diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/sqlglot_ir.py b/packages/bigframes/bigframes/core/compile/sqlglot/sqlglot_ir.py index 1e0b561e8c5b..b29a23cd84b8 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -249,12 +249,13 @@ def select( # TODO: Explicitly insert CTEs into plan if len(selections) > 0: to_select = [ - sge.Alias( - this=expr, + expr + if (isinstance(expr, sge.Alias) and expr.alias == id) + or (isinstance(expr, sge.Column) and expr.name == id) + else sge.Alias( + this=expr.this if isinstance(expr, sge.Alias) else expr, alias=sql.identifier(id), ) - if expr.alias_or_name != id - else expr for id, expr in selections ] new_expr = self.expr.select(*to_select) diff --git a/packages/bigframes/bigframes/core/sql_nodes.py b/packages/bigframes/bigframes/core/sql_nodes.py index 4cb4b02f7b80..c7a05a082f29 100644 --- a/packages/bigframes/bigframes/core/sql_nodes.py +++ b/packages/bigframes/bigframes/core/sql_nodes.py @@ -276,7 +276,14 @@ def _node_expressions(self): @property def is_star_selection(self) -> bool: - return tuple(self.ids) == tuple(self.child.ids) + if tuple(self.ids) != tuple(self.child.ids): + return False + for cdef in self.selections: + if not isinstance(cdef.expression, ex.DerefOp): + return False + if cdef.expression.id != cdef.id: + return False + return True @functools.cache def get_id_mapping(self) -> dict[identifiers.ColumnId, ex.Expression]: diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_to_datetime/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_to_datetime/out.sql index 9a2913e44beb..57ec17bf681a 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_to_datetime/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_to_datetime/out.sql @@ -1,7 +1,7 @@ SELECT CAST(TIMESTAMP_MICROS(CAST(TRUNC(`int64_col` * 0.001) AS INT64)) AS DATETIME) AS `int64_col`, - SAFE_CAST(`string_col` AS DATETIME), + SAFE_CAST(`string_col` AS DATETIME) AS `string_col`, CAST(TIMESTAMP_MICROS(CAST(TRUNC(`float64_col` * 0.001) AS INT64)) AS DATETIME) AS `float64_col`, - SAFE_CAST(`timestamp_col` AS DATETIME), + SAFE_CAST(`timestamp_col` AS DATETIME) AS `timestamp_col`, SAFE_CAST(`string_col` AS DATETIME) AS `string_col_fmt` -FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_float/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_float/out.sql index 3d48001e77ad..7f7bd86084ea 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_float/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_float/out.sql @@ -1,5 +1,5 @@ SELECT - CAST(CAST(`bool_col` AS INT64) AS FLOAT64), + CAST(CAST(`bool_col` AS INT64) AS FLOAT64) AS `bool_col`, CAST('1.34235e4' AS FLOAT64) AS `str_const`, SAFE_CAST(SAFE_CAST(`bool_col` AS INT64) AS FLOAT64) AS `bool_w_safe` -FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_string/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_string/out.sql index 3ea2299cc4f9..174f18d98233 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_string/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_string/out.sql @@ -1,5 +1,5 @@ SELECT - CAST(`int64_col` AS STRING), + CAST(`int64_col` AS STRING) AS `int64_col`, INITCAP(CAST(`bool_col` AS STRING)) AS `bool_col`, INITCAP(SAFE_CAST(`bool_col` AS STRING)) AS `bool_w_safe` -FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py index fd3aacc7e271..e86059b160a8 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py @@ -217,7 +217,7 @@ def test_to_datetime(scalar_types_df: bpd.DataFrame, snapshot): ) sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) - snapshot.assert_match(sql, "out.sql") + snapshot.assert_match(sql + "\n", "out.sql") def test_to_timestamp(scalar_types_df: bpd.DataFrame, snapshot): diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index fb5a9fd7ce84..185a8df04509 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -60,7 +60,7 @@ def test_astype_float(scalar_types_df: bpd.DataFrame, snapshot): "bool_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr("bool_col"), } sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) - snapshot.assert_match(sql, "out.sql") + snapshot.assert_match(sql + "\n", "out.sql") def test_astype_bool(scalar_types_df: bpd.DataFrame, snapshot): @@ -107,7 +107,7 @@ def test_astype_string(scalar_types_df: bpd.DataFrame, snapshot): "bool_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr("bool_col"), } sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) - snapshot.assert_match(sql, "out.sql") + snapshot.assert_match(sql + "\n", "out.sql") def test_astype_json(scalar_types_df: bpd.DataFrame, snapshot): diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_astype_aliases/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_astype_aliases/out.sql new file mode 100644 index 000000000000..cd056c650fd3 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_astype_aliases/out.sql @@ -0,0 +1,5 @@ +SELECT + `rowindex`, + CAST(`timestamp_col` AS STRING) AS `timestamp_col`, + CAST(`int64_col` AS FLOAT64) AS `int64_col` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/test_compile_readtable.py b/packages/bigframes/tests/unit/core/compile/sqlglot/test_compile_readtable.py index ea9875302a93..0f2058f21f68 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/test_compile_readtable.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/test_compile_readtable.py @@ -80,3 +80,15 @@ def test_compile_readtable_w_columns_filters(compiler_session, snapshot): filters=filters, ) snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_compile_astype_aliases(scalar_types_df: bpd.DataFrame, snapshot): + # Test case for issue #17394 (CAST columns lose their aliases) + bf_df = scalar_types_df[["timestamp_col", "int64_col"]] + result = bf_df.astype( + { + "timestamp_col": "string[pyarrow]", + "int64_col": "Float64", + } + ) + snapshot.assert_match(result.sql + "\n", "out.sql")