diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 635bdfd7c0..ed3e678e53 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -620,6 +620,9 @@ class Generator: UNSUPPORTED_TYPES: t.ClassVar[set[exp.DType]] = set() + TYPE_DEFAULT_PARAMS: t.ClassVar[dict[exp.DType, tuple[int, ...]]] = {} + TYPE_PARAM_BOUNDS: t.ClassVar[dict[exp.DType, tuple[int | None, ...]]] = {} + TIME_PART_SINGULARS: t.ClassVar = { "MICROSECONDS": "MICROSECOND", "SECONDS": "SECOND", @@ -1629,11 +1632,65 @@ def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str: specifier = f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else "" return f"{this}{specifier}" + def datatype_param_bound_limiter( + self, expression: exp.DataType, type_value: exp.DType + ) -> exp.DataType: + params = expression.expressions + + if not params: + if defaults := self.TYPE_DEFAULT_PARAMS.get(type_value): + expression = expression.copy() + expression.set( + "expressions", + [exp.DataTypeParam(this=exp.Literal.number(d)) for d in defaults], + ) + return expression + + bounds = self.TYPE_PARAM_BOUNDS.get(type_value) + if not bounds: + return expression + + new_params = list(params) + changed = False + for i, param in enumerate(params): + bound = bounds[i] if i < len(bounds) else None + if bound is None: + continue + + param_value = param.this if isinstance(param, exp.DataTypeParam) else param + if ( + isinstance(param_value, exp.Literal) + and param_value.is_number + and int(param_value.to_py()) > bound + ): + self.unsupported( + f"{type_value.value} parameter ({int(param_value.to_py())}) " + f"exceeds {self.dialect.__class__.__name__}'s maximum capping to {bound}" + ) + new_param = param.copy() + capped = exp.Literal.number(bound) + if isinstance(new_param, exp.DataTypeParam): + new_param.set("this", capped) + else: + new_param = capped + new_params[i] = new_param + changed = True + + if changed: + expression = expression.copy() + expression.set("expressions", new_params) + return expression + def datatype_sql(self, expression: exp.DataType) -> str: nested = "" values = "" expr_nested = expression.args.get("nested") + type_value = expression.this + + if not expr_nested and isinstance(type_value, exp.DType): + expression = self.datatype_param_bound_limiter(expression, type_value) + interior = ( self.expressions( expression, dynamic=True, new_line=True, skip_first=True, skip_last=True @@ -1642,7 +1699,6 @@ def datatype_sql(self, expression: exp.DataType) -> str: else self.expressions(expression, flat=True) ) - type_value = expression.this if type_value in self.UNSUPPORTED_TYPES: self.unsupported( f"Data type {type_value.value} is not supported when targeting {self.dialect.__class__.__name__}" diff --git a/sqlglot/generators/duckdb.py b/sqlglot/generators/duckdb.py index 061a71b314..33d4581171 100644 --- a/sqlglot/generators/duckdb.py +++ b/sqlglot/generators/duckdb.py @@ -1727,7 +1727,7 @@ class DuckDBGenerator(generator.Generator): exp.DType.BPCHAR: "TEXT", exp.DType.CHAR: "TEXT", exp.DType.DATETIME: "TIMESTAMP", - exp.DType.DECFLOAT: "DECIMAL(38, 5)", + exp.DType.DECFLOAT: "DECIMAL", exp.DType.FLOAT: "REAL", exp.DType.JSONB: "JSON", exp.DType.NCHAR: "TEXT", @@ -1741,7 +1741,19 @@ class DuckDBGenerator(generator.Generator): exp.DType.TIMESTAMP_S: "TIMESTAMP_S", exp.DType.TIMESTAMP_MS: "TIMESTAMP_MS", exp.DType.TIMESTAMP_NS: "TIMESTAMP_NS", - exp.DType.BIGDECIMAL: "DECIMAL(38, 5)", + exp.DType.BIGDECIMAL: "DECIMAL", + } + + TYPE_DEFAULT_PARAMS = { + **generator.Generator.TYPE_DEFAULT_PARAMS, + exp.DType.BIGDECIMAL: (38, 5), + exp.DType.DECFLOAT: (38, 5), + } + + TYPE_PARAM_BOUNDS = { + **generator.Generator.TYPE_PARAM_BOUNDS, + exp.DType.BIGDECIMAL: (38, 38), + exp.DType.DECFLOAT: (38, 38), } # https://github.com/duckdb/duckdb/blob/ff7f24fd8e3128d94371827523dae85ebaf58713/third_party/libpg_query/grammar/keywords/reserved_keywords.list#L1-L77 diff --git a/sqlglot/generators/singlestore.py b/sqlglot/generators/singlestore.py index f2d8303e42..c761274749 100644 --- a/sqlglot/generators/singlestore.py +++ b/sqlglot/generators/singlestore.py @@ -361,7 +361,12 @@ class SingleStoreGenerator(MySQLGenerator): exp.DType.JSONB: "BSON", exp.DType.TIMESTAMP: "TIMESTAMP", exp.DType.TIMESTAMP_S: "TIMESTAMP", - exp.DType.TIMESTAMP_MS: "TIMESTAMP(6)", + exp.DType.TIMESTAMP_MS: "TIMESTAMP", + } + + TYPE_DEFAULT_PARAMS = { + **MySQLGenerator.TYPE_DEFAULT_PARAMS, + exp.DType.TIMESTAMP_MS: (6,), } # https://docs.singlestore.com/cloud/reference/sql-reference/restricted-keywords/list-of-restricted-keywords/ diff --git a/sqlglot/generators/spark.py b/sqlglot/generators/spark.py index d989f7ece9..0d8556f279 100644 --- a/sqlglot/generators/spark.py +++ b/sqlglot/generators/spark.py @@ -76,13 +76,19 @@ class SparkGenerator(Spark2Generator): TYPE_MAPPING = { **Spark2Generator.TYPE_MAPPING, - exp.DType.MONEY: "DECIMAL(15, 4)", - exp.DType.SMALLMONEY: "DECIMAL(6, 4)", + exp.DType.MONEY: "DECIMAL", + exp.DType.SMALLMONEY: "DECIMAL", exp.DType.UUID: "STRING", exp.DType.TIMESTAMPLTZ: "TIMESTAMP_LTZ", exp.DType.TIMESTAMPNTZ: "TIMESTAMP_NTZ", } + TYPE_DEFAULT_PARAMS = { + **Spark2Generator.TYPE_DEFAULT_PARAMS, + exp.DType.MONEY: (15, 4), + exp.DType.SMALLMONEY: (6, 4), + } + TRANSFORMS = { k: v for k, v in { diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 798b3a5fa8..fabf9a6cc0 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -3824,6 +3824,22 @@ def test_bignumeric(self): }, ) + self.validate_all( + f"DECLARE x {type_}(20, 4)", + write={ + "bigquery": "DECLARE x BIGNUMERIC(20, 4)", + "duckdb": "DECLARE x DECIMAL(20, 4)", + }, + ) + + self.validate_all( + f"DECLARE x {type_}(76, 38)", + write={ + "bigquery": "DECLARE x BIGNUMERIC(76, 38)", + "duckdb": "DECLARE x DECIMAL(38, 38)", + }, + ) + self.validate_all( f"SELECT CAST(1 AS {type_})", write={ diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 8c7b3cf571..a6be9cdb37 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1342,7 +1342,6 @@ def test_update_positions_empty_meta(self): assert expr1.meta == {} def test_pipe_and_apply(self) -> None: - def add_val(expr: exp.Expr, val: int, *, squared: bool) -> exp.Expr: nb = val**2 if squared else val return expr + nb