Skip to content

Commit 40c53cd

Browse files
Fix: Parametrized bigdecimal mapping was being concatenated with params
1 parent 9f169ab commit 40c53cd

6 files changed

Lines changed: 101 additions & 7 deletions

File tree

sqlglot/generator.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,9 @@ class Generator:
620620

621621
UNSUPPORTED_TYPES: t.ClassVar[set[exp.DType]] = set()
622622

623+
TYPE_DEFAULT_PARAMS: t.ClassVar[dict[exp.DType, tuple[int, ...]]] = {}
624+
TYPE_PARAM_BOUNDS: t.ClassVar[dict[exp.DType, tuple[int | None, ...]]] = {}
625+
623626
TIME_PART_SINGULARS: t.ClassVar = {
624627
"MICROSECONDS": "MICROSECOND",
625628
"SECONDS": "SECOND",
@@ -1629,11 +1632,65 @@ def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
16291632
specifier = f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else ""
16301633
return f"{this}{specifier}"
16311634

1635+
def datatype_param_bound_limiter(
1636+
self, expression: exp.DataType, type_value: exp.DType
1637+
) -> exp.DataType:
1638+
params = expression.expressions
1639+
1640+
if not params:
1641+
if defaults := self.TYPE_DEFAULT_PARAMS.get(type_value):
1642+
expression = expression.copy()
1643+
expression.set(
1644+
"expressions",
1645+
[exp.DataTypeParam(this=exp.Literal.number(d)) for d in defaults],
1646+
)
1647+
return expression
1648+
1649+
bounds = self.TYPE_PARAM_BOUNDS.get(type_value)
1650+
if not bounds:
1651+
return expression
1652+
1653+
new_params = list(params)
1654+
changed = False
1655+
for i, param in enumerate(params):
1656+
bound = bounds[i] if i < len(bounds) else None
1657+
if bound is None:
1658+
continue
1659+
1660+
param_value = param.this if isinstance(param, exp.DataTypeParam) else param
1661+
if (
1662+
isinstance(param_value, exp.Literal)
1663+
and param_value.is_number
1664+
and int(param_value.to_py()) > bound
1665+
):
1666+
self.unsupported(
1667+
f"{type_value.value} parameter ({int(param_value.to_py())}) "
1668+
f"exceeds {self.dialect.__class__.__name__}'s maximum capping to {bound}"
1669+
)
1670+
new_param = param.copy()
1671+
capped = exp.Literal.number(bound)
1672+
if isinstance(new_param, exp.DataTypeParam):
1673+
new_param.set("this", capped)
1674+
else:
1675+
new_param = capped
1676+
new_params[i] = new_param
1677+
changed = True
1678+
1679+
if changed:
1680+
expression = expression.copy()
1681+
expression.set("expressions", new_params)
1682+
return expression
1683+
16321684
def datatype_sql(self, expression: exp.DataType) -> str:
16331685
nested = ""
16341686
values = ""
16351687

16361688
expr_nested = expression.args.get("nested")
1689+
type_value = expression.this
1690+
1691+
if not expr_nested and isinstance(type_value, exp.DType):
1692+
expression = self.datatype_param_bound_limiter(expression, type_value)
1693+
16371694
interior = (
16381695
self.expressions(
16391696
expression, dynamic=True, new_line=True, skip_first=True, skip_last=True
@@ -1642,7 +1699,6 @@ def datatype_sql(self, expression: exp.DataType) -> str:
16421699
else self.expressions(expression, flat=True)
16431700
)
16441701

1645-
type_value = expression.this
16461702
if type_value in self.UNSUPPORTED_TYPES:
16471703
self.unsupported(
16481704
f"Data type {type_value.value} is not supported when targeting {self.dialect.__class__.__name__}"

sqlglot/generators/duckdb.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,7 +1727,7 @@ class DuckDBGenerator(generator.Generator):
17271727
exp.DType.BPCHAR: "TEXT",
17281728
exp.DType.CHAR: "TEXT",
17291729
exp.DType.DATETIME: "TIMESTAMP",
1730-
exp.DType.DECFLOAT: "DECIMAL(38, 5)",
1730+
exp.DType.DECFLOAT: "DECIMAL",
17311731
exp.DType.FLOAT: "REAL",
17321732
exp.DType.JSONB: "JSON",
17331733
exp.DType.NCHAR: "TEXT",
@@ -1741,7 +1741,19 @@ class DuckDBGenerator(generator.Generator):
17411741
exp.DType.TIMESTAMP_S: "TIMESTAMP_S",
17421742
exp.DType.TIMESTAMP_MS: "TIMESTAMP_MS",
17431743
exp.DType.TIMESTAMP_NS: "TIMESTAMP_NS",
1744-
exp.DType.BIGDECIMAL: "DECIMAL(38, 5)",
1744+
exp.DType.BIGDECIMAL: "DECIMAL",
1745+
}
1746+
1747+
TYPE_DEFAULT_PARAMS = {
1748+
**generator.Generator.TYPE_DEFAULT_PARAMS,
1749+
exp.DType.BIGDECIMAL: (38, 5),
1750+
exp.DType.DECFLOAT: (38, 5),
1751+
}
1752+
1753+
TYPE_PARAM_BOUNDS = {
1754+
**generator.Generator.TYPE_PARAM_BOUNDS,
1755+
exp.DType.BIGDECIMAL: (38, 38),
1756+
exp.DType.DECFLOAT: (38, 38),
17451757
}
17461758

17471759
# https://github.com/duckdb/duckdb/blob/ff7f24fd8e3128d94371827523dae85ebaf58713/third_party/libpg_query/grammar/keywords/reserved_keywords.list#L1-L77

sqlglot/generators/singlestore.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,12 @@ class SingleStoreGenerator(MySQLGenerator):
361361
exp.DType.JSONB: "BSON",
362362
exp.DType.TIMESTAMP: "TIMESTAMP",
363363
exp.DType.TIMESTAMP_S: "TIMESTAMP",
364-
exp.DType.TIMESTAMP_MS: "TIMESTAMP(6)",
364+
exp.DType.TIMESTAMP_MS: "TIMESTAMP",
365+
}
366+
367+
TYPE_DEFAULT_PARAMS = {
368+
**MySQLGenerator.TYPE_DEFAULT_PARAMS,
369+
exp.DType.TIMESTAMP_MS: (6,),
365370
}
366371

367372
# https://docs.singlestore.com/cloud/reference/sql-reference/restricted-keywords/list-of-restricted-keywords/

sqlglot/generators/spark.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,19 @@ class SparkGenerator(Spark2Generator):
7676

7777
TYPE_MAPPING = {
7878
**Spark2Generator.TYPE_MAPPING,
79-
exp.DType.MONEY: "DECIMAL(15, 4)",
80-
exp.DType.SMALLMONEY: "DECIMAL(6, 4)",
79+
exp.DType.MONEY: "DECIMAL",
80+
exp.DType.SMALLMONEY: "DECIMAL",
8181
exp.DType.UUID: "STRING",
8282
exp.DType.TIMESTAMPLTZ: "TIMESTAMP_LTZ",
8383
exp.DType.TIMESTAMPNTZ: "TIMESTAMP_NTZ",
8484
}
8585

86+
TYPE_DEFAULT_PARAMS = {
87+
**Spark2Generator.TYPE_DEFAULT_PARAMS,
88+
exp.DType.MONEY: (15, 4),
89+
exp.DType.SMALLMONEY: (6, 4),
90+
}
91+
8692
TRANSFORMS = {
8793
k: v
8894
for k, v in {

tests/dialects/test_bigquery.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3824,6 +3824,22 @@ def test_bignumeric(self):
38243824
},
38253825
)
38263826

3827+
self.validate_all(
3828+
f"DECLARE x {type_}(20, 4)",
3829+
write={
3830+
"bigquery": "DECLARE x BIGNUMERIC(20, 4)",
3831+
"duckdb": "DECLARE x DECIMAL(20, 4)",
3832+
},
3833+
)
3834+
3835+
self.validate_all(
3836+
f"DECLARE x {type_}(76, 38)",
3837+
write={
3838+
"bigquery": "DECLARE x BIGNUMERIC(76, 38)",
3839+
"duckdb": "DECLARE x DECIMAL(38, 38)",
3840+
},
3841+
)
3842+
38273843
self.validate_all(
38283844
f"SELECT CAST(1 AS {type_})",
38293845
write={

tests/test_expressions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1342,7 +1342,6 @@ def test_update_positions_empty_meta(self):
13421342
assert expr1.meta == {}
13431343

13441344
def test_pipe_and_apply(self) -> None:
1345-
13461345
def add_val(expr: exp.Expr, val: int, *, squared: bool) -> exp.Expr:
13471346
nb = val**2 if squared else val
13481347
return expr + nb

0 commit comments

Comments
 (0)