Skip to content

Commit d146dcd

Browse files
authored
feat(lineage): support all-columns mode and on_node callback (#7575)
* feat(lineage): support all-columns mode and on_node callback Adds an extension to lineage() so that passing column=None produces a dict[str, Node] mapping every top-level output column name to its lineage Node. The single-column form (str | exp.Column) is unchanged and continues to return a Node. Typing overloads disambiguate the two return shapes for callers. A new on_node callback is invoked for every Node created during the walk, after its downstream is populated. Combined with Node.payload — a caller-managed dict — this lets callers thread per-node data through the lineage graph during construction without subclassing Node or rewalking it after the fact. Performance: * Resolving a column to its select expression scanned selectable.selects on every to_node call. Wide queries with many output columns made this O(N^2). Memoize a per-scope {name: select} map and the selectable.is_star bit on first lookup instead. * Compile sqlglot/lineage.py via mypyc by listing it in sqlglotc's _source_files. Together with the memoization above, this shrinks end-to-end all-columns lineage cost on large CTE-heavy queries by roughly 2x compared to the unmemoized pure-Python path. * test(lineage): cover all-columns mode and on_node invariants Adds tests for the column=None form of lineage() and the on_node callback contract: * column=None returns a dict keyed by every top-level output column, with each entry shaped like single-column lineage(). * shared upstream Nodes are deduplicated across output columns by the per-call cache (same source column referenced from multiple selects yields a single shared downstream Node). * UNION CTEs fan out correctly — each output column points at one downstream per branch and bottoms out at every branch's base table. * passing a pre-built Scope returns the same Node tree as the no-scope path, with no second qualify pass. * the on_node callback fires children before parents, so callers can populate Node.payload bottom-up from already-finalized children. * on_node fires exactly once per Node, even when a Node is reached from multiple parents.
1 parent 717a50d commit d146dcd

3 files changed

Lines changed: 275 additions & 29 deletions

File tree

sqlglot/lineage.py

Lines changed: 111 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class Node:
2828
source_name: str = ""
2929
reference_node_name: str = ""
3030

31+
# Caller-injected per-node data, populated via the `on_node` hook on lineage()
32+
payload: dict[str, t.Any] = field(default_factory=dict)
33+
3134
def walk(self) -> Iterator[Node]:
3235
visited: set[int] = set()
3336
queue = [self]
@@ -74,36 +77,50 @@ def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) ->
7477
return GraphHTML(nodes, edges, **opts)
7578

7679

80+
@t.overload
81+
def lineage(column: str | exp.Column, sql: str | exp.Expr, **kwargs: t.Any) -> Node: ...
82+
83+
84+
@t.overload
85+
def lineage(column: None, sql: str | exp.Expr, **kwargs: t.Any) -> dict[str, Node]: ...
86+
87+
7788
def lineage(
78-
column: str | exp.Column,
89+
column: str | exp.Column | None,
7990
sql: str | exp.Expr,
8091
schema: dict | Schema | None = None,
8192
sources: Mapping[str, str | exp.Query] | None = None,
8293
dialect: DialectType = None,
8394
scope: Scope | None = None,
8495
trim_selects: bool = True,
8596
copy: bool = True,
97+
on_node: t.Callable[[Node], None] | None = None,
8698
**kwargs,
87-
) -> Node:
88-
"""Build the lineage graph for a column of a SQL query.
99+
) -> Node | dict[str, Node]:
100+
"""Build the lineage graph for a SQL query.
101+
102+
If `column` is given, returns the lineage Node for that single output column.
103+
If `column` is None, returns a dict mapping every top-level output column name
104+
to its lineage Node (with a shared cache so cross-column work is deduplicated).
89105
90106
Args:
91-
column: The column to build the lineage for.
107+
column: The column to build the lineage for. Pass None to get all output columns.
92108
sql: The SQL string or expression.
93109
schema: The schema of tables.
94110
sources: A mapping of queries which will be used to continue building lineage.
95111
dialect: The dialect of input SQL.
96112
scope: A pre-created scope to use instead.
97113
trim_selects: Whether to clean up selects by trimming to only relevant columns.
98114
copy: Whether to copy the Expr arguments.
115+
on_node: Optional callback invoked for every Node created during the walk,
116+
after the Node's downstream is populated. Useful for injecting
117+
caller-managed data into Node.payload during the walk.
99118
**kwargs: Qualification optimizer kwargs.
100119
101120
Returns:
102-
A lineage node.
121+
A Node when `column` is provided, or a dict[str, Node] when `column` is None.
103122
"""
104-
105123
expression = maybe_parse(sql, copy=copy, dialect=dialect)
106-
column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
107124

108125
if sources:
109126
expression = exp.expand(
@@ -123,19 +140,50 @@ def lineage(
123140
schema=schema,
124141
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
125142
)
126-
127143
scope = build_scope(expression)
128144

129145
if not scope:
130146
raise SqlglotError("Cannot build lineage, sql must be SELECT")
131147

132148
selectable = scope.expression
133-
if not isinstance(selectable, exp.Selectable) or not any(
134-
select.alias_or_name == column for select in selectable.selects
135-
):
136-
raise SqlglotError(f"Cannot find column '{column}' in query.")
149+
if not isinstance(selectable, exp.Selectable):
150+
raise SqlglotError("Cannot build lineage, sql must be a query")
151+
152+
cache: dict[tuple, Node] = {}
153+
scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] = {}
154+
155+
if column is not None:
156+
column_name = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
157+
if not any(select.alias_or_name == column_name for select in selectable.selects):
158+
raise SqlglotError(f"Cannot find column '{column_name}' in query.")
159+
160+
return to_node(
161+
column_name,
162+
scope,
163+
dialect,
164+
trim_selects=trim_selects,
165+
_cache=cache,
166+
_scope_meta=scope_meta,
167+
on_node=on_node,
168+
)
137169

138-
return to_node(column, scope, dialect, trim_selects=trim_selects, _cache={})
170+
result: dict[str, Node] = {}
171+
for sel in selectable.selects:
172+
name = sel.alias_or_name
173+
if not name:
174+
continue
175+
176+
result[name] = to_node(
177+
name,
178+
scope,
179+
dialect,
180+
trim_selects=trim_selects,
181+
_cache=cache,
182+
_scope_meta=scope_meta,
183+
on_node=on_node,
184+
)
185+
186+
return result
139187

140188

141189
def to_node(
@@ -148,6 +196,8 @@ def to_node(
148196
reference_node_name: str | None = None,
149197
trim_selects: bool = True,
150198
_cache: dict[tuple, Node] | None = None,
199+
_scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] | None = None,
200+
on_node: t.Callable[[Node], None] | None = None,
151201
) -> Node:
152202
cache_key = (column, id(scope), scope_name, source_name, reference_node_name)
153203

@@ -167,10 +217,24 @@ def to_node(
167217
)
168218
select = selectable.selects[column]
169219
else:
170-
select = next(
171-
(select for select in selectable.selects if select.alias_or_name == column),
172-
exp.Star() if selectable.is_star else scope.expression,
173-
)
220+
# Resolving a column to its select scans selectable.selects on every call;
221+
# memoize a per-scope {name: select} map and is_star bit instead.
222+
if _scope_meta is None:
223+
select = next(
224+
(s for s in selectable.selects if s.alias_or_name == column),
225+
exp.Star() if selectable.is_star else scope.expression,
226+
)
227+
else:
228+
scope_id = id(scope)
229+
meta = _scope_meta.get(scope_id)
230+
if meta is None:
231+
select_by_name: dict[str, exp.Expr] = {}
232+
for sel in selectable.selects:
233+
select_by_name.setdefault(sel.alias_or_name, sel)
234+
meta = (selectable.is_star, select_by_name)
235+
_scope_meta[scope_id] = meta
236+
is_star, select_by_name = meta
237+
select = select_by_name.get(column, exp.Star() if is_star else scope.expression)
174238

175239
if isinstance(scope.expression, exp.Subquery):
176240
for inner_scope in scope.subquery_scopes:
@@ -183,6 +247,8 @@ def to_node(
183247
reference_node_name=reference_node_name,
184248
trim_selects=trim_selects,
185249
_cache=_cache,
250+
_scope_meta=_scope_meta,
251+
on_node=on_node,
186252
)
187253
# Skip caching a passed-in upstream returned by an inner SetOp:
188254
# a sibling call at the same key with that node as its upstream
@@ -221,10 +287,14 @@ def to_node(
221287
reference_node_name=reference_node_name,
222288
trim_selects=trim_selects,
223289
_cache=_cache,
290+
_scope_meta=_scope_meta,
291+
on_node=on_node,
224292
)
225293

226294
if _cache is not None and created_setop:
227295
_cache[cache_key] = upstream
296+
if created_setop and on_node:
297+
on_node(upstream)
228298
return upstream
229299

230300
if trim_selects and isinstance(scope.expression, exp.Select):
@@ -266,15 +336,18 @@ def to_node(
266336
upstream=node,
267337
trim_selects=trim_selects,
268338
_cache=_cache,
339+
_scope_meta=_scope_meta,
340+
on_node=on_node,
269341
)
270342

271343
# if the select is a star add all scope sources as downstreams
272344
if isinstance(select, exp.Star):
273345
for src in scope.sources.values():
274346
src_expr = src.expression if isinstance(src, Scope) else src
275-
node.downstream.append(
276-
Node(name=select.sql(comments=False), source=src_expr, expression=src_expr)
277-
)
347+
star_node = Node(name=select.sql(comments=False), source=src_expr, expression=src_expr)
348+
node.downstream.append(star_node)
349+
if on_node:
350+
on_node(star_node)
278351

279352
# Find all columns that went into creating this one to list their lineage nodes.
280353
source_columns = set(find_all_in_scope(select, exp.Column))
@@ -340,6 +413,8 @@ def to_node(
340413
reference_node_name=reference_node_name,
341414
trim_selects=trim_selects,
342415
_cache=_cache,
416+
_scope_meta=_scope_meta,
417+
on_node=on_node,
343418
)
344419
elif pivot and pivot.alias_or_name == c.table:
345420
downstream_columns = []
@@ -369,29 +444,36 @@ def to_node(
369444
reference_node_name=reference_node_name,
370445
trim_selects=trim_selects,
371446
_cache=_cache,
447+
_scope_meta=_scope_meta,
448+
on_node=on_node,
372449
)
373450
else:
374451
col_expr = col_source or exp.Placeholder()
375-
node.downstream.append(
376-
Node(
377-
name=downstream_column.sql(comments=False),
378-
source=col_expr,
379-
expression=col_expr,
380-
)
452+
pivot_leaf = Node(
453+
name=downstream_column.sql(comments=False),
454+
source=col_expr,
455+
expression=col_expr,
381456
)
457+
node.downstream.append(pivot_leaf)
458+
if on_node:
459+
on_node(pivot_leaf)
382460
else:
383461
# The source is not a scope and the column is not in any pivot - we've reached the end
384462
# of the line. At this point, if a source is not found it means this column's lineage
385463
# is unknown. This can happen if the definition of a source used in a query is not
386464
# passed into the `sources` map.
387465
col_expr = col_source or exp.Placeholder()
388-
node.downstream.append(
389-
Node(name=c.sql(comments=False), source=col_expr, expression=col_expr)
390-
)
466+
leaf = Node(name=c.sql(comments=False), source=col_expr, expression=col_expr)
467+
node.downstream.append(leaf)
468+
if on_node:
469+
on_node(leaf)
391470

392471
if _cache is not None:
393472
_cache[cache_key] = node
394473

474+
if on_node:
475+
on_node(node)
476+
395477
return node
396478

397479

sqlglotc/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _source_files(src_dir):
4646
"errors.py",
4747
"generator.py",
4848
"helper.py",
49+
"lineage.py",
4950
"parser.py",
5051
"schema.py",
5152
"serde.py",

0 commit comments

Comments
 (0)