From 525cdfd97964bcbe446682a1dd62f7f0215642a3 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Sat, 20 Dec 2025 14:44:06 +0000 Subject: [PATCH] fix: wrap RootModel primitive defaults with default_factory --- src/datamodel_code_generator/model/_types.py | 18 ++++++++++++++++++ src/datamodel_code_generator/model/base.py | 16 +++------------- .../model/pydantic/base_model.py | 18 ++++++++++++------ .../jsonschema/root_model_default_value.py | 10 ++-------- .../root_model_default_value_branches.py | 3 +-- .../root_model_default_value_no_annotated.py | 8 ++------ .../root_model_default_value_non_root.py | 3 +-- .../pydantic_v2_default_object/Nested.py | 4 +--- .../main/openapi/referenced_default.py | 4 +--- .../referenced_default_use_annotated.py | 3 +-- .../openapi/root_model_default_primitive.py | 15 +++++++++++++++ .../openapi/root_model_default_primitive.yaml | 18 ++++++++++++++++++ tests/main/openapi/test_main_openapi.py | 12 ++++++++++++ 13 files changed, 87 insertions(+), 45 deletions(-) create mode 100644 src/datamodel_code_generator/model/_types.py create mode 100644 tests/data/expected/main/openapi/root_model_default_primitive.py create mode 100644 tests/data/openapi/root_model_default_primitive.yaml diff --git a/src/datamodel_code_generator/model/_types.py b/src/datamodel_code_generator/model/_types.py new file mode 100644 index 000000000..e2b96fecf --- /dev/null +++ b/src/datamodel_code_generator/model/_types.py @@ -0,0 +1,18 @@ +"""Internal types for model module.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(repr=False) +class WrappedDefault: + """Represents a default value wrapped with its type constructor.""" + + value: Any + type_name: str + + def __repr__(self) -> str: + """Return type constructor representation, e.g., 'CountType(10)'.""" + return f"{self.type_name}({self.value!r})" diff --git a/src/datamodel_code_generator/model/base.py b/src/datamodel_code_generator/model/base.py index 1935cfc40..730a7b838 100644 --- a/src/datamodel_code_generator/model/base.py +++ b/src/datamodel_code_generator/model/base.py @@ -10,7 +10,6 @@ from abc import ABC, abstractmethod from collections import defaultdict from copy import deepcopy -from dataclasses import dataclass from functools import cached_property, lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar, Union @@ -26,6 +25,7 @@ IMPORT_UNION, Import, ) +from datamodel_code_generator.model._types import WrappedDefault from datamodel_code_generator.reference import Reference, _BaseModel from datamodel_code_generator.types import ( ANY, @@ -39,6 +39,8 @@ ) from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict +__all__ = ["WrappedDefault"] + if TYPE_CHECKING: from collections.abc import Iterator @@ -113,18 +115,6 @@ def merge_constraints(a: ConstraintsBaseT | None, b: ConstraintsBaseT | None) -> }) -@dataclass(repr=False) -class WrappedDefault: - """Represents a default value wrapped with its type constructor.""" - - value: Any - type_name: str - - def __repr__(self) -> str: - """Return type constructor representation, e.g., 'CountType(10)'.""" - return f"{self.type_name}({self.value!r})" - - class DataModelFieldBase(_BaseModel): """Base class for model field representation and rendering.""" diff --git a/src/datamodel_code_generator/model/pydantic/base_model.py b/src/datamodel_code_generator/model/pydantic/base_model.py index 7cc864ba6..a06e819d5 100644 --- a/src/datamodel_code_generator/model/pydantic/base_model.py +++ b/src/datamodel_code_generator/model/pydantic/base_model.py @@ -17,6 +17,7 @@ DataModel, DataModelFieldBase, ) +from datamodel_code_generator.model._types import WrappedDefault from datamodel_code_generator.model.base import UNDEFINED from datamodel_code_generator.model.pydantic.imports import ( IMPORT_ANYURL, @@ -122,6 +123,8 @@ def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any return int(value) def _get_default_as_pydantic_model(self) -> str | None: + if isinstance(self.default, WrappedDefault): + return f"lambda :{self.default!r}" for data_type in self.data_type.data_types or (self.data_type,): # TODO: Check nested data_types if data_type.is_dict: @@ -141,15 +144,18 @@ def _get_default_as_pydantic_model(self) -> str | None: f"{self._PARSE_METHOD}(v) for v in {self.default!r}]" ) elif data_type.reference and isinstance(data_type.reference.source, BaseModelBase): + source = data_type.reference.source + is_root_model = hasattr(source, "BASE_CLASS") and source.BASE_CLASS == "pydantic.RootModel" if self.data_type.is_union: if not isinstance(self.default, (dict, list)): + if not is_root_model: + continue + elif isinstance(self.default, dict) and any(dt.is_dict for dt in self.data_type.data_types): continue - if isinstance(self.default, dict) and any(dt.is_dict for dt in self.data_type.data_types): - continue - return ( - f"lambda :{data_type.alias or data_type.reference.source.class_name}." - f"{self._PARSE_METHOD}({self.default!r})" - ) + class_name = data_type.alias or source.class_name + if is_root_model: + return f"lambda :{class_name}({self.default!r})" + return f"lambda :{class_name}.{self._PARSE_METHOD}({self.default!r})" return None def _process_data_in_str(self, data: dict[str, Any]) -> None: diff --git a/tests/data/expected/main/jsonschema/root_model_default_value.py b/tests/data/expected/main/jsonschema/root_model_default_value.py index 5983a84a6..7eae0e095 100644 --- a/tests/data/expected/main/jsonschema/root_model_default_value.py +++ b/tests/data/expected/main/jsonschema/root_model_default_value.py @@ -25,13 +25,7 @@ class NameType(RootModel[str]): class Model(BaseModel): admin_state: AdminStateLeaf | None = AdminStateLeaf.enable - count: Annotated[ - CountType | None, - Field(default_factory=lambda: CountType.model_validate(CountType(10))), - ] + count: Annotated[CountType | None, Field(default_factory=lambda: CountType(10))] name: Annotated[ - NameType | None, - Field( - default_factory=lambda: NameType.model_validate(NameType('default_name')) - ), + NameType | None, Field(default_factory=lambda: NameType('default_name')) ] diff --git a/tests/data/expected/main/jsonschema/root_model_default_value_branches.py b/tests/data/expected/main/jsonschema/root_model_default_value_branches.py index bfce1d352..d32dc962d 100644 --- a/tests/data/expected/main/jsonschema/root_model_default_value_branches.py +++ b/tests/data/expected/main/jsonschema/root_model_default_value_branches.py @@ -15,8 +15,7 @@ class CountType(RootModel[int]): class Model(BaseModel): count_with_default: Annotated[ - CountType | None, - Field(default_factory=lambda: CountType.model_validate(CountType(10))), + CountType | None, Field(default_factory=lambda: CountType(10)) ] count_no_default: CountType | None = None count_list_default: Annotated[ diff --git a/tests/data/expected/main/jsonschema/root_model_default_value_no_annotated.py b/tests/data/expected/main/jsonschema/root_model_default_value_no_annotated.py index 443995bf0..f43f1e381 100644 --- a/tests/data/expected/main/jsonschema/root_model_default_value_no_annotated.py +++ b/tests/data/expected/main/jsonschema/root_model_default_value_no_annotated.py @@ -24,9 +24,5 @@ class NameType(RootModel[constr(min_length=1, max_length=50)]): class Model(BaseModel): admin_state: AdminStateLeaf | None = AdminStateLeaf.enable - count: CountType | None = Field( - default_factory=lambda: CountType.model_validate(10) - ) - name: NameType | None = Field( - default_factory=lambda: NameType.model_validate('default_name') - ) + count: CountType | None = Field(default_factory=lambda: CountType(10)) + name: NameType | None = Field(default_factory=lambda: NameType('default_name')) diff --git a/tests/data/expected/main/jsonschema/root_model_default_value_non_root.py b/tests/data/expected/main/jsonschema/root_model_default_value_non_root.py index b81324b1f..c142c6d64 100644 --- a/tests/data/expected/main/jsonschema/root_model_default_value_non_root.py +++ b/tests/data/expected/main/jsonschema/root_model_default_value_non_root.py @@ -19,8 +19,7 @@ class PersonType(BaseModel): class Model(BaseModel): root_model_field: Annotated[ - CountType | None, - Field(default_factory=lambda: CountType.model_validate(CountType(10))), + CountType | None, Field(default_factory=lambda: CountType(10)) ] non_root_model_field: Annotated[ PersonType | None, diff --git a/tests/data/expected/main/openapi/pydantic_v2_default_object/Nested.py b/tests/data/expected/main/openapi/pydantic_v2_default_object/Nested.py index afca84576..f952a7d7e 100644 --- a/tests/data/expected/main/openapi/pydantic_v2_default_object/Nested.py +++ b/tests/data/expected/main/openapi/pydantic_v2_default_object/Nested.py @@ -23,6 +23,4 @@ class Bar(BaseModel): for v in [{'text': 'abc', 'number': 123}, {'text': 'efg', 'number': 456}] ] ) - nested_foo: Foo | None = Field( - default_factory=lambda: Foo.model_validate('default foo') - ) + nested_foo: Foo | None = Field(default_factory=lambda: Foo('default foo')) diff --git a/tests/data/expected/main/openapi/referenced_default.py b/tests/data/expected/main/openapi/referenced_default.py index 117d137c7..257d7c0b0 100644 --- a/tests/data/expected/main/openapi/referenced_default.py +++ b/tests/data/expected/main/openapi/referenced_default.py @@ -13,6 +13,4 @@ class ModelSettingB(RootModel[confloat(ge=0.0, le=10.0)]): class Model(BaseModel): settingA: confloat(ge=0.0, le=10.0) | None = 5 - settingB: ModelSettingB | None = Field( - default_factory=lambda: ModelSettingB.model_validate(5) - ) + settingB: ModelSettingB | None = Field(default_factory=lambda: ModelSettingB(5)) diff --git a/tests/data/expected/main/openapi/referenced_default_use_annotated.py b/tests/data/expected/main/openapi/referenced_default_use_annotated.py index 85afdc793..e2b6696d5 100644 --- a/tests/data/expected/main/openapi/referenced_default_use_annotated.py +++ b/tests/data/expected/main/openapi/referenced_default_use_annotated.py @@ -16,6 +16,5 @@ class ModelSettingB(RootModel[float]): class Model(BaseModel): settingA: Annotated[float | None, Field(ge=0.0, le=10.0)] = 5 settingB: Annotated[ - ModelSettingB | None, - Field(default_factory=lambda: ModelSettingB.model_validate(ModelSettingB(5))), + ModelSettingB | None, Field(default_factory=lambda: ModelSettingB(5)) ] diff --git a/tests/data/expected/main/openapi/root_model_default_primitive.py b/tests/data/expected/main/openapi/root_model_default_primitive.py new file mode 100644 index 000000000..063dc91df --- /dev/null +++ b/tests/data/expected/main/openapi/root_model_default_primitive.py @@ -0,0 +1,15 @@ +# generated by datamodel-codegen: +# filename: root_model_default_primitive.yaml +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from pydantic import BaseModel, Field, RootModel, conint + + +class Timeout(RootModel[conint(le=14400, gt=0)]): + root: conint(le=14400, gt=0) + + +class CrawlConfiguration(BaseModel): + timeout: Timeout | None = Field(default_factory=lambda: Timeout(3600)) diff --git a/tests/data/openapi/root_model_default_primitive.yaml b/tests/data/openapi/root_model_default_primitive.yaml new file mode 100644 index 000000000..21960d6e6 --- /dev/null +++ b/tests/data/openapi/root_model_default_primitive.yaml @@ -0,0 +1,18 @@ +openapi: 3.1.0 +info: + title: Test + version: 0.1.0 +paths: {} +components: + schemas: + Timeout: + type: integer + maximum: 14400 + exclusiveMinimum: 0 + CrawlConfiguration: + properties: + timeout: + anyOf: + - $ref: "#/components/schemas/Timeout" + - type: "null" + default: 3600 diff --git a/tests/main/openapi/test_main_openapi.py b/tests/main/openapi/test_main_openapi.py index 04c8beef9..666ca25ab 100644 --- a/tests/main/openapi/test_main_openapi.py +++ b/tests/main/openapi/test_main_openapi.py @@ -3242,6 +3242,18 @@ def test_main_openapi_referenced_default_use_annotated(output_file: Path) -> Non ) +def test_main_openapi_root_model_default_primitive(output_file: Path) -> None: + """Test RootModel with primitive default value in union type.""" + run_main_and_assert( + input_path=OPEN_API_DATA_PATH / "root_model_default_primitive.yaml", + output_path=output_file, + input_file_type="openapi", + assert_func=assert_file_content, + expected_file="root_model_default_primitive.py", + extra_args=["--output-model-type", "pydantic_v2.BaseModel"], + ) + + @pytest.mark.cli_doc( options=["--parent-scoped-naming"], input_schema="openapi/duplicate_models2.yaml",