Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/datamodel_code_generator/model/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Internal types for model module."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any


@dataclass(repr=False)
class WrappedDefault:
"""Represents a default value wrapped with its type constructor."""

value: Any
type_name: str

def __repr__(self) -> str:
"""Return type constructor representation, e.g., 'CountType(10)'."""
return f"{self.type_name}({self.value!r})"
16 changes: 3 additions & 13 deletions src/datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import cached_property, lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar, Union
Expand All @@ -26,6 +25,7 @@
IMPORT_UNION,
Import,
)
from datamodel_code_generator.model._types import WrappedDefault
from datamodel_code_generator.reference import Reference, _BaseModel
from datamodel_code_generator.types import (
ANY,
Expand All @@ -39,6 +39,8 @@
)
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict

__all__ = ["WrappedDefault"]

if TYPE_CHECKING:
from collections.abc import Iterator

Expand Down Expand Up @@ -113,18 +115,6 @@ def merge_constraints(a: ConstraintsBaseT | None, b: ConstraintsBaseT | None) ->
})


@dataclass(repr=False)
class WrappedDefault:
"""Represents a default value wrapped with its type constructor."""

value: Any
type_name: str

def __repr__(self) -> str:
"""Return type constructor representation, e.g., 'CountType(10)'."""
return f"{self.type_name}({self.value!r})"


class DataModelFieldBase(_BaseModel):
"""Base class for model field representation and rendering."""

Expand Down
18 changes: 12 additions & 6 deletions src/datamodel_code_generator/model/pydantic/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DataModel,
DataModelFieldBase,
)
from datamodel_code_generator.model._types import WrappedDefault
from datamodel_code_generator.model.base import UNDEFINED
from datamodel_code_generator.model.pydantic.imports import (
IMPORT_ANYURL,
Expand Down Expand Up @@ -122,6 +123,8 @@ def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any
return int(value)

def _get_default_as_pydantic_model(self) -> str | None:
if isinstance(self.default, WrappedDefault):
return f"lambda :{self.default!r}"
for data_type in self.data_type.data_types or (self.data_type,):
# TODO: Check nested data_types
if data_type.is_dict:
Expand All @@ -141,15 +144,18 @@ def _get_default_as_pydantic_model(self) -> str | None:
f"{self._PARSE_METHOD}(v) for v in {self.default!r}]"
)
elif data_type.reference and isinstance(data_type.reference.source, BaseModelBase):
source = data_type.reference.source
is_root_model = hasattr(source, "BASE_CLASS") and source.BASE_CLASS == "pydantic.RootModel"
if self.data_type.is_union:
if not isinstance(self.default, (dict, list)):
if not is_root_model:
continue
elif isinstance(self.default, dict) and any(dt.is_dict for dt in self.data_type.data_types):
continue
if isinstance(self.default, dict) and any(dt.is_dict for dt in self.data_type.data_types):
continue
return (
f"lambda :{data_type.alias or data_type.reference.source.class_name}."
f"{self._PARSE_METHOD}({self.default!r})"
)
class_name = data_type.alias or source.class_name
if is_root_model:
return f"lambda :{class_name}({self.default!r})"
return f"lambda :{class_name}.{self._PARSE_METHOD}({self.default!r})"
return None

def _process_data_in_str(self, data: dict[str, Any]) -> None:
Expand Down
10 changes: 2 additions & 8 deletions tests/data/expected/main/jsonschema/root_model_default_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@ class NameType(RootModel[str]):

class Model(BaseModel):
admin_state: AdminStateLeaf | None = AdminStateLeaf.enable
count: Annotated[
CountType | None,
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
]
count: Annotated[CountType | None, Field(default_factory=lambda: CountType(10))]
name: Annotated[
NameType | None,
Field(
default_factory=lambda: NameType.model_validate(NameType('default_name'))
),
NameType | None, Field(default_factory=lambda: NameType('default_name'))
]
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ class CountType(RootModel[int]):

class Model(BaseModel):
count_with_default: Annotated[
CountType | None,
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
CountType | None, Field(default_factory=lambda: CountType(10))
]
count_no_default: CountType | None = None
count_list_default: Annotated[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,5 @@ class NameType(RootModel[constr(min_length=1, max_length=50)]):

class Model(BaseModel):
admin_state: AdminStateLeaf | None = AdminStateLeaf.enable
count: CountType | None = Field(
default_factory=lambda: CountType.model_validate(10)
)
name: NameType | None = Field(
default_factory=lambda: NameType.model_validate('default_name')
)
count: CountType | None = Field(default_factory=lambda: CountType(10))
name: NameType | None = Field(default_factory=lambda: NameType('default_name'))
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class PersonType(BaseModel):

class Model(BaseModel):
root_model_field: Annotated[
CountType | None,
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
CountType | None, Field(default_factory=lambda: CountType(10))
]
non_root_model_field: Annotated[
PersonType | None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,4 @@ class Bar(BaseModel):
for v in [{'text': 'abc', 'number': 123}, {'text': 'efg', 'number': 456}]
]
)
nested_foo: Foo | None = Field(
default_factory=lambda: Foo.model_validate('default foo')
)
nested_foo: Foo | None = Field(default_factory=lambda: Foo('default foo'))
4 changes: 1 addition & 3 deletions tests/data/expected/main/openapi/referenced_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,4 @@ class ModelSettingB(RootModel[confloat(ge=0.0, le=10.0)]):

class Model(BaseModel):
settingA: confloat(ge=0.0, le=10.0) | None = 5
settingB: ModelSettingB | None = Field(
default_factory=lambda: ModelSettingB.model_validate(5)
)
settingB: ModelSettingB | None = Field(default_factory=lambda: ModelSettingB(5))
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,5 @@ class ModelSettingB(RootModel[float]):
class Model(BaseModel):
settingA: Annotated[float | None, Field(ge=0.0, le=10.0)] = 5
settingB: Annotated[
ModelSettingB | None,
Field(default_factory=lambda: ModelSettingB.model_validate(ModelSettingB(5))),
ModelSettingB | None, Field(default_factory=lambda: ModelSettingB(5))
]
15 changes: 15 additions & 0 deletions tests/data/expected/main/openapi/root_model_default_primitive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# generated by datamodel-codegen:
# filename: root_model_default_primitive.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from pydantic import BaseModel, Field, RootModel, conint


class Timeout(RootModel[conint(le=14400, gt=0)]):
root: conint(le=14400, gt=0)


class CrawlConfiguration(BaseModel):
timeout: Timeout | None = Field(default_factory=lambda: Timeout(3600))
18 changes: 18 additions & 0 deletions tests/data/openapi/root_model_default_primitive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
openapi: 3.1.0
info:
title: Test
version: 0.1.0
paths: {}
components:
schemas:
Timeout:
type: integer
maximum: 14400
exclusiveMinimum: 0
CrawlConfiguration:
properties:
timeout:
anyOf:
- $ref: "#/components/schemas/Timeout"
- type: "null"
default: 3600
12 changes: 12 additions & 0 deletions tests/main/openapi/test_main_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3242,6 +3242,18 @@ def test_main_openapi_referenced_default_use_annotated(output_file: Path) -> Non
)


def test_main_openapi_root_model_default_primitive(output_file: Path) -> None:
"""Test RootModel with primitive default value in union type."""
run_main_and_assert(
input_path=OPEN_API_DATA_PATH / "root_model_default_primitive.yaml",
output_path=output_file,
input_file_type="openapi",
assert_func=assert_file_content,
expected_file="root_model_default_primitive.py",
extra_args=["--output-model-type", "pydantic_v2.BaseModel"],
)


@pytest.mark.cli_doc(
options=["--parent-scoped-naming"],
input_schema="openapi/duplicate_models2.yaml",
Expand Down
Loading