Skip to content

Commit c7f65e8

Browse files
authored
Use default_factory for $ref with default values when using --use-annotated (#2618)
* Add support for computed default factories in BaseModel fields * Use default factories for model fields in RootModel and non-root models * Refactor handling of computed default factories in BaseModel fields
1 parent ff5de87 commit c7f65e8

10 files changed

Lines changed: 74 additions & 13 deletions

File tree

src/datamodel_code_generator/model/pydantic/base_model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ class DataModelField(DataModelFieldBase):
6969
constraints: Optional[Constraints] = None # noqa: UP045
7070
_PARSE_METHOD: ClassVar[str] = "parse_obj"
7171

72+
@property
73+
def has_default_factory_in_field(self) -> bool:
74+
"""Check if this field has a default_factory in Field() including computed ones."""
75+
return "default_factory" in self.extras or self.__dict__.get("_computed_default_factory") is not None
76+
7277
@property
7378
def method(self) -> str | None:
7479
"""Get the validation method name."""
@@ -200,20 +205,23 @@ def __str__(self) -> str: # noqa: PLR0912
200205
else:
201206
default_factory = data.pop("default_factory", None)
202207

208+
self.__dict__["_computed_default_factory"] = default_factory
209+
203210
field_arguments = sorted(f"{k}={v!r}" for k, v in data.items() if v is not None)
204211

205212
if not field_arguments and not default_factory:
206213
if self.nullable and self.required:
207214
return "Field(...)" # Field() is for mypy
208215
return ""
209216

217+
if default_factory:
218+
field_arguments = [f"default_factory={default_factory}", *field_arguments]
219+
210220
if self.use_annotated:
211221
field_arguments = self._process_annotated_field_arguments(field_arguments)
212222
elif self.required:
213223
field_arguments = ["...", *field_arguments]
214-
elif default_factory:
215-
field_arguments = [f"default_factory={default_factory}", *field_arguments]
216-
else:
224+
elif not default_factory:
217225
field_arguments = [f"{self.default!r}", *field_arguments]
218226

219227
return f"Field({', '.join(field_arguments)})"

src/datamodel_code_generator/model/template/pydantic/BaseModel.jinja2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comme
2424
{%- else %}
2525
{{ field.name }}: {{ field.type_hint }}
2626
{%- endif %}
27-
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
27+
{%- if not field.has_default_factory_in_field and not (field.required or (field.represented_default == 'None' and field.strip_default_none))
2828
%} = {{ field.represented_default }}
2929
{%- endif -%}
3030
{%- endif %}

src/datamodel_code_generator/model/template/pydantic/BaseModel_root.jinja2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comme
2424
{%- else %}
2525
__root__: {{ field.type_hint }}
2626
{%- endif %}
27-
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
27+
{%- if not field.has_default_factory_in_field and not (field.required or (field.represented_default == 'None' and field.strip_default_none))
2828
%} = {{ field.represented_default }}
2929
{%- endif -%}
3030
{%- endif %}

src/datamodel_code_generator/model/template/pydantic_v2/BaseModel.jinja2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comme
3232
{%- else %}
3333
{{ field.name }}: {{ field.type_hint }}
3434
{%- endif %}
35-
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none)) or field.data_type.is_optional
35+
{%- if not field.has_default_factory_in_field and (not (field.required or (field.represented_default == 'None' and field.strip_default_none)) or field.data_type.is_optional)
3636
%} = {{ field.represented_default }}
3737
{%- endif -%}
3838
{%- endif %}

src/datamodel_code_generator/model/template/pydantic_v2/RootModel.jinja2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class {{ class_name }}({{ base_class }}{%- if fields -%}[{{get_type_hint(fields)
3333
{%- else %}
3434
root: {{ field.type_hint }}
3535
{%- endif %}
36-
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
36+
{%- if not field.has_default_factory_in_field and not (field.required or (field.represented_default == 'None' and field.strip_default_none))
3737
%} = {{ field.represented_default }}
3838
{%- endif -%}
3939
{%- endif %}

tests/data/expected/main/jsonschema/root_model_default_value.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,13 @@ class NameType(RootModel[str]):
2525

2626
class Model(BaseModel):
2727
admin_state: Optional[AdminStateLeaf] = AdminStateLeaf.enable
28-
count: Annotated[Optional[CountType], Field()] = CountType(10)
29-
name: Annotated[Optional[NameType], Field()] = NameType('default_name')
28+
count: Annotated[
29+
Optional[CountType],
30+
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
31+
]
32+
name: Annotated[
33+
Optional[NameType],
34+
Field(
35+
default_factory=lambda: NameType.model_validate(NameType('default_name'))
36+
),
37+
]

tests/data/expected/main/jsonschema/root_model_default_value_branches.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ class CountType(RootModel[int]):
1414

1515

1616
class Model(BaseModel):
17-
count_with_default: Annotated[Optional[CountType], Field()] = CountType(10)
17+
count_with_default: Annotated[
18+
Optional[CountType],
19+
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
20+
]
1821
count_no_default: Optional[CountType] = None
19-
count_list_default: Annotated[Optional[List[CountType]], Field()] = [1, 2, 3]
22+
count_list_default: Annotated[
23+
Optional[List[CountType]],
24+
Field(default_factory=lambda: [CountType.model_validate(v) for v in [1, 2, 3]]),
25+
]

tests/data/expected/main/jsonschema/root_model_default_value_non_root.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ class PersonType(BaseModel):
1818

1919

2020
class Model(BaseModel):
21-
root_model_field: Annotated[Optional[CountType], Field()] = CountType(10)
22-
non_root_model_field: Annotated[Optional[PersonType], Field()] = {'name': 'John'}
21+
root_model_field: Annotated[
22+
Optional[CountType],
23+
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
24+
]
25+
non_root_model_field: Annotated[
26+
Optional[PersonType],
27+
Field(default_factory=lambda: PersonType.model_validate({'name': 'John'})),
28+
]
2329
primitive_field: Optional[str] = 'hello'
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# generated by datamodel-codegen:
2+
# filename: referenced_default.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Annotated, Optional
8+
9+
from pydantic import BaseModel, Field, RootModel
10+
11+
12+
class ModelSettingB(RootModel[float]):
13+
root: Annotated[float, Field(ge=0.0, le=10.0)]
14+
15+
16+
class Model(BaseModel):
17+
settingA: Annotated[Optional[float], Field(ge=0.0, le=10.0)] = 5
18+
settingB: Annotated[
19+
Optional[ModelSettingB],
20+
Field(default_factory=lambda: ModelSettingB.model_validate(ModelSettingB(5))),
21+
]

tests/main/openapi/test_main_openapi.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2390,6 +2390,18 @@ def test_main_openapi_referenced_default(output_file: Path) -> None:
23902390
)
23912391

23922392

2393+
def test_main_openapi_referenced_default_use_annotated(output_file: Path) -> None:
2394+
"""Test OpenAPI generation with referenced default values using --use-annotated."""
2395+
run_main_and_assert(
2396+
input_path=OPEN_API_DATA_PATH / "referenced_default.yaml",
2397+
output_path=output_file,
2398+
input_file_type="openapi",
2399+
assert_func=assert_file_content,
2400+
expected_file="referenced_default_use_annotated.py",
2401+
extra_args=["--output-model-type", "pydantic_v2.BaseModel", "--use-annotated"],
2402+
)
2403+
2404+
23932405
def test_duplicate_models(output_file: Path) -> None:
23942406
"""Test OpenAPI generation with duplicate models."""
23952407
run_main_and_assert(

0 commit comments

Comments
 (0)