Skip to content

Commit 96506f6

Browse files
authored
Refactor default handling for Union and dict types in model generation (#2595)
1 parent 61b38e9 commit 96506f6

7 files changed

Lines changed: 160 additions & 5 deletions

File tree

src/datamodel_code_generator/model/msgspec.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def _get_default_as_struct_model(self) -> str | None:
332332
"""Convert default value to Struct model using msgspec convert."""
333333
for data_type in self.data_type.data_types or (self.data_type,):
334334
# TODO: Check nested data_types
335-
if data_type.is_dict or self.data_type.is_union:
336-
# TODO: Parse Union and dict model for default
335+
if data_type.is_dict:
336+
# TODO: Parse dict model for default
337337
continue # pragma: no cover
338338
if data_type.is_list and len(data_type.data_types) == 1:
339339
data_type_child = data_type.data_types[0]
@@ -347,6 +347,11 @@ def _get_default_as_struct_model(self) -> str | None:
347347
f"type=list[{data_type_child.alias or data_type_child.reference.source.class_name}])"
348348
)
349349
elif data_type.reference and isinstance(data_type.reference.source, Struct):
350+
if self.data_type.is_union:
351+
if not isinstance(self.default, (dict, list)):
352+
continue
353+
if isinstance(self.default, dict) and any(dt.is_dict for dt in self.data_type.data_types):
354+
continue
350355
return (
351356
f"lambda: {self._PARSE_METHOD}({self.default!r}, "
352357
f"type={data_type.alias or data_type.reference.source.class_name})"

src/datamodel_code_generator/model/pydantic/base_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any
114114
def _get_default_as_pydantic_model(self) -> str | None:
115115
for data_type in self.data_type.data_types or (self.data_type,):
116116
# TODO: Check nested data_types
117-
if data_type.is_dict or self.data_type.is_union:
118-
# TODO: Parse Union and dict model for default
117+
if data_type.is_dict:
118+
# TODO: Parse dict model for default
119119
continue
120120
if data_type.is_list and len(data_type.data_types) == 1:
121121
data_type_child = data_type.data_types[0]
@@ -128,7 +128,12 @@ def _get_default_as_pydantic_model(self) -> str | None:
128128
f"lambda :[{data_type_child.alias or data_type_child.reference.source.class_name}."
129129
f"{self._PARSE_METHOD}(v) for v in {self.default!r}]"
130130
)
131-
elif data_type.reference and isinstance(data_type.reference.source, BaseModelBase): # pragma: no cover
131+
elif data_type.reference and isinstance(data_type.reference.source, BaseModelBase):
132+
if self.data_type.is_union:
133+
if not isinstance(self.default, (dict, list)):
134+
continue
135+
if isinstance(self.default, dict) and any(dt.is_dict for dt in self.data_type.data_types):
136+
continue
132137
return (
133138
f"lambda :{data_type.alias or data_type.reference.source.class_name}."
134139
f"{self._PARSE_METHOD}({self.default!r})"
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# generated by datamodel-codegen:
2+
# filename: union_default_object.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Dict, Union
8+
9+
from msgspec import UNSET, Struct, UnsetType, convert, field
10+
11+
12+
class Interval(Struct):
13+
start: Union[int, UnsetType] = UNSET
14+
end: Union[int, UnsetType] = UNSET
15+
16+
17+
class Container(Struct):
18+
interval_or_string: Union[Interval, str, UnsetType] = field(
19+
default_factory=lambda: convert({'start': 2009, 'end': 2019}, type=Interval)
20+
)
21+
string_or_interval: Union[Interval, str, UnsetType] = 'some string value'
22+
dict_or_interval: Union[Dict[str, str], Interval, UnsetType] = {'key': 'value'}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# generated by datamodel-codegen:
2+
# filename: union_default_object.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Dict, Optional, Union
8+
9+
from pydantic import BaseModel, Field
10+
11+
12+
class Interval(BaseModel):
13+
start: Optional[int] = None
14+
end: Optional[int] = None
15+
16+
17+
class Container(BaseModel):
18+
interval_or_string: Optional[Union[Interval, str]] = Field(
19+
default_factory=lambda: Interval.model_validate({'start': 2009, 'end': 2019})
20+
)
21+
string_or_interval: Optional[Union[Interval, str]] = 'some string value'
22+
dict_or_interval: Optional[Union[Dict[str, str], Interval]] = {'key': 'value'}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# generated by datamodel-codegen:
2+
# filename: union_default_object.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Dict, Optional, Union
8+
9+
from pydantic import BaseModel, Field
10+
11+
12+
class Interval(BaseModel):
13+
start: Optional[int] = None
14+
end: Optional[int] = None
15+
16+
17+
class Container(BaseModel):
18+
interval_or_string: Optional[Union[Interval, str]] = Field(
19+
default_factory=lambda: Interval.parse_obj({'start': 2009, 'end': 2019})
20+
)
21+
string_or_interval: Optional[Union[Interval, str]] = 'some string value'
22+
dict_or_interval: Optional[Union[Dict[str, str], Interval]] = {'key': 'value'}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
openapi: 3.1.0
2+
info:
3+
title: Union Default Object Test
4+
version: 0.1.0
5+
servers:
6+
- url: http://example.com
7+
paths:
8+
/test:
9+
get:
10+
responses:
11+
'200':
12+
description: OK
13+
components:
14+
schemas:
15+
Interval:
16+
type: object
17+
properties:
18+
start:
19+
type: integer
20+
end:
21+
type: integer
22+
Container:
23+
type: object
24+
properties:
25+
# Union[Interval, str] with dict default - should use default_factory
26+
interval_or_string:
27+
anyOf:
28+
- $ref: '#/components/schemas/Interval'
29+
- type: string
30+
default:
31+
start: 2009
32+
end: 2019
33+
# Union[Interval, str] with string default - should NOT use default_factory
34+
string_or_interval:
35+
anyOf:
36+
- $ref: '#/components/schemas/Interval'
37+
- type: string
38+
default: "some string value"
39+
# Union[Dict, Interval] with dict default - should NOT use default_factory (dict arm)
40+
dict_or_interval:
41+
anyOf:
42+
- type: object
43+
additionalProperties:
44+
type: string
45+
- $ref: '#/components/schemas/Interval'
46+
default:
47+
key: "value"

tests/main/openapi/test_main_openapi.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,38 @@ def test_main_openapi_default_object(output_model: str, expected_output: str, tm
15771577
)
15781578

15791579

1580+
@pytest.mark.parametrize(
1581+
("output_model", "expected_output"),
1582+
[
1583+
(
1584+
"pydantic.BaseModel",
1585+
"union_default_object.py",
1586+
),
1587+
(
1588+
"pydantic_v2.BaseModel",
1589+
"pydantic_v2_union_default_object.py",
1590+
),
1591+
(
1592+
"msgspec.Struct",
1593+
"msgspec_union_default_object.py",
1594+
),
1595+
],
1596+
)
1597+
@pytest.mark.skipif(
1598+
black.__version__.split(".")[0] == "19",
1599+
reason="Installed black doesn't support the old style",
1600+
)
1601+
def test_main_openapi_union_default_object(output_model: str, expected_output: str, output_file: Path) -> None:
1602+
"""Test OpenAPI generation with Union type default object values."""
1603+
run_main_and_assert(
1604+
input_path=OPEN_API_DATA_PATH / "union_default_object.yaml",
1605+
output_path=output_file,
1606+
expected_file=EXPECTED_OPENAPI_PATH / expected_output,
1607+
input_file_type="openapi",
1608+
extra_args=["--output-model", output_model, "--target-python-version", "3.9", "--openapi-scopes", "schemas"],
1609+
)
1610+
1611+
15801612
def test_main_dataclass(output_file: Path) -> None:
15811613
"""Test OpenAPI generation with dataclass output."""
15821614
run_main_and_assert(

0 commit comments

Comments
 (0)