Skip to content

Commit a20ec0e

Browse files
authored
fix(optimizer): unpivot annotate types (#7543)
* fix(optimizer): unpivot annotate types for bq * fix varchar * fix * simplify * refactor * remove copy * revert copy * refactor impl * varchar default * more robust impl * remove typing * simplify * correlated subquery case * simplify logic * reverse types * feedback 1 george (partial) * feedback 1 (george - full resolved) * refactor * ref * ref 2 * remove dup tests * doc str update
1 parent 1d976e6 commit a20ec0e

4 files changed

Lines changed: 115 additions & 48 deletions

File tree

sqlglot-integration-tests

sqlglot/expressions/query.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlglot.helper import trait, ensure_list
99
from sqlglot.expressions.core import (
1010
Aliases,
11+
Column,
1112
Condition,
1213
Distinct,
1314
Dot,
@@ -1762,6 +1763,60 @@ def unpivot(self) -> bool:
17621763
def fields(self) -> list[Expr]:
17631764
return self.args.get("fields", [])
17641765

1766+
def output_columns(self, pre_pivot_columns: t.Iterable[str]) -> list[str]:
1767+
"""
1768+
Returns the columns produced by this (UN)PIVOT, in order.
1769+
1770+
Example:
1771+
>>> from sqlglot import parse_one, exp
1772+
>>> piv = parse_one("SELECT * FROM t UNPIVOT(val FOR name IN (a, b))").find(exp.Pivot)
1773+
>>> piv.output_columns(["a", "b", "c"])
1774+
['c', 'name', 'val']
1775+
1776+
AST shape:
1777+
PIVOT(SUM(val) FOR name IN ('a', 'b')):
1778+
expressions: aggregate(s), e.g. [Sum(this=Column(val))]
1779+
fields: [In(this=Column(name), expressions=[Literal('a'), Literal('b')])]
1780+
columns: optional explicit output identifiers (e.g. set by Snowflake)
1781+
1782+
UNPIVOT(val FOR name IN (a, b)):
1783+
expressions: value Identifier(s), or Tuple(Identifiers) for multi-value
1784+
fields: [In(this=Identifier(name), expressions=[Column(a), Column(b)])]
1785+
For literal-aliased entries (`a AS 'x'`) the IN expressions
1786+
are wrapped in PivotAlias(this=Column, alias=Literal).
1787+
1788+
Args:
1789+
pre_pivot_columns: Columns visible to the operator before it runs
1790+
(e.g. the source table or subquery's projections).
1791+
"""
1792+
if self.unpivot:
1793+
excluded: set[str] = set()
1794+
name_columns: list[Identifier] = []
1795+
for field in self.fields:
1796+
if not isinstance(field, In):
1797+
continue
1798+
if isinstance(field.this, Identifier):
1799+
name_columns.append(field.this)
1800+
for e in field.expressions:
1801+
excluded.update(c.output_name for c in e.find_all(Column))
1802+
value_columns = [
1803+
ident
1804+
for e in self.expressions
1805+
for ident in (e.expressions if isinstance(e, Tuple) else [e])
1806+
if isinstance(ident, Identifier)
1807+
]
1808+
outputs = [i.name for i in name_columns + value_columns]
1809+
else:
1810+
excluded = {c.output_name for c in self.find_all(Column)}
1811+
outputs = [c.output_name for c in self.args.get("columns") or []]
1812+
if not outputs:
1813+
outputs = [c.alias_or_name for c in self.expressions]
1814+
1815+
if not excluded or not outputs:
1816+
return []
1817+
1818+
return [c for c in pre_pivot_columns if c not in excluded] + outputs
1819+
17651820

17661821
class UnpivotColumns(Expression):
17671822
arg_types = {"this": True, "expressions": True}

sqlglot/optimizer/annotate_types.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,10 @@ def _get_scope_selects(self, scope: Scope) -> dict[str, dict[str, t.Any]]:
330330
elif isinstance(expression, exp.Selectable):
331331
selects[name] = {s.alias_or_name: s.type for s in expression.selects if s.type}
332332

333+
for pivot in scope.pivots:
334+
if pivot.unpivot and (col_types := self._get_unpivot_column_types(pivot, selects)):
335+
selects[pivot.alias] = col_types
336+
333337
self._scope_selects[scope] = selects
334338

335339
return self._scope_selects[scope]
@@ -419,7 +423,15 @@ def _annotate_expression(
419423
source_scope = source_scope.parent
420424

421425
if isinstance(source, exp.Table):
422-
self._set_type(expr, self.schema.get_column_type(source, expr))
426+
table_col_type = self.schema.get_column_type(source, expr)
427+
if table_col_type.is_type(exp.DType.UNKNOWN) and source.args.get("pivots"):
428+
table_col_type = (
429+
self._get_scope_selects(source_scope or scope)
430+
.get(expr.table, {})
431+
.get(expr.name)
432+
or table_col_type
433+
)
434+
self._set_type(expr, table_col_type)
423435
elif source and source_scope:
424436
col_type = (
425437
self._get_scope_selects(source_scope).get(expr.table, {}).get(expr.name)
@@ -430,6 +442,16 @@ def _annotate_expression(
430442
self._set_type(expr, source.expression.type)
431443
else:
432444
self._set_type(expr, exp.DType.UNKNOWN)
445+
elif (
446+
not source
447+
and scope.pivots
448+
and (
449+
pivot_type := self._get_scope_selects(scope)
450+
.get(expr.table, {})
451+
.get(expr.name)
452+
)
453+
):
454+
self._set_type(expr, pivot_type)
433455
else:
434456
self._set_type(expr, exp.DType.UNKNOWN)
435457

@@ -591,6 +613,39 @@ def _get_setop_column_types(
591613
self._setop_column_types[setop_id] = col_types
592614
return col_types
593615

616+
def _get_unpivot_column_types(
617+
self, pivot: exp.Pivot, selects: dict[str, dict[str, t.Any]]
618+
) -> dict[str, t.Any]:
619+
src_types = selects.get(parent.alias_or_name, {}) if (parent := pivot.parent) else {}
620+
new_types: dict[str, t.Any] = {}
621+
622+
for field in pivot.fields:
623+
field_col = field.this
624+
first = seq_get(field.expressions, 0)
625+
626+
if isinstance(first, exp.PivotAlias) and (alias_node := first.args.get("alias")):
627+
new_types[field_col.name] = alias_node.type
628+
in_src = first.this
629+
else:
630+
new_types[field_col.name] = exp.DType.VARCHAR.into_expr()
631+
in_src = first
632+
633+
in_cols = in_src.expressions if isinstance(in_src, exp.Tuple) else [in_src]
634+
val_expr = seq_get(pivot.expressions, 0)
635+
val_cols = val_expr.expressions if isinstance(val_expr, exp.Tuple) else [val_expr]
636+
for val_col, in_col in zip(val_cols, in_cols):
637+
new_types[val_col.output_name] = (
638+
src_types.get(in_col.output_name)
639+
if in_col.is_type(exp.DType.UNKNOWN)
640+
else in_col.type
641+
)
642+
643+
return {
644+
name: type_
645+
for name in pivot.output_columns(src_types)
646+
if (type_ := new_types.get(name) or src_types.get(name))
647+
}
648+
594649
def _annotate_binary(self, expression: B) -> B:
595650
left, right = expression.left, expression.right
596651
if not left or not right:

sqlglot/optimizer/qualify_columns.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
if t.TYPE_CHECKING:
1717
from sqlglot._typing import E
18-
from collections.abc import Iterator, Iterable, Sequence
18+
from collections.abc import Iterable
1919

2020

2121
def qualify_columns(
@@ -181,49 +181,6 @@ def _separate_pseudocolumns(scope: Scope, pseudocolumns: set[str]) -> None:
181181
scope.clear_cache()
182182

183183

184-
def _pivot_output_columns(pivot: exp.Pivot, pre_pivot_columns: Sequence[str]) -> list[str]:
185-
"""Compute the columns exposed after a (UN)PIVOT, given its pre-pivot source columns.
186-
187-
Returns an empty list for degenerate pivots (no IN-list or no output names) so callers
188-
can fall through to their non-pivot handling.
189-
"""
190-
if pivot.unpivot:
191-
excluded = {
192-
c.output_name
193-
for field in pivot.fields
194-
if isinstance(field, exp.In)
195-
for e in field.expressions
196-
for c in e.find_all(exp.Column)
197-
}
198-
outputs = [i.name for i in _unpivot_columns(pivot)]
199-
else:
200-
excluded = {c.output_name for c in pivot.find_all(exp.Column)}
201-
outputs = [c.output_name for c in pivot.args.get("columns") or []]
202-
if not outputs:
203-
outputs = [c.alias_or_name for c in pivot.expressions]
204-
205-
if not excluded or not outputs:
206-
return []
207-
208-
return [c for c in pre_pivot_columns if c not in excluded] + outputs
209-
210-
211-
def _unpivot_columns(unpivot: exp.Pivot) -> Iterator[exp.Identifier]:
212-
name_columns = [
213-
field.this
214-
for field in unpivot.fields
215-
if isinstance(field, exp.In) and isinstance(field.this, exp.Identifier)
216-
]
217-
value_columns = (
218-
ident
219-
for e in unpivot.expressions
220-
for ident in (e.expressions if isinstance(e, exp.Tuple) else [e])
221-
if isinstance(ident, exp.Identifier)
222-
)
223-
224-
return itertools.chain(name_columns, value_columns)
225-
226-
227184
def _pop_table_column_aliases(derived_tables: Iterable[exp.Expr]) -> None:
228185
"""
229186
Remove table column aliases.
@@ -638,7 +595,7 @@ def _qualify_columns(
638595
if isinstance(column_source, exp.Table) and (
639596
pivots := column_source.args.get("pivots")
640597
):
641-
source_columns = _pivot_output_columns(pivots[0], source_columns)
598+
source_columns = pivots[0].output_columns(source_columns)
642599
if (
643600
not allow_partial_qualification
644601
and source_columns
@@ -879,7 +836,7 @@ def _expand_stars(
879836
replaced_columns = replace_columns.get(table_id, {})
880837

881838
if pivot:
882-
pivot_columns = pivot.alias_column_names or _pivot_output_columns(pivot, columns)
839+
pivot_columns = pivot.alias_column_names or pivot.output_columns(columns)
883840

884841
if pivot_columns:
885842
new_selections.extend(

0 commit comments

Comments
 (0)