Skip to content

Commit 27d5f07

Browse files
authored
Refactor reference path handling and update circular import resolution to include path mappings (#2621)
1 parent 70bf516 commit 27d5f07

4 files changed

Lines changed: 33 additions & 14 deletions

File tree

src/datamodel_code_generator/model/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,12 @@ def path(self) -> str:
578578
"""Get the full reference path for this model."""
579579
return self.reference.path
580580

581+
def set_reference_path(self, new_path: str) -> None:
582+
"""Set reference path and clear cached path property."""
583+
self.reference.path = new_path
584+
if "path" in self.__dict__:
585+
del self.__dict__["path"]
586+
581587
def render(self, *, class_name: str | None = None) -> str:
582588
"""Render the model to a string using the template."""
583589
return self._render(

src/datamodel_code_generator/parser/base.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,20 +1968,24 @@ def __rename_and_relocate_scc_models( # noqa: PLR6301
19681968
model_to_original_module: dict[int, tuple[str, ...]],
19691969
internal_module: tuple[str, ...],
19701970
internal_path: Path,
1971-
) -> defaultdict[tuple[str, ...], list[tuple[str, str]]]:
1971+
) -> tuple[defaultdict[tuple[str, ...], list[tuple[str, str]]], dict[str, str]]:
19721972
"""Rename duplicate classes and relocate models to internal module.
19731973
19741974
Returns:
1975-
Mapping from original module to list of (original_name, new_name) tuples.
1975+
Tuple of:
1976+
- Mapping from original module to list of (original_name, new_name) tuples.
1977+
- Mapping from old reference paths to new reference paths.
19761978
"""
19771979
class_name_counts = Counter(model.class_name for model in all_scc_models)
19781980
class_name_seen: dict[str, int] = {}
19791981
internal_module_str = ".".join(internal_module)
19801982
module_class_mappings: defaultdict[tuple[str, ...], list[tuple[str, str]]] = defaultdict(list)
1983+
path_mapping: dict[str, str] = {}
19811984

19821985
for model in all_scc_models:
19831986
original_class_name = model.class_name
19841987
original_module = model_to_original_module[id(model)]
1988+
old_path = model.path # Save old path before updating
19851989

19861990
if class_name_counts[original_class_name] > 1:
19871991
seen_count = class_name_seen.get(original_class_name, 0)
@@ -1991,12 +1995,14 @@ def __rename_and_relocate_scc_models( # noqa: PLR6301
19911995
new_class_name = original_class_name
19921996

19931997
model.reference.name = new_class_name
1994-
model.reference.path = f"{internal_module_str}.{new_class_name}"
1998+
new_path = f"{internal_module_str}.{new_class_name}"
1999+
model.set_reference_path(new_path)
19952000
model.file_path = internal_path
19962001

19972002
module_class_mappings[original_module].append((original_class_name, new_class_name))
2003+
path_mapping[old_path] = new_path
19982004

1999-
return module_class_mappings
2005+
return module_class_mappings, path_mapping
20002006

20012007
def __build_module_dependency_graph( # noqa: PLR6301
20022008
self,
@@ -2031,13 +2037,14 @@ def add_cross_module_edge(ref_path: str, source_module: tuple[str, ...]) -> None
20312037

20322038
return graph
20332039

2034-
def __resolve_circular_imports(
2040+
def __resolve_circular_imports( # noqa: PLR0914
20352041
self,
20362042
module_models_list: list[tuple[tuple[str, ...], list[DataModel]]],
20372043
) -> tuple[
20382044
list[tuple[tuple[str, ...], list[DataModel]]],
20392045
set[tuple[str, ...]],
20402046
dict[tuple[str, ...], tuple[tuple[str, ...], list[tuple[str, str]]]],
2047+
dict[str, str],
20412048
]:
20422049
"""Resolve circular imports by merging all SCCs into _internal.py modules.
20432050
@@ -2050,15 +2057,17 @@ def __resolve_circular_imports(
20502057
- Updated module_models_list with models moved to _internal modules
20512058
- Set of _internal modules created
20522059
- Forwarder map: original_module -> (internal_module, [(original_name, new_name)])
2060+
- Path mapping: old_reference_path -> new_reference_path
20532061
"""
20542062
graph = self.__build_module_dependency_graph(module_models_list)
20552063

20562064
circular_sccs = find_circular_sccs(graph)
20572065

20582066
forwarder_map: dict[tuple[str, ...], tuple[tuple[str, ...], list[tuple[str, str]]]] = {}
2067+
all_path_mappings: dict[str, str] = {}
20592068

20602069
if not circular_sccs:
2061-
return module_models_list, set(), forwarder_map
2070+
return module_models_list, set(), forwarder_map, all_path_mappings
20622071

20632072
# All circular SCCs are problematic and should be merged into _internal.py
20642073
# to break the import cycles.
@@ -2077,9 +2086,10 @@ def __resolve_circular_imports(
20772086
internal_path = Path("/".join(internal_module))
20782087

20792088
all_scc_models, model_to_original_module = self.__collect_scc_models(scc, result_modules)
2080-
module_class_mappings = self.__rename_and_relocate_scc_models(
2089+
module_class_mappings, path_mapping = self.__rename_and_relocate_scc_models(
20812090
all_scc_models, model_to_original_module, internal_module, internal_path
20822091
)
2092+
all_path_mappings.update(path_mapping)
20832093

20842094
for scc_module in scc:
20852095
if scc_module in result_modules: # pragma: no branch
@@ -2099,7 +2109,7 @@ def __resolve_circular_imports(
20992109
if module not in internal_modules_created: # pragma: no branch
21002110
new_module_models.append((module, result_modules.get(module, [])))
21012111

2102-
return new_module_models, internal_modules_created, forwarder_map
2112+
return new_module_models, internal_modules_created, forwarder_map, all_path_mappings
21032113

21042114
def parse( # noqa: PLR0912, PLR0913, PLR0914, PLR0915, PLR0917
21052115
self,
@@ -2177,7 +2187,11 @@ def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
21772187
module_models.insert(0, shared_module_entry)
21782188

21792189
# Resolve circular imports by moving models to _internal.py modules
2180-
module_models, internal_modules, forwarder_map = self.__resolve_circular_imports(module_models)
2190+
module_models, internal_modules, forwarder_map, path_mapping = self.__resolve_circular_imports(module_models)
2191+
2192+
# Update require_update_action_models with new paths for relocated models
2193+
if path_mapping:
2194+
require_update_action_models[:] = [path_mapping.get(path, path) for path in require_update_action_models]
21812195

21822196
class Processed(NamedTuple):
21832197
module: tuple[str, ...]

tests/data/expected/main/openapi/circular_imports_with_inheritance/_internal.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
from pydantic import BaseModel
1010

11-
from . import BaseEntity as BaseEntity_1
12-
1311

1412
class BaseEntity(BaseModel):
1513
id: Optional[str] = None
@@ -21,17 +19,17 @@ class RootModel(BaseModel):
2119
auth: Optional[Authorization] = None
2220

2321

24-
class Invoice(BaseEntity_1):
22+
class Invoice(BaseEntity):
2523
total: Optional[int] = None
2624
session: Optional[Session] = None
2725

2826

29-
class Session(BaseEntity_1):
27+
class Session(BaseEntity):
3028
status: Optional[str] = None
3129
root_ref: Optional[RootModel] = None
3230

3331

34-
class Authorization(BaseEntity_1):
32+
class Authorization(BaseEntity):
3533
amount: Optional[int] = None
3634
invoice: Optional[Invoice] = None
3735

tests/data/expected/main/openapi/modular_reuse_model/_internal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ class ListModel(BaseModel):
6262

6363

6464
Tea_1.update_forward_refs()
65+
TeaClone.update_forward_refs()

0 commit comments

Comments
 (0)