diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 8077afddec..17cda2fc10 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -11,8 +11,8 @@ class Plan: def __init__(self, expression: exp.Expr) -> None: - self.expression = expression.copy() - self.root = Step.from_expression(self.expression) + self.expression: exp.Expr = expression.copy() + self.root: Step = Step.from_expression(self.expression) self._dag: dict[Step, set[Step]] = {} @property @@ -93,10 +93,10 @@ def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = No """ ctes = ctes or {} expression = expression.unnest() - with_ = expression.args.get("with_") + with_: exp.With | None = expression.args.get("with_") # CTEs break the mold of scope and introduce themselves to all in the context. - if with_: + if with_ is not None: ctes = ctes.copy() for cte in with_.expressions: step = Step.from_expression(cte.this, ctes) @@ -112,23 +112,22 @@ def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = No else: step = Scan() - joins = expression.args.get("joins") + joins: list[exp.Join] | None = expression.args.get("joins") - if joins: + if joins is not None: join = Join.from_joins(joins, ctes) join.name = step.name join.source_name = step.name join.add_dependency(step) step = join - - projections: list[ - exp.Expr - ] = [] # final selects in this chain of steps representing a select - operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) - aggregations = {} + # final selects in this chain of steps representing a select + projections: list[exp.Expr] = [] + # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) + operands: dict[exp.Expr, str] = {} + aggregations: dict[exp.Expr, None] = {} next_operand_name = name_sequence("_a_") - def extract_agg_operands(expression): + def extract_agg_operands(expression: exp.Expr) -> bool: agg_funcs = tuple(expression.find_all(exp.AggFunc)) if agg_funcs: aggregations[expression] = None @@ -144,7 +143,7 @@ def extract_agg_operands(expression): return bool(agg_funcs) - def set_ops_and_aggs(step): + def set_ops_and_aggs(step) -> None: step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) step.aggregations = list(aggregations) @@ -155,21 +154,21 @@ def set_ops_and_aggs(step): else: projections.append(e) - where = expression.args.get("where") + where: exp.Where | None = expression.args.get("where") - if where: + if where is not None: step.condition = where.this - group = expression.args.get("group") + group: exp.Group | None = expression.args.get("group") - if group or aggregations: + if group is not None or aggregations: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name - having = expression.args.get("having") + having: exp.Having | None = expression.args.get("having") - if having: + if having is not None: if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)): aggregate.condition = exp.column("_h", step.name, quoted=True) else: @@ -205,10 +204,10 @@ def set_ops_and_aggs(step): else: aggregate = None - order = expression.args.get("order") + order: exp.Order | None = expression.args.get("order") - if order: - if aggregate and isinstance(step, Aggregate): + if order is not None: + if aggregate is not None and isinstance(step, Aggregate): for i, ordered in enumerate(order.expressions): if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)): ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True)) @@ -234,9 +233,9 @@ def set_ops_and_aggs(step): distinct.add_dependency(step) step = distinct - limit = expression.args.get("limit") + limit: exp.Limit | None = expression.args.get("limit") - if limit: + if limit is not None: step.limit = int(limit.text("expression")) return step @@ -304,7 +303,7 @@ def _to_s(self, _indent: str) -> list[str]: class Scan(Step): @classmethod def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step: - table = expression + table: exp.Expr = expression alias_ = expression.alias_or_name if isinstance(expression, exp.Subquery): @@ -356,7 +355,7 @@ def _to_s(self, indent: str) -> list[str]: lines = [f"{indent}Source: {self.source_name or self.name}"] for name, join in self.joins.items(): lines.append(f"{indent}{name}: {join['side'] or 'INNER'}") - join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or [])) + join_key = ", ".join(str(key) for key in t.cast(list[str], join.get("join_key") or [])) if join_key: lines.append(f"{indent}Key: {join_key}") if join.get("condition"): @@ -396,7 +395,7 @@ def _to_s(self, indent: str) -> list[str]: class Sort(Step): def __init__(self) -> None: super().__init__() - self.key = None + self.key: list[exp.Expr] | None = None def _to_s(self, indent: str) -> list[str]: lines = [f"{indent}Key:"] @@ -408,18 +407,12 @@ def _to_s(self, indent: str) -> list[str]: class SetOperation(Step): - def __init__( - self, - op: type[exp.Expr], - left: str | None, - right: str | None, - distinct: bool = False, - ) -> None: + def __init__(self, op: type[exp.Expr], left: str, right: str, distinct: bool = False) -> None: super().__init__() - self.op = op - self.left = left - self.right = right - self.distinct = distinct + self.op: type[exp.Expr] = op + self.left: str = left + self.right: str = right + self.distinct: bool = distinct @classmethod def from_expression( @@ -442,15 +435,15 @@ def from_expression( step.add_dependency(left) step.add_dependency(right) - limit = expression.args.get("limit") + limit: exp.Limit | None = expression.args.get("limit") - if limit: + if limit is not None: step.limit = int(limit.text("expression")) return step def _to_s(self, indent: str) -> list[str]: - lines = [] + lines: list[str] = [] if self.distinct: lines.append(f"{indent}Distinct: {self.distinct}") return lines diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 6d8c441cef..0efdd1686b 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from typing_extensions import Unpack - ColumnMapping = t.Union[dict, str, list] + ColumnMapping = t.Union[dict[str, t.Any], str, list[str]] @trait @@ -344,7 +344,7 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: def find( self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False ) -> t.Any | None: - schema = super().find( + schema: exp.Table | dict[str, object] | None = super().find( table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types ) if ensure_data_types and isinstance(schema, dict): @@ -417,7 +417,7 @@ def column_names( ) -> list[str]: normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) - schema = self.find(normalized_table) + schema: exp.Table | dict[str, object] | None = self.find(normalized_table) if schema is None: return [] @@ -440,7 +440,7 @@ def get_column_type( column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize ) - table_schema = self.find(normalized_table, raise_on_missing=False) + table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False) if table_schema: column_type = table_schema.get(normalized_column_name) @@ -500,7 +500,7 @@ def has_column( column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize ) - table_schema = self.find(normalized_table, raise_on_missing=False) + table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False) return normalized_column_name in table_schema if table_schema else False def _normalize(self, schema: dict[str, object]) -> dict[str, object]: @@ -708,7 +708,7 @@ def ensure_schema( return MappingSchema(schema, **kwargs) -def ensure_column_mapping(mapping: ColumnMapping | None) -> dict: +def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, t.Any]: if mapping is None: return {} elif isinstance(mapping, dict): diff --git a/sqlglot/serde.py b/sqlglot/serde.py index 44598ef12b..bc76de7323 100644 --- a/sqlglot/serde.py +++ b/sqlglot/serde.py @@ -3,6 +3,12 @@ import typing as t from sqlglot import expressions as exp +from types import ModuleType + +if t.TYPE_CHECKING: + from typing_extensions import TypeIs + +StackVal = tuple[t.Union[exp.Expr, exp.DType, t.Any], t.Optional[int], t.Optional[str], bool] INDEX = "i" @@ -21,8 +27,8 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]: Dump an Expr into a JSON serializable List. """ i = 0 - payloads = [] - stack: list[tuple[t.Any, int | None, str | None, bool]] = [(expression, None, None, False)] + payloads: list[dict[str, t.Any]] = [] + stack: list[StackVal] = [(expression, None, None, False)] while stack: node, index, arg_key, is_array = stack.pop() @@ -38,7 +44,7 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]: payloads.append(payload) - if hasattr(node, "parent"): + if _has_parent(node): klass = node.__class__.__qualname__ if node.__class__.__module__ != exp.__name__: @@ -54,12 +60,12 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]: payload[META] = node._meta if node.args: for k, vs in reversed(node.args.items()): - if type(vs) is list: + if isinstance(vs, list): for v in reversed(vs): stack.append((v, i, k, True)) elif vs is not None: stack.append((vs, i, k, False)) - elif type(node) is exp.DType: + elif isinstance(node, exp.DType): payload[CLASS] = DATA_TYPE payload[VALUE] = node.value else: @@ -70,6 +76,10 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]: return payloads +def _has_parent(node: object) -> TypeIs[exp.Expr]: + return hasattr(node, "parent") + + def load( payloads: list[dict[str, t.Any]] | None, ) -> exp.Expr | exp.DType | None: @@ -82,16 +92,16 @@ def load( payload, *tail = payloads root = _load(payload) - nodes: list[object] = [root] + nodes: list[exp.Expr | exp.DType | t.Any] = [root] for payload in tail: if CLASS in payload: - node: object = _load(payload) + node = _load(payload) else: node = payload[VALUE] nodes.append(node) - parent = nodes[payload[INDEX]] - arg_key = payload[ARG_KEY] + parent: exp.Expr = nodes[payload[INDEX]] + arg_key: str = payload[ARG_KEY] if payload.get(IS_ARRAY): parent.append(arg_key, node) @@ -102,11 +112,11 @@ def load( def _load(payload: dict[str, t.Any]) -> exp.Expr | exp.DType: - class_name = payload[CLASS] + class_name: str = payload[CLASS] if class_name == DATA_TYPE: return exp.DType(payload[VALUE]) - + module: ModuleType if "." in class_name: module_path, class_name = class_name.rsplit(".", maxsplit=1) module = __import__(module_path, fromlist=[class_name]) diff --git a/sqlglot/time.py b/sqlglot/time.py index e6de12a773..2ec44f96df 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -1,5 +1,6 @@ -import typing as t +from __future__ import annotations import datetime +import typing as t # The generic time format is based on python time.strftime. # https://docs.python.org/3/library/time.html#time.strftime @@ -7,8 +8,8 @@ def format_time( - string: str, mapping: dict[str, str], trie: t.Optional[dict] = None -) -> t.Optional[str]: + string: str, mapping: dict[str, str], trie: dict[str, t.Any] | None = None +) -> str | None: """ Converts a time string given a mapping. @@ -31,7 +32,7 @@ def format_time( size = len(string) trie = trie or new_trie(mapping) current = trie - chunks = [] + chunks: list[str] = [] sym = None while end <= size: @@ -61,7 +62,7 @@ def format_time( return "".join(mapping.get(chars, chars) for chars in chunks) -TIMEZONES = { +TIMEZONES: set[str] = { tz.lower() for tz in ( "Africa/Abidjan", diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index a9266d3090..eaaf9b4528 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -28,7 +28,7 @@ def preprocess( Function that can be used as a generator transform. """ - def _to_sql(self, expression: exp.Expr) -> str: + def _to_sql(self: Generator, expression: exp.Expr) -> str: expression_type = type(expression) try: @@ -41,7 +41,9 @@ def _to_sql(self, expression: exp.Expr) -> str: if generator: return generator(self, expression) - _sql_handler = getattr(self, expression.key + "_sql", None) + _sql_handler: t.Callable[[exp.Expr], str] | None = getattr( + self, expression.key + "_sql", None + ) if _sql_handler: return _sql_handler(expression) @@ -68,7 +70,7 @@ def _to_sql(self, expression: exp.Expr) -> str: def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp.Expr: if isinstance(expression, exp.Select): count = 0 - recursive_ctes = [] + recursive_ctes: list[exp.Expr] = [] for unnest in expression.find_all(exp.Unnest): if ( @@ -79,15 +81,17 @@ def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp. continue generate_date_array = unnest.expressions[0] - start = generate_date_array.args.get("start") - end = generate_date_array.args.get("end") - step = generate_date_array.args.get("step") + start: exp.Expr | None = generate_date_array.args.get("start") + end: exp.Expr | None = generate_date_array.args.get("end") + step: exp.Expr | None = generate_date_array.args.get("step") if not start or not end or not isinstance(step, exp.Interval): continue - alias = unnest.args.get("alias") - column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" + alias: exp.TableAlias | None = unnest.args.get("alias") + column_name: str = ( + alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" + ) start = exp.cast(start, "date") date_add = exp.func( @@ -114,7 +118,7 @@ def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp. count += 1 if recursive_ctes: - with_expression = expression.args.get("with_") or exp.With() + with_expression: exp.With = expression.args.get("with_") or exp.With() with_expression.set("recursive", True) with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) expression.set("with_", with_expression) @@ -157,7 +161,7 @@ def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr: distinct_cols = expression.args["distinct"].pop().args["on"].expressions window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) - order = expression.args.get("order") + order: exp.Order | None = expression.args.get("order") if order: window.set("order", order.pop()) else: @@ -166,7 +170,7 @@ def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr: expression.select(exp.alias_(window, row_number_window_alias), copy=False) # We add aliases to the projections so that we can safely reference them in the outer query - new_selects = [] + new_selects: list[exp.Expr] = [] taken_names = {row_number_window_alias} for select in expression.selects[:-1]: if select.is_star: @@ -175,7 +179,9 @@ def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr: if not isinstance(select, exp.Alias): alias = find_new_name(taken_names, select.output_name or "_col") - quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None + quoted: bool | None = ( + select.this.args.get("quoted") if isinstance(select, exp.Column) else None + ) select = select.replace(exp.alias_(select, alias, quoted=quoted)) taken_names.add(select.output_name) @@ -219,16 +225,16 @@ def _select_alias_or_name(select: exp.Expr) -> str | exp.Column: return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) return alias_or_name - outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) - qualify_filters = expression.args["qualify"].pop().this - expression_by_alias = { + outer_selects = exp.select(*map(_select_alias_or_name, expression.selects)) + qualify_filters: exp.Expr = expression.args["qualify"].pop().this + expression_by_alias: dict[str, exp.Expr] = { select.alias: select.this for select in expression.selects if isinstance(select, exp.Alias) } - select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) - for select_candidate in list(qualify_filters.find_all(select_candidates)): + select_candidates = (exp.Window,) if expression.is_star else (exp.Window, exp.Column) + for select_candidate in qualify_filters.find_all(*select_candidates): if isinstance(select_candidate, exp.Window): if expression_by_alias: for column in select_candidate.find_all(exp.Column): @@ -314,14 +320,14 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]: from_ = expression.args.get("from_") if from_ and isinstance(from_.this, exp.Unnest): - unnest = from_.this - alias = unnest.args.get("alias") - exprs = unnest.expressions + unnest: exp.Unnest = from_.this + alias: exp.TableAlias | None = unnest.args.get("alias") + exprs: list[exp.Expr] = unnest.expressions has_multi_expr = len(exprs) > 1 this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr) - columns = alias.columns if alias else [] - offset = unnest.args.get("offset") + columns: list[exp.Identifier] = alias.columns if alias else [] + offset: exp.Expr | None = unnest.args.get("offset") if offset: columns.insert( 0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos") @@ -334,7 +340,7 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]: ) ) - joins = expression.args.get("joins") or [] + joins: list[exp.Join] = expression.args.get("joins") or [] for join in list(joins): join_expr = join.this @@ -347,6 +353,10 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]: alias = join_expr.args.get("alias") else: alias = unnest.args.get("alias") + if alias is None: + raise UnsupportedError( + "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires an alias" + ) exprs = unnest.expressions # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here has_multi_expr = len(exprs) > 1 @@ -354,7 +364,7 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]: joins.remove(join) - alias_cols = alias.columns if alias else [] + alias_cols: list[exp.Identifier] = alias.columns # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. @@ -378,10 +388,7 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]: exp.Lateral( this=_udtf_type(unnest, has_multi_expr)(this=e), view=True, - alias=exp.TableAlias( - this=alias.this, # type: ignore - columns=alias_cols, - ), + alias=exp.TableAlias(this=alias.this, columns=alias_cols), ), ) @@ -604,8 +611,8 @@ def epoch_cast_to_ts(expression: exp.Expr) -> exp.Expr: def eliminate_semi_and_anti_joins(expression: exp.Expr) -> exp.Expr: """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" if isinstance(expression, exp.Select): - for join in list(expression.args.get("joins") or []): - on = join.args.get("on") + for join in list[exp.Join](expression.args.get("joins") or []): + on: exp.Expr | None = join.args.get("on") if on and join.kind in ("SEMI", "ANTI"): subquery = exp.select("1").from_(join.this).where(on) exists: exp.Exists | exp.Not = exp.Exists(this=subquery) @@ -625,9 +632,9 @@ def eliminate_full_outer_join(expression: exp.Expr) -> exp.Expr: for queries that have a single FULL OUTER join. """ if isinstance(expression, exp.Select): - full_outer_joins = [ + full_outer_joins: list[tuple[int, exp.Join]] = [ (index, join) - for index, join in enumerate(expression.args.get("joins") or []) + for index, join in enumerate[exp.Join](expression.args.get("joins") or []) if join.side == "FULL" ] @@ -640,7 +647,7 @@ def eliminate_full_outer_join(expression: exp.Expr) -> exp.Expr: join_conditions = full_outer_join.args.get("on") or exp.and_( *[ exp.column(col, tables[0]).eq(exp.column(col, tables[1])) - for col in full_outer_join.args.get("using") + for col in t.cast(list[exp.Identifier], full_outer_join.args.get("using")) ] ) @@ -670,7 +677,7 @@ def move_ctes_to_top_level(expression: E) -> E: TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). """ - top_level_with = expression.args.get("with_") + top_level_with: exp.With | None = expression.args.get("with_") for inner_with in expression.find_all(exp.With): if inner_with.parent is expression: continue @@ -741,10 +748,10 @@ def ctas_with_tmp_tables_to_create_tmp_view( tmp_storage_provider: t.Callable[[exp.Expr], exp.Expr] = lambda e: e, ) -> exp.Expr: assert isinstance(expression, exp.Create) - properties = expression.args.get("properties") + properties: exp.Properties | None = expression.args.get("properties") temporary = any( isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) + for prop in (properties.expressions if properties is not None else []) ) # CTAS with temp tables map to CREATE TEMPORARY VIEW @@ -773,9 +780,11 @@ def move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr: if has_schema and is_partitionable: prop = expression.find(exp.PartitionedByProperty) if prop and prop.this and not isinstance(prop.this, exp.Schema): - schema = expression.this - columns = {v.name.upper() for v in prop.this.expressions} - partitions = [col for col in schema.expressions if col.name.upper() in columns] + schema: exp.Schema = expression.this + columns: set[str] = {v.name.upper() for v in prop.this.expressions} + partitions: list[exp.ColumnDef] = [ + col for col in schema.expressions if col.name.upper() in columns + ] schema.set("expressions", [e for e in schema.expressions if e not in partitions]) prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) expression.set("this", schema) @@ -800,7 +809,7 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expr) -> exp.Expr: prop_this = exp.Tuple( expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] ) - schema = expression.this + schema: exp.Schema = expression.this for e in prop.this.expressions: schema.append("expressions", e) prop.set("this", prop_this) @@ -871,8 +880,8 @@ def eliminate_join_marks(expression: exp.Expr) -> exp.Expr: for scope in reversed(traverse_scope(expression)): query = scope.expression - where = query.args.get("where") - joins = query.args.get("joins", []) + where: exp.Expr | None = query.args.get("where") + joins: list[exp.Join] = query.args.get("joins", []) if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)): continue @@ -884,7 +893,9 @@ def eliminate_join_marks(expression: exp.Expr) -> exp.Expr: where = normalize(where.this) assert normalized(where), "Cannot normalize JOIN predicates" - joins_ons = defaultdict(list) # dict of {name: list of join AND conditions} + joins_ons: defaultdict[str, list[exp.Expr]] = defaultdict( + list + ) # dict of {name: list of join AND conditions} for cond in [where] if not isinstance(where, exp.And) else where.flatten(): join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")] @@ -902,7 +913,7 @@ def eliminate_join_marks(expression: exp.Expr) -> exp.Expr: joins_ons[left_join_table.pop()].append(cond) old_joins = {join.alias_or_name: join for join in joins} - new_joins = {} + new_joins: dict[str, exp.Join] = {} query_from = query.args["from_"] for table, predicates in joins_ons.items(): @@ -924,12 +935,12 @@ def eliminate_join_marks(expression: exp.Expr) -> exp.Expr: parent.pop() if query_from.alias_or_name in new_joins: - only_old_joins = old_joins.keys() - new_joins.keys() + only_old_joins: set[str] = old_joins.keys() - new_joins.keys() assert len(only_old_joins) >= 1, ( "Cannot determine which table to use in the new FROM clause" ) - new_from_name = list(only_old_joins)[0] + new_from_name = list[str](only_old_joins)[0] query.set("from_", exp.From(this=old_joins[new_from_name].this)) if new_joins: @@ -956,7 +967,7 @@ def any_to_exists(expression: exp.Expr) -> exp.Expr: """ if isinstance(expression, exp.Select): for any_expr in expression.find_all(exp.Any): - this = any_expr.this + this: exp.Expr = any_expr.this if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)): continue @@ -972,10 +983,10 @@ def any_to_exists(expression: exp.Expr) -> exp.Expr: def eliminate_window_clause(expression: exp.Expr) -> exp.Expr: """Eliminates the `WINDOW` query clause by inling each named window.""" - if isinstance(expression, exp.Select) and expression.args.get("windows"): + windows: list[exp.Expr] | None = expression.args.get("windows") + if isinstance(expression, exp.Select) and windows is not None: from sqlglot.optimizer.scope import find_all_in_scope - windows = expression.args["windows"] expression.set("windows", None) window_expression: dict[str, exp.Expr] = {} @@ -987,8 +998,8 @@ def _inline_inherited_window(window: exp.Expr) -> None: window.set("alias", None) for key in ("partition_by", "order", "spec"): - arg = inherited_window.args.get(key) - if arg: + arg: exp.Expr | None = inherited_window.args.get(key) + if arg is not None: window.set(key, arg.copy()) for window in windows: @@ -1029,7 +1040,7 @@ def inherit_struct_field_names(expression: exp.Expr) -> exp.Expr: and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct) and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions) ): - field_names = [fld.this for fld in first_item.expressions] + field_names: list[exp.Identifier] = [fld.this for fld in first_item.expressions] # Apply field names to subsequent structs that don't have them for struct in expression.expressions[1:]: @@ -1037,7 +1048,7 @@ def inherit_struct_field_names(expression: exp.Expr) -> exp.Expr: continue # Convert unnamed expressions to PropertyEQ with inherited names - new_expressions = [] + new_expressions: list[exp.PropertyEQ] = [] for i, expr in enumerate(struct.expressions): if not isinstance(expr, exp.PropertyEQ): # Create PropertyEQ: field_name := value, preserving the type from the inner expression