Skip to content

Commit 0906812

Browse files
authored
Refactor parser base post-processing for DRY and type-safe implementation (#2730)
1 parent ae11c41 commit 0906812

8 files changed

Lines changed: 371 additions & 220 deletions

File tree

src/datamodel_code_generator/imports.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,33 @@ def dump_all(self, *, multiline: bool = False) -> str:
168168
items = ", ".join(f'"{name}"' for name in name_list)
169169
return f"__all__ = [{items}]"
170170

171+
def get_effective_name(self, from_: str | None, import_: str) -> str:
172+
"""Get the effective name after alias resolution."""
173+
return self.alias.get(from_, {}).get(import_, import_)
174+
175+
def remove_unused(self, used_names: set[str]) -> None:
176+
"""Remove imports not referenced in used_names.
177+
178+
Note: Checks both effective name (after alias) and original name to handle
179+
cases where code may reference either form (e.g., type annotations may use
180+
original name while runtime code uses alias).
181+
"""
182+
unused = [
183+
(from_, import_)
184+
for from_, imports_ in self.items()
185+
for import_ in imports_
186+
if not {self.get_effective_name(from_, import_), import_}.intersection(used_names)
187+
]
188+
for from_, import_ in unused:
189+
alias = self.alias.get(from_, {}).get(import_)
190+
reference_path = next(
191+
(p for p, i in self.reference_paths.items() if i.from_ == from_ and i.import_ == import_),
192+
None,
193+
)
194+
import_obj = Import(from_=from_, import_=import_, alias=alias, reference_path=reference_path)
195+
while self.counter.get((from_, import_), 0) > 0:
196+
self.remove(import_obj)
197+
171198

172199
IMPORT_ANNOTATED = Import.from_full_path("typing.Annotated")
173200
IMPORT_ANY = Import.from_full_path("typing.Any")

src/datamodel_code_generator/model/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,10 @@ class DataModel(TemplateBase, Nullable, ABC): # noqa: PLR0904
478478
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = ()
479479
IS_ALIAS: ClassVar[bool] = False
480480
SUPPORTS_GENERIC_BASE_CLASS: ClassVar[bool] = True
481+
SUPPORTS_DISCRIMINATOR: ClassVar[bool] = False
482+
SUPPORTS_FIELD_RENAMING: ClassVar[bool] = False
483+
SUPPORTS_WRAPPED_DEFAULT: ClassVar[bool] = False
484+
SUPPORTS_KW_ONLY: ClassVar[bool] = False
481485
has_forward_reference: bool = False
482486

483487
def __init__( # noqa: PLR0913

src/datamodel_code_generator/model/dataclass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class DataClass(DataModel):
4242

4343
TEMPLATE_FILE_PATH: ClassVar[str] = "dataclass.jinja2"
4444
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_DATACLASS,)
45+
SUPPORTS_DISCRIMINATOR: ClassVar[bool] = True
46+
SUPPORTS_KW_ONLY: ClassVar[bool] = True
4547

4648
def __init__( # noqa: PLR0913
4749
self,

src/datamodel_code_generator/model/msgspec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class Struct(DataModel):
113113
BASE_CLASS_NAME: ClassVar[str] = "Struct"
114114
BASE_CLASS_ALIAS: ClassVar[str] = "_Struct"
115115
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = ()
116+
SUPPORTS_DISCRIMINATOR: ClassVar[bool] = True
116117
CONFIG_MAPPING: ClassVar[dict[tuple[str, Any], tuple[str, Any] | None]] = {
117118
("allow_mutation", False): ("frozen", True),
118119
("extra_fields", "forbid"): ("forbid_unknown_fields", True),

src/datamodel_code_generator/model/pydantic/base_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ class BaseModel(BaseModelBase):
328328

329329
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic/BaseModel.jinja2"
330330
BASE_CLASS: ClassVar[str] = "pydantic.BaseModel"
331+
SUPPORTS_DISCRIMINATOR: ClassVar[bool] = True
331332

332333
def __init__( # noqa: PLR0912, PLR0913
333334
self,

src/datamodel_code_generator/model/pydantic_v2/base_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class BaseModel(BaseModelBase):
174174
BASE_CLASS: ClassVar[str] = "pydantic.BaseModel"
175175
BASE_CLASS_NAME: ClassVar[str] = "BaseModel"
176176
BASE_CLASS_ALIAS: ClassVar[str] = "_BaseModel"
177+
SUPPORTS_DISCRIMINATOR: ClassVar[bool] = True
178+
SUPPORTS_FIELD_RENAMING: ClassVar[bool] = True
179+
SUPPORTS_WRAPPED_DEFAULT: ClassVar[bool] = True
177180
CONFIG_ATTRIBUTES: ClassVar[list[ConfigAttribute]] = [
178181
ConfigAttribute("allow_population_by_field_name", "populate_by_name", False), # noqa: FBT003
179182
ConfigAttribute("populate_by_name", "populate_by_name", False), # noqa: FBT003

0 commit comments

Comments
 (0)