Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ class Generator:

UNSUPPORTED_TYPES: t.ClassVar[set[exp.DType]] = set()

TYPE_DEFAULT_PARAMS: t.ClassVar[dict[exp.DType, tuple[int, ...]]] = {}
TYPE_PARAM_BOUNDS: t.ClassVar[dict[exp.DType, tuple[int | None, ...]]] = {}

TIME_PART_SINGULARS: t.ClassVar = {
"MICROSECONDS": "MICROSECOND",
"SECONDS": "SECOND",
Expand Down Expand Up @@ -1629,11 +1632,65 @@ def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
specifier = f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else ""
return f"{this}{specifier}"

def datatype_param_bound_limiter(
self, expression: exp.DataType, type_value: exp.DType
) -> exp.DataType:
params = expression.expressions

if not params:
if defaults := self.TYPE_DEFAULT_PARAMS.get(type_value):
expression = expression.copy()
expression.set(
"expressions",
[exp.DataTypeParam(this=exp.Literal.number(d)) for d in defaults],
)
return expression

bounds = self.TYPE_PARAM_BOUNDS.get(type_value)
if not bounds:
return expression

new_params = list(params)
changed = False
for i, param in enumerate(params):
bound = bounds[i] if i < len(bounds) else None
if bound is None:
continue

param_value = param.this if isinstance(param, exp.DataTypeParam) else param
if (
isinstance(param_value, exp.Literal)
and param_value.is_number
and int(param_value.to_py()) > bound
):
self.unsupported(
f"{type_value.value} parameter ({int(param_value.to_py())}) "
f"exceeds {self.dialect.__class__.__name__}'s maximum capping to {bound}"
)
new_param = param.copy()
capped = exp.Literal.number(bound)
if isinstance(new_param, exp.DataTypeParam):
new_param.set("this", capped)
else:
new_param = capped
new_params[i] = new_param
changed = True

if changed:
expression = expression.copy()
expression.set("expressions", new_params)
return expression

def datatype_sql(self, expression: exp.DataType) -> str:
nested = ""
values = ""

expr_nested = expression.args.get("nested")
type_value = expression.this

if not expr_nested and isinstance(type_value, exp.DType):
expression = self.datatype_param_bound_limiter(expression, type_value)

interior = (
self.expressions(
expression, dynamic=True, new_line=True, skip_first=True, skip_last=True
Expand All @@ -1642,7 +1699,6 @@ def datatype_sql(self, expression: exp.DataType) -> str:
else self.expressions(expression, flat=True)
)

type_value = expression.this
if type_value in self.UNSUPPORTED_TYPES:
self.unsupported(
f"Data type {type_value.value} is not supported when targeting {self.dialect.__class__.__name__}"
Expand Down
16 changes: 14 additions & 2 deletions sqlglot/generators/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ class DuckDBGenerator(generator.Generator):
exp.DType.BPCHAR: "TEXT",
exp.DType.CHAR: "TEXT",
exp.DType.DATETIME: "TIMESTAMP",
exp.DType.DECFLOAT: "DECIMAL(38, 5)",
exp.DType.DECFLOAT: "DECIMAL",
exp.DType.FLOAT: "REAL",
exp.DType.JSONB: "JSON",
exp.DType.NCHAR: "TEXT",
Expand All @@ -1741,7 +1741,19 @@ class DuckDBGenerator(generator.Generator):
exp.DType.TIMESTAMP_S: "TIMESTAMP_S",
exp.DType.TIMESTAMP_MS: "TIMESTAMP_MS",
exp.DType.TIMESTAMP_NS: "TIMESTAMP_NS",
exp.DType.BIGDECIMAL: "DECIMAL(38, 5)",
exp.DType.BIGDECIMAL: "DECIMAL",
}

TYPE_DEFAULT_PARAMS = {
**generator.Generator.TYPE_DEFAULT_PARAMS,
exp.DType.BIGDECIMAL: (38, 5),
exp.DType.DECFLOAT: (38, 5),
}

TYPE_PARAM_BOUNDS = {
**generator.Generator.TYPE_PARAM_BOUNDS,
exp.DType.BIGDECIMAL: (38, 38),
exp.DType.DECFLOAT: (38, 38),
}

# https://github.com/duckdb/duckdb/blob/ff7f24fd8e3128d94371827523dae85ebaf58713/third_party/libpg_query/grammar/keywords/reserved_keywords.list#L1-L77
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/generators/singlestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,12 @@ class SingleStoreGenerator(MySQLGenerator):
exp.DType.JSONB: "BSON",
exp.DType.TIMESTAMP: "TIMESTAMP",
exp.DType.TIMESTAMP_S: "TIMESTAMP",
exp.DType.TIMESTAMP_MS: "TIMESTAMP(6)",
exp.DType.TIMESTAMP_MS: "TIMESTAMP",
}

TYPE_DEFAULT_PARAMS = {
**MySQLGenerator.TYPE_DEFAULT_PARAMS,
exp.DType.TIMESTAMP_MS: (6,),
}

# https://docs.singlestore.com/cloud/reference/sql-reference/restricted-keywords/list-of-restricted-keywords/
Expand Down
10 changes: 8 additions & 2 deletions sqlglot/generators/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,19 @@ class SparkGenerator(Spark2Generator):

TYPE_MAPPING = {
**Spark2Generator.TYPE_MAPPING,
exp.DType.MONEY: "DECIMAL(15, 4)",
exp.DType.SMALLMONEY: "DECIMAL(6, 4)",
exp.DType.MONEY: "DECIMAL",
exp.DType.SMALLMONEY: "DECIMAL",
exp.DType.UUID: "STRING",
exp.DType.TIMESTAMPLTZ: "TIMESTAMP_LTZ",
exp.DType.TIMESTAMPNTZ: "TIMESTAMP_NTZ",
}

TYPE_DEFAULT_PARAMS = {
**Spark2Generator.TYPE_DEFAULT_PARAMS,
exp.DType.MONEY: (15, 4),
exp.DType.SMALLMONEY: (6, 4),
}

TRANSFORMS = {
k: v
for k, v in {
Expand Down
16 changes: 16 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3824,6 +3824,22 @@ def test_bignumeric(self):
},
)

self.validate_all(
f"DECLARE x {type_}(20, 4)",
write={
"bigquery": "DECLARE x BIGNUMERIC(20, 4)",
"duckdb": "DECLARE x DECIMAL(20, 4)",
},
)

self.validate_all(
f"DECLARE x {type_}(76, 38)",
write={
"bigquery": "DECLARE x BIGNUMERIC(76, 38)",
"duckdb": "DECLARE x DECIMAL(38, 38)",
},
)

self.validate_all(
f"SELECT CAST(1 AS {type_})",
write={
Expand Down
1 change: 0 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,6 @@ def test_update_positions_empty_meta(self):
assert expr1.meta == {}

def test_pipe_and_apply(self) -> None:

def add_val(expr: exp.Expr, val: int, *, squared: bool) -> exp.Expr:
nb = val**2 if squared else val
return expr + nb
Expand Down
Loading