Skip to content

Commit b3fa159

Browse files
OutSquareCapitalCopilot
andauthored
refactor(optimizer): continuation on typing coverage improvements (#7572)
* refactor (optimizer): add type annotations, raise `OptimizeError` if parent is None in `reorder_joins` * refactor (optimizer): go back to original func body in `_distribute` * refactor (optimizer): added a few other type annotations * fix: added from future import annotations for <3.10 python compatibility Co-authored-by: Copilot <copilot@github.com> * fix: narrow pushdown and pushdown_cnf `join_index` mapping key type * fix: narrow `replace_aliases()` return type, as well as the closure argument Co-authored-by: Copilot <copilot@github.com> * fix: delete incorrect `None` return type for `_distribute`, and improve `_predicate_lengths` typing Co-authored-by: Copilot <copilot@github.com> * chore: ruff formatting + linting * fix: annotate `_is_limit_1` function, modified None check to make it pass Co-authored-by: Copilot <copilot@github.com> * refactor: - improve `_unique_outputs` internal and signature typing - added a few annotations to `_join_is_used`, `eliminate_joins`, and `_is_limit_1` functions Co-authored-by: Copilot <copilot@github.com> * fix: prefer parent class `Query` instead of `Select` for `_unique_outputs` internal cast --------- Co-authored-by: Copilot <copilot@github.com>
1 parent 1cf31d6 commit b3fa159

6 files changed

Lines changed: 59 additions & 36 deletions

File tree

sqlglot/optimizer/eliminate_joins.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def eliminate_joins(expression: E) -> E:
3030
The optimized expression
3131
"""
3232
for scope in traverse_scope(expression):
33-
joins = scope.expression.args.get("joins", [])
33+
joins: list[exp.Join] = scope.expression.args.get("joins", [])
3434
if not joins:
3535
continue
3636

@@ -53,7 +53,7 @@ def eliminate_joins(expression: E) -> E:
5353
return expression
5454

5555

56-
def _should_eliminate_join(scope, join, alias):
56+
def _should_eliminate_join(scope: Scope, join: exp.Join, alias: str) -> bool:
5757
inner_source = scope.sources.get(alias)
5858
return (
5959
isinstance(inner_source, Scope)
@@ -65,20 +65,20 @@ def _should_eliminate_join(scope, join, alias):
6565
)
6666

6767

68-
def _join_is_used(scope, join, alias):
68+
def _join_is_used(scope: Scope, join: exp.Join, alias: str) -> bool:
6969
# We need to find all columns that reference this join.
7070
# But columns in the ON clause shouldn't count.
71-
on = join.args.get("on")
72-
if on:
73-
on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
71+
on: exp.Expr | None = join.args.get("on")
72+
if on is not None:
73+
on_clause_columns: set[int] = {id(column) for column in on.find_all(exp.Column)}
7474
else:
7575
on_clause_columns = set()
7676
return any(
7777
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
7878
)
7979

8080

81-
def _is_joined_on_all_unique_outputs(scope, join):
81+
def _is_joined_on_all_unique_outputs(scope: Scope, join: exp.Join) -> bool:
8282
unique_outputs = _unique_outputs(scope)
8383
if not unique_outputs:
8484
return False
@@ -88,18 +88,19 @@ def _is_joined_on_all_unique_outputs(scope, join):
8888
return not remaining_unique_outputs
8989

9090

91-
def _unique_outputs(scope):
91+
def _unique_outputs(scope: Scope) -> set[str]:
9292
"""Determine output columns of `scope` that must have a unique combination per row"""
93-
if scope.expression.args.get("distinct"):
94-
return set(scope.expression.named_selects)
93+
expr = t.cast(exp.Query, scope.expression)
94+
if expr.args.get("distinct") is not None:
95+
return set(expr.named_selects)
9596

96-
group = scope.expression.args.get("group")
97-
if group:
98-
grouped_expressions = set(group.expressions)
99-
grouped_outputs = set()
97+
group: exp.Group | None = expr.args.get("group")
98+
if group is not None:
99+
grouped_expressions: set[exp.Expr] = set(group.expressions)
100+
grouped_outputs: set[exp.Expr] = set()
100101

101-
unique_outputs = set()
102-
for select in scope.expression.selects:
102+
unique_outputs: set[str] = set()
103+
for select in expr.selects:
103104
output = select.unalias()
104105
if output in grouped_expressions:
105106
grouped_outputs.add(output)
@@ -112,22 +113,22 @@ def _unique_outputs(scope):
112113
return set()
113114

114115
if _has_single_output_row(scope):
115-
return set(scope.expression.named_selects)
116+
return set(expr.named_selects)
116117

117118
return set()
118119

119120

120-
def _has_single_output_row(scope):
121+
def _has_single_output_row(scope: Scope) -> bool:
121122
return isinstance(scope.expression, exp.Select) and (
122123
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
123124
or _is_limit_1(scope)
124125
or not scope.expression.args.get("from_")
125126
)
126127

127128

128-
def _is_limit_1(scope):
129-
limit = scope.expression.args.get("limit")
130-
return limit and limit.expression.this == "1"
129+
def _is_limit_1(scope: Scope) -> bool:
130+
limit: exp.Limit | None = scope.expression.args.get("limit")
131+
return limit is not None and limit.expression.this == "1"
131132

132133

133134
def join_condition(join):

sqlglot/optimizer/normalize.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import Callable, Iterator
34
import logging
45

56
from sqlglot import exp
@@ -123,7 +124,9 @@ def normalization_distance(
123124
return total
124125

125126

126-
def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
127+
def _predicate_lengths(
128+
expression: exp.Expr, dnf: bool, max_: float = float("inf"), depth: int = 0
129+
) -> Iterator[int]:
127130
"""
128131
Returns a list of predicate lengths when expanded to normalized form.
129132
@@ -151,7 +154,9 @@ def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
151154
yield from _predicate_lengths(right, dnf, max_, depth)
152155

153156

154-
def distributive_law(expression, dnf, max_distance, simplifier=None):
157+
def distributive_law(
158+
expression: exp.Expr, dnf: bool, max_distance: float, simplifier: Simplifier | None = None
159+
):
155160
"""
156161
x OR (y AND z) -> (x OR y) AND (x OR z)
157162
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -187,7 +192,13 @@ def distributive_law(expression, dnf, max_distance, simplifier=None):
187192
return expression
188193

189194

190-
def _distribute(a, b, from_func, to_func, simplifier):
195+
def _distribute(
196+
a,
197+
b,
198+
from_func: Callable[..., exp.Condition],
199+
to_func: Callable[..., exp.Condition],
200+
simplifier: Simplifier,
201+
):
191202
if isinstance(a, exp.Connector):
192203
exp.replace_children(
193204
a,

sqlglot/optimizer/optimize_joins.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Iterable
33

44
from sqlglot import exp
5+
from sqlglot.errors import OptimizeError
56
from sqlglot.helper import tsort
67

78
JOIN_ATTRS = ("on", "side", "kind", "using", "method")
@@ -18,7 +19,7 @@ def optimize_joins(expression: exp.Expr) -> exp.Expr:
1819
"""
1920

2021
for select in expression.find_all(exp.Select):
21-
joins = select.args.get("joins", [])
22+
joins: list[exp.Join] = select.args.get("joins", [])
2223

2324
if not _is_reorderable(joins):
2425
continue
@@ -57,13 +58,15 @@ def optimize_joins(expression: exp.Expr) -> exp.Expr:
5758
return expression
5859

5960

60-
def reorder_joins(expression) -> exp.Expr:
61+
def reorder_joins(expression: exp.Expr) -> exp.Expr:
6162
"""
6263
Reorder joins by topological sort order based on predicate references.
6364
"""
6465
for from_ in expression.find_all(exp.From):
6566
parent = from_.parent
66-
joins = parent.args.get("joins", [])
67+
if parent is None:
68+
raise OptimizeError("FROM clause without parent expression")
69+
joins: list[exp.Join] = parent.args.get("joins", [])
6770

6871
if not _is_reorderable(joins):
6972
continue

sqlglot/optimizer/pushdown_predicates.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import Mapping
34
import typing as t
45

56
from sqlglot import exp
@@ -80,7 +81,13 @@ def pushdown_predicates(expression: E, dialect: DialectType = None) -> E:
8081
return expression
8182

8283

83-
def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
84+
def pushdown(
85+
condition: exp.Expr | None,
86+
sources,
87+
scope_ref_count,
88+
dialect: DialectType,
89+
join_index: Mapping[str, int] | None = None,
90+
):
8491
if not condition:
8592
return
8693

@@ -99,7 +106,7 @@ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
99106
pushdown_dnf(predicates, sources, scope_ref_count, join_index=join_index)
100107

101108

102-
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
109+
def pushdown_cnf(predicates, sources, scope_ref_count, join_index: Mapping[str, int] | None = None):
103110
"""
104111
If the predicates are in CNF like form, we can simply replace each block in the parent.
105112
"""

sqlglot/optimizer/resolver.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ class Resolver:
2121
This is a class so we can lazily load some things and easily share them across functions.
2222
"""
2323

24-
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
25-
self.scope = scope
26-
self.schema = schema
27-
self.dialect = schema.dialect or Dialect()
24+
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True) -> None:
25+
self.scope: Scope = scope
26+
self.schema: Schema = schema
27+
self.dialect: Dialect = schema.dialect or Dialect()
2828
self._source_columns: dict[str, Sequence[str]] | None = None
2929
self._unambiguous_columns: Mapping[str, str] | None = None
3030
self._all_columns: set[str] | None = None
31-
self._infer_schema = infer_schema
31+
self._infer_schema: bool = infer_schema
3232
self._get_source_columns_cache: dict[tuple[str, bool], Sequence[str]] = {}
3333

3434
def get_table(self, column: str | exp.Column) -> exp.Identifier | None:

sqlglot/optimizer/unnest_subqueries.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from sqlglot import exp
23
from sqlglot.helper import name_sequence
34
from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope
@@ -303,11 +304,11 @@ def remove_aggs(node):
303304
)
304305

305306

306-
def _replace(expression, condition):
307+
def _replace(expression: exp.Expr, condition: exp.ExpOrStr) -> exp.Expr:
307308
return expression.replace(exp.condition(condition))
308309

309310

310-
def _other_operand(expression):
311+
def _other_operand(expression: object) -> exp.Expr | None:
311312
if isinstance(expression, exp.In):
312313
return expression.this
313314

0 commit comments

Comments
 (0)