Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 35 additions & 42 deletions sqlglot/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

class Plan:
def __init__(self, expression: exp.Expr) -> None:
self.expression = expression.copy()
self.root = Step.from_expression(self.expression)
self.expression: exp.Expr = expression.copy()
self.root: Step = Step.from_expression(self.expression)
self._dag: dict[Step, set[Step]] = {}

@property
Expand Down Expand Up @@ -93,10 +93,10 @@ def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = No
"""
ctes = ctes or {}
expression = expression.unnest()
with_ = expression.args.get("with_")
with_: exp.With | None = expression.args.get("with_")

# CTEs break the mold of scope and introduce themselves to all in the context.
if with_:
if with_ is not None:
ctes = ctes.copy()
for cte in with_.expressions:
step = Step.from_expression(cte.this, ctes)
Expand All @@ -112,23 +112,22 @@ def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = No
else:
step = Scan()

joins = expression.args.get("joins")
joins: list[exp.Join] | None = expression.args.get("joins")

if joins:
if joins is not None:
join = Join.from_joins(joins, ctes)
join.name = step.name
join.source_name = step.name
join.add_dependency(step)
step = join

projections: list[
exp.Expr
] = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
aggregations = {}
# final selects in this chain of steps representing a select
projections: list[exp.Expr] = []
# intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
operands: dict[exp.Expr, str] = {}
aggregations: dict[exp.Expr, None] = {}
next_operand_name = name_sequence("_a_")

def extract_agg_operands(expression):
def extract_agg_operands(expression: exp.Expr) -> bool:
agg_funcs = tuple(expression.find_all(exp.AggFunc))
if agg_funcs:
aggregations[expression] = None
Expand All @@ -144,7 +143,7 @@ def extract_agg_operands(expression):

return bool(agg_funcs)

def set_ops_and_aggs(step):
def set_ops_and_aggs(step) -> None:
step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
step.aggregations = list(aggregations)

Expand All @@ -155,21 +154,21 @@ def set_ops_and_aggs(step):
else:
projections.append(e)

where = expression.args.get("where")
where: exp.Where | None = expression.args.get("where")

if where:
if where is not None:
step.condition = where.this

group = expression.args.get("group")
group: exp.Group | None = expression.args.get("group")

if group or aggregations:
if group is not None or aggregations:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name

having = expression.args.get("having")
having: exp.Having | None = expression.args.get("having")

if having:
if having is not None:
if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
aggregate.condition = exp.column("_h", step.name, quoted=True)
else:
Expand Down Expand Up @@ -205,10 +204,10 @@ def set_ops_and_aggs(step):
else:
aggregate = None

order = expression.args.get("order")
order: exp.Order | None = expression.args.get("order")

if order:
if aggregate and isinstance(step, Aggregate):
if order is not None:
if aggregate is not None and isinstance(step, Aggregate):
for i, ordered in enumerate(order.expressions):
if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
Expand All @@ -234,9 +233,9 @@ def set_ops_and_aggs(step):
distinct.add_dependency(step)
step = distinct

limit = expression.args.get("limit")
limit: exp.Limit | None = expression.args.get("limit")

if limit:
if limit is not None:
step.limit = int(limit.text("expression"))

return step
Expand Down Expand Up @@ -304,7 +303,7 @@ def _to_s(self, _indent: str) -> list[str]:
class Scan(Step):
@classmethod
def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step:
table = expression
table: exp.Expr = expression
alias_ = expression.alias_or_name

if isinstance(expression, exp.Subquery):
Expand Down Expand Up @@ -356,7 +355,7 @@ def _to_s(self, indent: str) -> list[str]:
lines = [f"{indent}Source: {self.source_name or self.name}"]
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
join_key = ", ".join(str(key) for key in t.cast(list[str], join.get("join_key") or []))
if join_key:
lines.append(f"{indent}Key: {join_key}")
if join.get("condition"):
Expand Down Expand Up @@ -396,7 +395,7 @@ def _to_s(self, indent: str) -> list[str]:
class Sort(Step):
def __init__(self) -> None:
super().__init__()
self.key = None
self.key: list[exp.Expr] | None = None

def _to_s(self, indent: str) -> list[str]:
lines = [f"{indent}Key:"]
Expand All @@ -408,18 +407,12 @@ def _to_s(self, indent: str) -> list[str]:


class SetOperation(Step):
def __init__(
self,
op: type[exp.Expr],
left: str | None,
right: str | None,
distinct: bool = False,
) -> None:
def __init__(self, op: type[exp.Expr], left: str, right: str, distinct: bool = False) -> None:
super().__init__()
self.op = op
self.left = left
self.right = right
self.distinct = distinct
self.op: type[exp.Expr] = op
self.left: str = left
self.right: str = right
self.distinct: bool = distinct

@classmethod
def from_expression(
Expand All @@ -442,15 +435,15 @@ def from_expression(
step.add_dependency(left)
step.add_dependency(right)

limit = expression.args.get("limit")
limit: exp.Limit | None = expression.args.get("limit")

if limit:
if limit is not None:
step.limit = int(limit.text("expression"))

return step

def _to_s(self, indent: str) -> list[str]:
lines = []
lines: list[str] = []
if self.distinct:
lines.append(f"{indent}Distinct: {self.distinct}")
return lines
Expand Down
12 changes: 6 additions & 6 deletions sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import Sequence
from typing_extensions import Unpack

ColumnMapping = t.Union[dict, str, list]
ColumnMapping = t.Union[dict[str, t.Any], str, list[str]]


@trait
Expand Down Expand Up @@ -344,7 +344,7 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
def find(
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
) -> t.Any | None:
schema = super().find(
schema: exp.Table | dict[str, object] | None = super().find(
table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
)
if ensure_data_types and isinstance(schema, dict):
Expand Down Expand Up @@ -417,7 +417,7 @@ def column_names(
) -> list[str]:
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)

schema = self.find(normalized_table)
schema: exp.Table | dict[str, object] | None = self.find(normalized_table)
if schema is None:
return []

Expand All @@ -440,7 +440,7 @@ def get_column_type(
column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
)

table_schema = self.find(normalized_table, raise_on_missing=False)
table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
if table_schema:
column_type = table_schema.get(normalized_column_name)

Expand Down Expand Up @@ -500,7 +500,7 @@ def has_column(
column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
)

table_schema = self.find(normalized_table, raise_on_missing=False)
table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
return normalized_column_name in table_schema if table_schema else False

def _normalize(self, schema: dict[str, object]) -> dict[str, object]:
Expand Down Expand Up @@ -708,7 +708,7 @@ def ensure_schema(
return MappingSchema(schema, **kwargs)


def ensure_column_mapping(mapping: ColumnMapping | None) -> dict:
def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, t.Any]:
if mapping is None:
return {}
elif isinstance(mapping, dict):
Expand Down
32 changes: 21 additions & 11 deletions sqlglot/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import typing as t

from sqlglot import expressions as exp
from types import ModuleType

if t.TYPE_CHECKING:
from typing_extensions import TypeIs

StackVal = tuple[t.Union[exp.Expr, exp.DType, t.Any], t.Optional[int], t.Optional[str], bool]


INDEX = "i"
Expand All @@ -21,8 +27,8 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]:
Dump an Expr into a JSON serializable List.
"""
i = 0
payloads = []
stack: list[tuple[t.Any, int | None, str | None, bool]] = [(expression, None, None, False)]
payloads: list[dict[str, t.Any]] = []
stack: list[StackVal] = [(expression, None, None, False)]

while stack:
node, index, arg_key, is_array = stack.pop()
Expand All @@ -38,7 +44,7 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]:

payloads.append(payload)

if hasattr(node, "parent"):
if _has_parent(node):
klass = node.__class__.__qualname__

if node.__class__.__module__ != exp.__name__:
Expand All @@ -54,12 +60,12 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]:
payload[META] = node._meta
if node.args:
for k, vs in reversed(node.args.items()):
if type(vs) is list:
if isinstance(vs, list):
for v in reversed(vs):
stack.append((v, i, k, True))
elif vs is not None:
stack.append((vs, i, k, False))
elif type(node) is exp.DType:
elif isinstance(node, exp.DType):
payload[CLASS] = DATA_TYPE
payload[VALUE] = node.value
else:
Expand All @@ -70,6 +76,10 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]:
return payloads


def _has_parent(node: object) -> TypeIs[exp.Expr]:
return hasattr(node, "parent")


def load(
payloads: list[dict[str, t.Any]] | None,
) -> exp.Expr | exp.DType | None:
Expand All @@ -82,16 +92,16 @@ def load(

payload, *tail = payloads
root = _load(payload)
nodes: list[object] = [root]
nodes: list[exp.Expr | exp.DType | t.Any] = [root]
for payload in tail:
if CLASS in payload:
node: object = _load(payload)
node = _load(payload)
else:
node = payload[VALUE]

nodes.append(node)
parent = nodes[payload[INDEX]]
arg_key = payload[ARG_KEY]
parent: exp.Expr = nodes[payload[INDEX]]
arg_key: str = payload[ARG_KEY]

if payload.get(IS_ARRAY):
parent.append(arg_key, node)
Expand All @@ -102,11 +112,11 @@ def load(


def _load(payload: dict[str, t.Any]) -> exp.Expr | exp.DType:
class_name = payload[CLASS]
class_name: str = payload[CLASS]

if class_name == DATA_TYPE:
return exp.DType(payload[VALUE])

module: ModuleType
if "." in class_name:
module_path, class_name = class_name.rsplit(".", maxsplit=1)
module = __import__(module_path, fromlist=[class_name])
Expand Down
11 changes: 6 additions & 5 deletions sqlglot/time.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import typing as t
from __future__ import annotations
import datetime
import typing as t

# The generic time format is based on python time.strftime.
# https://docs.python.org/3/library/time.html#time.strftime
from sqlglot.trie import TrieResult, in_trie, new_trie


def format_time(
string: str, mapping: dict[str, str], trie: t.Optional[dict] = None
) -> t.Optional[str]:
string: str, mapping: dict[str, str], trie: dict[str, t.Any] | None = None
) -> str | None:
"""
Converts a time string given a mapping.

Expand All @@ -31,7 +32,7 @@ def format_time(
size = len(string)
trie = trie or new_trie(mapping)
current = trie
chunks = []
chunks: list[str] = []
sym = None

while end <= size:
Expand Down Expand Up @@ -61,7 +62,7 @@ def format_time(
return "".join(mapping.get(chars, chars) for chars in chunks)


TIMEZONES = {
TIMEZONES: set[str] = {
tz.lower()
for tz in (
"Africa/Abidjan",
Expand Down
Loading
Loading