Skip to content

Commit 89f54b1

Browse files
Extract helper methods and reduce code duplication across parsers (#2637)
* Refactor DataModel methods to reduce duplication and improve clarity * Refactor schema parsing methods to reduce duplication and improve readability * Refactor data type handling to reduce duplication and improve clarity * Refactor data type replacement logic to improve clarity and reduce duplication * Refactor data type replacement logic to improve clarity and reduce duplication * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor methods to reduce duplication and improve clarity * Refactor get_dedup_key method to include use_default parameter for improved deduplication logic * Refactor condition checks and remove redundant comments for improved clarity --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 13e6fb1 commit 89f54b1

12 files changed

Lines changed: 425 additions & 387 deletions

File tree

src/datamodel_code_generator/model/base.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,23 @@ def process_const(self) -> None:
181181
self.required = False
182182
self.nullable = False
183183

184+
def _process_const_as_literal(self) -> None:
185+
"""Process const values by converting to literal type. Used by subclasses."""
186+
if "const" not in self.extras:
187+
return
188+
const = self.extras["const"]
189+
self.const = True
190+
self.nullable = False
191+
self.replace_data_type(self.data_type.__class__(literals=[const]), clear_old_parent=False)
192+
if not self.default:
193+
self.default = const
194+
195+
def self_reference(self) -> bool:
196+
"""Check if field references its parent model."""
197+
if self.parent is None or not self.parent.reference: # pragma: no cover
198+
return False
199+
return self.parent.reference.path in {d.reference.path for d in self.data_type.all_data_types if d.reference}
200+
184201
@property
185202
def type_hint(self) -> str: # noqa: PLR0911
186203
"""Get the type hint string for this field, including nullability."""
@@ -299,6 +316,20 @@ def copy_deep(self) -> Self:
299316
copied.data_type.data_types = [dt.copy() for dt in self.data_type.data_types]
300317
return copied
301318

319+
def replace_data_type(self, new_data_type: DataType, *, clear_old_parent: bool = True) -> None:
320+
"""Replace data_type and update parent relationships.
321+
322+
Args:
323+
new_data_type: The new DataType to set.
324+
clear_old_parent: If True, clear the old data_type's parent reference.
325+
Set to False when the old data_type may be referenced elsewhere.
326+
"""
327+
if self.data_type.parent is self and clear_old_parent:
328+
self.data_type.swap_with(new_data_type)
329+
else:
330+
self.data_type = new_data_type
331+
new_data_type.parent = self
332+
302333

303334
@lru_cache
304335
def get_template(template_file_path: Path) -> Template:
@@ -369,7 +400,7 @@ class BaseClassDataType(DataType):
369400
UNDEFINED: Any = object()
370401

371402

372-
class DataModel(TemplateBase, Nullable, ABC):
403+
class DataModel(TemplateBase, Nullable, ABC): # noqa: PLR0904
373404
"""Abstract base class for all data model types.
374405
375406
Handles template rendering, import collection, and model relationships.
@@ -483,6 +514,34 @@ def iter_all_fields(self, visited: set[str] | None = None) -> Iterator[DataModel
483514
yield from base_class.reference.source.iter_all_fields(visited)
484515
yield from self.fields
485516

517+
def get_dedup_key(self, class_name: str | None = None, *, use_default: bool = True) -> tuple[Any, ...]:
518+
"""Generate hashable key for model deduplication."""
519+
from datamodel_code_generator.parser.base import to_hashable # noqa: PLC0415
520+
521+
render_class_name = class_name if class_name is not None or not use_default else "M"
522+
return tuple(to_hashable(v) for v in (self.render(class_name=render_class_name), self.imports))
523+
524+
def create_reuse_model(self, base_ref: Reference) -> Self:
525+
"""Create inherited model with empty fields pointing to base reference."""
526+
return self.__class__(
527+
fields=[],
528+
base_classes=[base_ref],
529+
description=self.description,
530+
reference=Reference(
531+
name=self.name,
532+
path=self.reference.path + "/reuse",
533+
),
534+
custom_template_dir=self._custom_template_dir,
535+
)
536+
537+
def replace_children_in_models(self, models: list[DataModel], new_ref: Reference) -> None:
538+
"""Replace reference children if their parent model is in models list."""
539+
from datamodel_code_generator.parser.base import get_most_of_parent # noqa: PLC0415
540+
541+
for child in self.reference.children[:]:
542+
if isinstance(child, DataType) and get_most_of_parent(child) in models:
543+
child.replace_reference(new_ref)
544+
486545
def set_base_class(self) -> None:
487546
"""Set up the base class for this model."""
488547
base_class = self.custom_base_class or self.BASE_CLASS

src/datamodel_code_generator/model/dataclass.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,8 @@ class DataModelField(DataModelFieldBase):
106106
constraints: Optional[Constraints] = None # noqa: UP045
107107

108108
def process_const(self) -> None:
109-
"""Process const field constraint."""
110-
if "const" not in self.extras:
111-
return
112-
self.const = True
113-
self.nullable = False
114-
const = self.extras["const"]
115-
self.data_type = self.data_type.__class__(literals=[const])
116-
if not self.default:
117-
self.default = const
109+
"""Process const field constraint using literal type."""
110+
self._process_const_as_literal()
118111

119112
@property
120113
def imports(self) -> tuple[Import, ...]:
@@ -124,12 +117,6 @@ def imports(self) -> tuple[Import, ...]:
124117
return chain_as_tuple(super().imports, (IMPORT_FIELD,))
125118
return super().imports
126119

127-
def self_reference(self) -> bool: # pragma: no cover
128-
"""Check if field references its parent dataclass."""
129-
return isinstance(self.parent, DataClass) and self.parent.reference.path in {
130-
d.reference.path for d in self.data_type.all_data_types if d.reference
131-
}
132-
133120
@property
134121
def field(self) -> str | None:
135122
"""For backwards compatibility."""

src/datamodel_code_generator/model/msgspec.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,6 @@ class DataModelField(DataModelFieldBase):
221221
_COMPARE_EXPRESSIONS: ClassVar[set[str]] = {"gt", "ge", "lt", "le", "multiple_of"}
222222
constraints: Optional[Constraints] = None # noqa: UP045
223223

224-
def self_reference(self) -> bool: # pragma: no cover
225-
"""Check if field references its parent Struct."""
226-
return isinstance(self.parent, Struct) and self.parent.reference.path in {
227-
d.reference.path for d in self.data_type.all_data_types if d.reference
228-
}
229-
230224
def process_const(self) -> None:
231225
"""Process const field constraint."""
232226
if "const" not in self.extras:
@@ -235,7 +229,7 @@ def process_const(self) -> None:
235229
self.nullable = False
236230
const = self.extras["const"]
237231
if self.data_type.type == "str" and isinstance(const, str): # pragma: no cover # Literal supports only str
238-
self.data_type = self.data_type.__class__(literals=[const])
232+
self.replace_data_type(self.data_type.__class__(literals=[const]), clear_old_parent=False)
239233

240234
def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any:
241235
"""Get constraint value with appropriate numeric type."""

src/datamodel_code_generator/model/pydantic/base_model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,6 @@ def field(self) -> str | None:
102102
return None
103103
return result
104104

105-
def self_reference(self) -> bool:
106-
"""Check if this field references its parent model."""
107-
return isinstance(self.parent, BaseModelBase) and self.parent.reference.path in {
108-
d.reference.path for d in self.data_type.all_data_types if d.reference
109-
}
110-
111105
def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any:
112106
if value is None or constraint not in self._COMPARE_EXPRESSIONS:
113107
return value

src/datamodel_code_generator/model/pydantic_v2/base_model.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,8 @@ def validate_extras(cls, values: Any) -> dict[str, Any]: # noqa: N805
125125
return values
126126

127127
def process_const(self) -> None:
128-
"""Process const field by converting it to a literal type with default value."""
129-
if "const" not in self.extras:
130-
return
131-
self.const = True
132-
self.nullable = False
133-
const = self.extras["const"]
134-
self.data_type = self.data_type.__class__(literals=[const])
135-
if not self.default:
136-
self.default = const
128+
"""Process const field constraint using literal type."""
129+
self._process_const_as_literal()
137130

138131
def _process_data_in_str(self, data: dict[str, Any]) -> None:
139132
if self.const:

src/datamodel_code_generator/model/typed_dict.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,8 @@ class DataModelField(DataModelFieldBase):
118118
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_NOT_REQUIRED,)
119119

120120
def process_const(self) -> None:
121-
"""Process const field constraint."""
122-
if "const" not in self.extras:
123-
return
124-
self.const = True
125-
self.nullable = False
126-
const = self.extras["const"]
127-
self.data_type = self.data_type.__class__(literals=[const])
128-
if not self.default:
129-
self.default = const
121+
"""Process const field constraint using literal type."""
122+
self._process_const_as_literal()
130123

131124
@property
132125
def key(self) -> str:

0 commit comments

Comments
 (0)