Skip to content

Commit 2659b2a

Browse files
authored
Fix circular imports in generated multi-module packages (#2613)
* Fix circular imports in generated multi-module packages * Add unittest * Fix circular imports and enhance type checking in generated modules * Refactor * Fix circular imports and improve type checking in relative import handling * Fix circular imports and enhance type checking for relative imports with current module handling * Refactor import handling to use centralized dump_all method for __all__ declarations * Fix circular imports and streamline export handling in import generation * Add tests and models for all exports scope with local models * Fix circular imports handling and refine export imports logic in base.py * Fix circular imports and enhance type checking for internal module references in base.py and boo.py * Refactor imports in boo.py to resolve circular dependencies and enhance type checking
1 parent 9c1ba5d commit 2659b2a

129 files changed

Lines changed: 3049 additions & 1320 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/datamodel_code_generator/imports.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, use_exact: bool = False) -> None: # noqa: FBT001, FBT002
5252
self.counter: dict[tuple[str | None, str], int] = defaultdict(int)
5353
self.reference_paths: dict[str, Import] = {}
5454
self.use_exact: bool = use_exact
55+
self._exports: set[str] | None = None
5556

5657
def _set_alias(self, from_: str | None, imports: set[str]) -> list[str]:
5758
"""Apply aliases to imports and return sorted list."""
@@ -127,6 +128,31 @@ def extract_future(self) -> Imports:
127128
future.alias[future_key] = self.alias.pop(future_key)
128129
return future
129130

131+
def add_export(self, name: str) -> None:
132+
"""Add a name to export without importing it (for local definitions)."""
133+
if self._exports is None:
134+
self._exports = set()
135+
self._exports.add(name)
136+
137+
def dump_all(self, *, multiline: bool = False) -> str:
138+
"""Generate __all__ declaration from imported names and added exports.
139+
140+
Args:
141+
multiline: If True, format with one name per line
142+
143+
Returns:
144+
Formatted __all__ = [...] string
145+
"""
146+
name_set: set[str] = (self._exports or set()).copy()
147+
for from_, imports in self.items():
148+
name_set.update(self.alias.get(from_, {}).get(import_) or import_ for import_ in imports)
149+
name_list = sorted(name_set)
150+
if multiline:
151+
items = ",\n ".join(f'"{name}"' for name in name_list)
152+
return f"__all__ = [\n {items},\n]"
153+
items = ", ".join(f'"{name}"' for name in name_list)
154+
return f"__all__ = [{items}]"
155+
130156

131157
IMPORT_ANNOTATED = Import.from_full_path("typing.Annotated")
132158
IMPORT_ANY = Import.from_full_path("typing.Any")
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""Strongly Connected Components detection using Tarjan's algorithm.
2+
3+
Provides SCC detection for module dependency graphs to identify
4+
circular import patterns in generated code.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from enum import IntEnum
10+
from typing import NamedTuple, TypeAlias
11+
12+
ModulePath: TypeAlias = tuple[str, ...]
13+
ModuleGraph: TypeAlias = dict[ModulePath, set[ModulePath]]
14+
SCC: TypeAlias = set[ModulePath]
15+
SCCList: TypeAlias = list[SCC]
16+
17+
_EMPTY_SET: frozenset[ModulePath] = frozenset()
18+
19+
20+
class _Phase(IntEnum):
21+
"""DFS traversal phase for iterative Tarjan's algorithm."""
22+
23+
VISIT = 0
24+
POSTVISIT = 1
25+
26+
27+
class _Frame(NamedTuple):
28+
"""Call stack frame for iterative DFS."""
29+
30+
node: ModulePath
31+
neighbor_idx: int
32+
phase: _Phase
33+
34+
35+
class _TarjanState:
36+
"""Mutable state for Tarjan's SCC algorithm."""
37+
38+
__slots__ = ("graph", "index", "index_counter", "lowlinks", "on_stack", "result", "sorted_cache", "stack")
39+
40+
def __init__(self, graph: ModuleGraph) -> None:
41+
self.graph = graph
42+
self.index_counter: int = 0
43+
self.stack: list[ModulePath] = []
44+
self.lowlinks: dict[ModulePath, int] = {}
45+
self.index: dict[ModulePath, int] = {}
46+
self.on_stack: set[ModulePath] = set()
47+
self.result: SCCList = []
48+
self.sorted_cache: dict[ModulePath, list[ModulePath]] = {}
49+
50+
def get_sorted_neighbors(self, node: ModulePath) -> list[ModulePath]:
51+
"""Get sorted neighbors with lazy memoization."""
52+
cached: list[ModulePath] | None = self.sorted_cache.get(node)
53+
if cached is None:
54+
cached = sorted(self.graph.get(node, _EMPTY_SET))
55+
self.sorted_cache[node] = cached
56+
return cached
57+
58+
def extract_scc(self, root: ModulePath) -> None:
59+
"""Pop nodes from stack to form an SCC rooted at the given node."""
60+
scc: SCC = set()
61+
while True:
62+
w: ModulePath = self.stack.pop()
63+
self.on_stack.remove(w)
64+
scc.add(w)
65+
if w == root: # pragma: no branch
66+
break
67+
self.result.append(scc)
68+
69+
def initialize_node(self, node: ModulePath) -> None:
70+
"""Initialize a node for DFS traversal."""
71+
self.index[node] = self.lowlinks[node] = self.index_counter
72+
self.index_counter += 1
73+
self.stack.append(node)
74+
self.on_stack.add(node)
75+
76+
77+
def _strongconnect(state: _TarjanState, start: ModulePath) -> None:
78+
"""Execute Tarjan's strongconnect algorithm iteratively."""
79+
state.initialize_node(start)
80+
call_stack: list[_Frame] = [_Frame(start, 0, _Phase.VISIT)]
81+
82+
while call_stack:
83+
frame: _Frame = call_stack.pop()
84+
node: ModulePath = frame.node
85+
neighbors: list[ModulePath] = state.get_sorted_neighbors(node)
86+
neighbor_idx: int = frame.neighbor_idx
87+
88+
# Handle post-visit: update lowlink from child
89+
if frame.phase == _Phase.POSTVISIT:
90+
child: ModulePath = neighbors[neighbor_idx]
91+
state.lowlinks[node] = min(state.lowlinks[node], state.lowlinks[child])
92+
neighbor_idx += 1
93+
94+
# Process remaining neighbors
95+
while neighbor_idx < len(neighbors):
96+
w: ModulePath = neighbors[neighbor_idx]
97+
98+
if w not in state.index:
99+
# Save state for post-visit
100+
call_stack.append(_Frame(node, neighbor_idx, _Phase.POSTVISIT))
101+
# Initialize and push unvisited neighbor
102+
state.initialize_node(w)
103+
call_stack.append(_Frame(w, 0, _Phase.VISIT))
104+
break
105+
if w in state.on_stack:
106+
state.lowlinks[node] = min(state.lowlinks[node], state.index[w])
107+
108+
neighbor_idx += 1
109+
else:
110+
# All neighbors processed: check if node is SCC root
111+
if state.lowlinks[node] == state.index[node]:
112+
state.extract_scc(node)
113+
114+
115+
def strongly_connected_components(graph: ModuleGraph) -> SCCList:
116+
"""Find all strongly connected components using Tarjan's algorithm.
117+
118+
Uses an iterative approach to avoid Python recursion limits on large graphs.
119+
Neighbors are lazily sorted and memoized for determinism with O(E log V) cost.
120+
121+
Args:
122+
graph: Adjacency list mapping module tuple to set of dependency module tuples.
123+
Each node is a tuple like ("pkg", "__init__.py") or ("pkg", "module.py").
124+
125+
Returns:
126+
List of all SCCs, each being a set of module tuples.
127+
SCCs are returned in reverse topological order (leaves first).
128+
Includes all SCCs, including singleton nodes without self-loops.
129+
"""
130+
# Collect all nodes (including those only referenced as edges)
131+
all_nodes: set[ModulePath] = set(graph.keys())
132+
for neighbors in graph.values():
133+
all_nodes.update(neighbors)
134+
135+
state = _TarjanState(graph)
136+
137+
# Run algorithm on all unvisited nodes (sorted for determinism)
138+
for node in sorted(all_nodes):
139+
if node not in state.index:
140+
_strongconnect(state, node)
141+
142+
return state.result
143+
144+
145+
def find_circular_sccs(graph: ModuleGraph) -> SCCList:
146+
"""Find SCCs that represent circular dependencies.
147+
148+
A circular SCC is one with:
149+
- More than one node, OR
150+
- Exactly one node with a self-loop (edge to itself)
151+
152+
Args:
153+
graph: Module dependency graph
154+
155+
Returns:
156+
List of circular SCCs, sorted by their minimum element for determinism
157+
"""
158+
all_sccs: SCCList = strongly_connected_components(graph)
159+
circular: SCCList = []
160+
161+
for scc in all_sccs:
162+
if len(scc) > 1:
163+
circular.append(scc)
164+
elif len(scc) == 1: # pragma: no branch
165+
node: ModulePath = next(iter(scc))
166+
if node in graph and node in graph[node]:
167+
circular.append(scc)
168+
169+
return sorted(circular, key=min)

0 commit comments

Comments
 (0)