Skip to content

Commit 40929f0

Browse files
pr feedback 2
1 parent 82739a7 commit 40929f0

6 files changed

Lines changed: 84 additions & 30 deletions

File tree

sqlglot/generator.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,9 @@ class Generator:
620620

621621
UNSUPPORTED_TYPES: t.ClassVar[set[exp.DType]] = set()
622622

623+
TYPE_DEFAULT_PARAMS: t.ClassVar[dict[exp.DType, tuple[int, ...]]] = {}
624+
TYPE_PARAM_BOUNDS: t.ClassVar[dict[exp.DType, tuple[int | None, ...]]] = {}
625+
623626
TIME_PART_SINGULARS: t.ClassVar = {
624627
"MICROSECONDS": "MICROSECOND",
625628
"SECONDS": "SECOND",
@@ -1629,11 +1632,64 @@ def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
16291632
specifier = f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else ""
16301633
return f"{this}{specifier}"
16311634

1635+
def datatype_param_bound_limiter(
1636+
self, expression: exp.DataType, type_value: exp.DType
1637+
) -> exp.DataType:
1638+
params = expression.expressions
1639+
1640+
if not params:
1641+
if defaults := self.TYPE_DEFAULT_PARAMS.get(type_value):
1642+
expression = expression.copy()
1643+
expression.set(
1644+
"expressions",
1645+
[exp.DataTypeParam(this=exp.Literal.number(d)) for d in defaults],
1646+
)
1647+
return expression
1648+
1649+
bounds = self.TYPE_PARAM_BOUNDS.get(type_value)
1650+
if not bounds:
1651+
return expression
1652+
1653+
new_params = list(params)
1654+
changed = False
1655+
for i, param in enumerate(params):
1656+
if i >= len(bounds) or bounds[i] is None:
1657+
continue
1658+
1659+
param_value = param.this if isinstance(param, exp.DataTypeParam) else param
1660+
if (
1661+
isinstance(param_value, exp.Literal)
1662+
and param_value.is_number
1663+
and int(param_value.to_py()) > bounds[i]
1664+
):
1665+
self.unsupported(
1666+
f"{type_value.value} parameter ({int(param_value.to_py())}) "
1667+
f"exceeds {self.dialect.__class__.__name__}'s maximum capping to {bounds[i]}"
1668+
)
1669+
new_param = param.copy()
1670+
capped = exp.Literal.number(bounds[i])
1671+
if isinstance(new_param, exp.DataTypeParam):
1672+
new_param.set("this", capped)
1673+
else:
1674+
new_param = capped
1675+
new_params[i] = new_param
1676+
changed = True
1677+
1678+
if changed:
1679+
expression = expression.copy()
1680+
expression.set("expressions", new_params)
1681+
return expression
1682+
16321683
def datatype_sql(self, expression: exp.DataType) -> str:
16331684
nested = ""
16341685
values = ""
16351686

16361687
expr_nested = expression.args.get("nested")
1688+
type_value = expression.this
1689+
1690+
if not expr_nested and isinstance(type_value, exp.DType):
1691+
expression = self.datatype_param_bound_limiter(expression, type_value)
1692+
16371693
interior = (
16381694
self.expressions(
16391695
expression, dynamic=True, new_line=True, skip_first=True, skip_last=True
@@ -1642,7 +1698,6 @@ def datatype_sql(self, expression: exp.DataType) -> str:
16421698
else self.expressions(expression, flat=True)
16431699
)
16441700

1645-
type_value = expression.this
16461701
if type_value in self.UNSUPPORTED_TYPES:
16471702
self.unsupported(
16481703
f"Data type {type_value.value} is not supported when targeting {self.dialect.__class__.__name__}"

sqlglot/generators/duckdb.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -655,26 +655,6 @@ def _struct_sql(self: DuckDBGenerator, expression: exp.Struct) -> str:
655655
return f"ROW({csv_args})" if is_bq_inline_struct else f"{{{csv_args}}}"
656656

657657

658-
def _big_decimal_sql(self: DuckDBGenerator, expression: exp.DataType) -> str:
659-
if params := expression.expressions:
660-
precision_param = params[0].this if isinstance(params[0], exp.DataTypeParam) else params[0]
661-
if (
662-
isinstance(precision_param, exp.Literal)
663-
and precision_param.is_number
664-
and int(precision_param.to_py()) > 38
665-
):
666-
self.unsupported(
667-
"DECIMAL precision exceeds maximum of 38 for DuckDB, capping to (38,5)"
668-
)
669-
expression = expression.copy()
670-
expression.expressions[0].set("this", exp.Literal.number(38))
671-
expression.expressions[1].set("this", exp.Literal.number(5))
672-
673-
return f"DECIMAL({self.expressions(expression, flat=True)})"
674-
675-
return "DECIMAL(38, 5)"
676-
677-
678658
def _datatype_sql(self: DuckDBGenerator, expression: exp.DataType) -> str:
679659
if expression.is_type("array"):
680660
return f"{self.expressions(expression, flat=True)}[{self.expressions(expression, key='values', flat=True)}]"
@@ -683,9 +663,6 @@ def _datatype_sql(self: DuckDBGenerator, expression: exp.DataType) -> str:
683663
if expression.is_type(exp.DType.TIME, exp.DType.TIMETZ, exp.DType.TIMESTAMPTZ):
684664
return expression.this.value
685665

686-
if expression.is_type(exp.DType.BIGDECIMAL):
687-
return _big_decimal_sql(self, expression)
688-
689666
return self.datatype_sql(expression)
690667

691668

@@ -1750,7 +1727,7 @@ class DuckDBGenerator(generator.Generator):
17501727
exp.DType.BPCHAR: "TEXT",
17511728
exp.DType.CHAR: "TEXT",
17521729
exp.DType.DATETIME: "TIMESTAMP",
1753-
exp.DType.DECFLOAT: "DECIMAL(38, 5)",
1730+
exp.DType.DECFLOAT: "DECIMAL",
17541731
exp.DType.FLOAT: "REAL",
17551732
exp.DType.JSONB: "JSON",
17561733
exp.DType.NCHAR: "TEXT",
@@ -1767,6 +1744,18 @@ class DuckDBGenerator(generator.Generator):
17671744
exp.DType.BIGDECIMAL: "DECIMAL",
17681745
}
17691746

1747+
TYPE_DEFAULT_PARAMS = {
1748+
**generator.Generator.TYPE_DEFAULT_PARAMS,
1749+
exp.DType.BIGDECIMAL: (38, 5),
1750+
exp.DType.DECFLOAT: (38, 5),
1751+
}
1752+
1753+
TYPE_PARAM_BOUNDS = {
1754+
**generator.Generator.TYPE_PARAM_BOUNDS,
1755+
exp.DType.BIGDECIMAL: (38, 38),
1756+
exp.DType.DECFLOAT: (38, 38),
1757+
}
1758+
17701759
# https://github.com/duckdb/duckdb/blob/ff7f24fd8e3128d94371827523dae85ebaf58713/third_party/libpg_query/grammar/keywords/reserved_keywords.list#L1-L77
17711760
RESERVED_KEYWORDS = {
17721761
"array",

sqlglot/generators/singlestore.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,12 @@ class SingleStoreGenerator(MySQLGenerator):
361361
exp.DType.JSONB: "BSON",
362362
exp.DType.TIMESTAMP: "TIMESTAMP",
363363
exp.DType.TIMESTAMP_S: "TIMESTAMP",
364-
exp.DType.TIMESTAMP_MS: "TIMESTAMP(6)",
364+
exp.DType.TIMESTAMP_MS: "TIMESTAMP",
365+
}
366+
367+
TYPE_DEFAULT_PARAMS = {
368+
**MySQLGenerator.TYPE_DEFAULT_PARAMS,
369+
exp.DType.TIMESTAMP_MS: (6,),
365370
}
366371

367372
# https://docs.singlestore.com/cloud/reference/sql-reference/restricted-keywords/list-of-restricted-keywords/

sqlglot/generators/spark.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,19 @@ class SparkGenerator(Spark2Generator):
7676

7777
TYPE_MAPPING = {
7878
**Spark2Generator.TYPE_MAPPING,
79-
exp.DType.MONEY: "DECIMAL(15, 4)",
80-
exp.DType.SMALLMONEY: "DECIMAL(6, 4)",
79+
exp.DType.MONEY: "DECIMAL",
80+
exp.DType.SMALLMONEY: "DECIMAL",
8181
exp.DType.UUID: "STRING",
8282
exp.DType.TIMESTAMPLTZ: "TIMESTAMP_LTZ",
8383
exp.DType.TIMESTAMPNTZ: "TIMESTAMP_NTZ",
8484
}
8585

86+
TYPE_DEFAULT_PARAMS = {
87+
**Spark2Generator.TYPE_DEFAULT_PARAMS,
88+
exp.DType.MONEY: (15, 4),
89+
exp.DType.SMALLMONEY: (6, 4),
90+
}
91+
8692
TRANSFORMS = {
8793
k: v
8894
for k, v in {

tests/dialects/test_bigquery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3836,7 +3836,7 @@ def test_bignumeric(self):
38363836
f"DECLARE x {type_}(76, 38)",
38373837
write={
38383838
"bigquery": "DECLARE x BIGNUMERIC(76, 38)",
3839-
"duckdb": "DECLARE x DECIMAL(38, 5)",
3839+
"duckdb": "DECLARE x DECIMAL(38, 38)",
38403840
},
38413841
)
38423842

tests/test_expressions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1342,7 +1342,6 @@ def test_update_positions_empty_meta(self):
13421342
assert expr1.meta == {}
13431343

13441344
def test_pipe_and_apply(self) -> None:
1345-
13461345
def add_val(expr: exp.Expr, val: int, *, squared: bool) -> exp.Expr:
13471346
nb = val**2 if squared else val
13481347
return expr + nb

0 commit comments

Comments
 (0)