Skip to content

Commit ccf8794

Browse files
authored
Fix Pydantic v2 discriminated unions in array fields (#2907)
* Fix Pydantic v2 discriminated unions in array fields * Add e2e tests for discriminator in array * Remove complex runtime test that fails in CI
1 parent 3db6667 commit ccf8794

7 files changed

Lines changed: 174 additions & 28 deletions

File tree

src/datamodel_code_generator/model/pydantic_v2/base_model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from datamodel_code_generator.model.pydantic.base_model import (
2424
DataModelField as DataModelFieldV1,
2525
)
26+
from datamodel_code_generator.model.pydantic.imports import IMPORT_FIELD
2627
from datamodel_code_generator.model.pydantic_v2.imports import IMPORT_BASE_MODEL, IMPORT_CONFIG_DICT
2728
from datamodel_code_generator.types import chain_as_tuple
2829
from datamodel_code_generator.util import field_validator, model_validate, model_validator
@@ -174,14 +175,23 @@ def _process_annotated_field_arguments( # noqa: PLR6301
174175
) -> list[str]:
175176
return field_arguments
176177

178+
def _has_discriminator_in_data_type(self) -> bool:
179+
"""Check if any nested DataType has a discriminator."""
180+
return any(dt.discriminator for dt in self.data_type.all_data_types)
181+
177182
@property
178183
def imports(self) -> tuple[Import, ...]:
179-
"""Get all required imports including AliasChoices if needed."""
184+
"""Get all required imports including AliasChoices and Field for discriminator."""
180185
base_imports = super().imports
186+
extra_imports: list[Import] = []
181187
if self.validation_aliases:
182188
from datamodel_code_generator.model.pydantic_v2.imports import IMPORT_ALIAS_CHOICES # noqa: PLC0415
183189

184-
return chain_as_tuple(base_imports, (IMPORT_ALIAS_CHOICES,))
190+
extra_imports.append(IMPORT_ALIAS_CHOICES)
191+
if self._has_discriminator_in_data_type():
192+
extra_imports.append(IMPORT_FIELD)
193+
if extra_imports:
194+
return chain_as_tuple(base_imports, tuple(extra_imports))
185195
return base_imports
186196

187197

src/datamodel_code_generator/parser/base.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,10 @@ def __reuse_model_tree_scope(
17021702
self.__validate_shared_module_name(module_models)
17031703
return self.__create_shared_module_from_duplicates(module_models, duplicates, require_update_action_models)
17041704

1705+
def _is_pydantic_v2_model(self) -> bool:
1706+
"""Check if the output model type is Pydantic v2."""
1707+
return self.data_model_type.__module__.startswith("datamodel_code_generator.model.pydantic_v2")
1708+
17051709
def __collapse_root_models( # noqa: PLR0912, PLR0914, PLR0915
17061710
self,
17071711
models: list[DataModel],
@@ -1819,19 +1823,15 @@ def __collapse_root_models( # noqa: PLR0912, PLR0914, PLR0915
18191823
model_field.constraints = ConstraintsBase.merge_constraints(
18201824
root_type_field.constraints, model_field.constraints
18211825
)
1822-
if ( # pragma: no cover
1823-
isinstance(
1824-
root_type_field,
1825-
pydantic_model.DataModelField,
1826+
discriminator = root_type_field.extras.get("discriminator")
1827+
if discriminator and isinstance(root_type_field, pydantic_model.DataModelField):
1828+
prop_name = (
1829+
discriminator.get("propertyName") if isinstance(discriminator, dict) else discriminator
18261830
)
1827-
and not model_field.extras.get("discriminator")
1828-
and not any(t.is_list for t in model_field.data_type.data_types)
1829-
):
1830-
discriminator = root_type_field.extras.get("discriminator")
1831-
if discriminator:
1832-
model_field.extras["discriminator"] = discriminator
1831+
if self._is_pydantic_v2_model():
1832+
copied_data_type.discriminator = prop_name
18331833
assert isinstance(data_type.parent, DataType)
1834-
data_type.parent.data_types.remove(data_type) # pragma: no cover
1834+
data_type.parent.data_types.remove(data_type)
18351835
data_type.parent.data_types.append(copied_data_type)
18361836

18371837
elif isinstance(data_type.parent, DataType):

src/datamodel_code_generator/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from datamodel_code_generator.imports import (
4141
IMPORT_ABC_MAPPING,
4242
IMPORT_ABC_SEQUENCE,
43+
IMPORT_ANNOTATED,
4344
IMPORT_ANY,
4445
IMPORT_DICT,
4546
IMPORT_FROZEN_SET,
@@ -454,6 +455,7 @@ class Config:
454455
dict_key: Optional[DataType] = None # noqa: UP045
455456
treat_dot_as_module: bool = False
456457
use_serialize_as_any: bool = False
458+
discriminator: Optional[str] = None # noqa: UP045
457459

458460
_exclude_fields: ClassVar[set[str]] = {"parent", "children"}
459461
_pass_fields: ClassVar[set[str]] = {"parent", "children", "data_types", "reference"}
@@ -621,11 +623,11 @@ def imports(self) -> Iterator[Import]:
621623
if self.import_:
622624
yield self.import_
623625

624-
# Define required imports based on type features and conditions
625626
imports: tuple[tuple[bool, Import], ...] = (
626627
(self.is_optional and not self.use_union_operator, IMPORT_OPTIONAL),
627628
(len(self.data_types) > 1 and not self.use_union_operator, IMPORT_UNION),
628629
(bool(self.literals) or bool(self.enum_member_literals), IMPORT_LITERAL),
630+
(bool(self.discriminator), IMPORT_ANNOTATED),
629631
)
630632

631633
imports = (
@@ -748,6 +750,8 @@ def type_hint(self) -> str: # noqa: PLR0912, PLR0915
748750
type_ = UNION_OPERATOR_DELIMITER.join(data_types)
749751
else:
750752
type_ = f"{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]"
753+
if self.discriminator:
754+
type_ = f"Annotated[{type_}, Field(discriminator={self.discriminator!r})]"
751755
elif len(self.data_types) == 1:
752756
type_ = self.data_types[0].type_hint
753757
elif self.enum_member_literals:
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_in_array.yaml
3+
# timestamp: 2023-07-27T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from enum import Enum
8+
from typing import Annotated, Literal
9+
10+
from pydantic import BaseModel, Field
11+
12+
13+
class Type(Enum):
14+
my_first_object = 'my_first_object'
15+
my_second_object = 'my_second_object'
16+
17+
18+
class ObjectBase(BaseModel):
19+
name: str | None = Field(None, description='Name of the object')
20+
type: Literal['type1'] = Field(..., description='Object type')
21+
22+
23+
class CreateObjectRequest(ObjectBase):
24+
name: str = Field(..., description='Name of the object')
25+
type: Literal['type2'] = Field(..., description='Object type')
26+
27+
28+
class UpdateObjectRequest(ObjectBase):
29+
type: Literal['type3']
30+
31+
32+
class Demo(BaseModel):
33+
myArray: list[
34+
Annotated[
35+
ObjectBase | CreateObjectRequest | UpdateObjectRequest,
36+
Field(discriminator='type'),
37+
]
38+
]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_in_array_underscore.yaml
3+
# timestamp: 2023-07-27T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from enum import Enum
8+
from typing import Annotated, Literal
9+
10+
from pydantic import BaseModel, Field
11+
12+
13+
class AttrType(Enum):
14+
string_val = 'string_val'
15+
16+
17+
class StringAttr(BaseModel):
18+
attr_type: Literal['string_val']
19+
value: str
20+
21+
22+
class AttrType1(Enum):
23+
number_val = 'number_val'
24+
25+
26+
class NumberAttr(BaseModel):
27+
attr_type: Literal['number_val']
28+
value: float
29+
30+
31+
class Container(BaseModel):
32+
attributes: list[
33+
Annotated[StringAttr | NumberAttr, Field(discriminator='attr_type')]
34+
]
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
openapi: "3.0.0"
2+
components:
3+
schemas:
4+
StringAttr:
5+
type: object
6+
required:
7+
- attr_type
8+
- value
9+
properties:
10+
attr_type:
11+
type: string
12+
enum:
13+
- string_val
14+
value:
15+
type: string
16+
NumberAttr:
17+
type: object
18+
required:
19+
- attr_type
20+
- value
21+
properties:
22+
attr_type:
23+
type: string
24+
enum:
25+
- number_val
26+
value:
27+
type: number
28+
Container:
29+
type: object
30+
required:
31+
- attributes
32+
properties:
33+
attributes:
34+
type: array
35+
items:
36+
oneOf:
37+
- $ref: "#/components/schemas/StringAttr"
38+
- $ref: "#/components/schemas/NumberAttr"
39+
discriminator:
40+
propertyName: attr_type
41+
mapping:
42+
string_val: "#/components/schemas/StringAttr"
43+
number_val: "#/components/schemas/NumberAttr"

tests/main/openapi/test_main_openapi.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@
4848
if TYPE_CHECKING:
4949
from pytest_mock import MockerFixture
5050

51+
SKIP_PYDANTIC_V1 = pytest.mark.skipif(
52+
pydantic.VERSION < "2.0.0",
53+
reason="This test requires Pydantic v2",
54+
)
55+
5156

5257
@pytest.mark.benchmark
5358
def test_main(output_file: Path) -> None:
@@ -2580,26 +2585,24 @@ def test_main_openapi_discriminator(input_: str, output: str, output_file: Path)
25802585

25812586
@freeze_time("2023-07-27")
25822587
@pytest.mark.parametrize(
2583-
("kind", "option", "expected"),
2588+
("kind", "option", "output_model", "expected"),
25842589
[
2585-
(
2586-
"anyOf",
2587-
"--collapse-root-models",
2588-
"in_array_collapse_root_models.py",
2589-
),
2590-
(
2591-
"oneOf",
2592-
"--collapse-root-models",
2593-
"in_array_collapse_root_models.py",
2594-
),
2595-
("anyOf", None, "in_array.py"),
2596-
("oneOf", None, "in_array.py"),
2590+
("anyOf", "--collapse-root-models", None, "in_array_collapse_root_models.py"),
2591+
("oneOf", "--collapse-root-models", None, "in_array_collapse_root_models.py"),
2592+
("anyOf", None, None, "in_array.py"),
2593+
("oneOf", None, None, "in_array.py"),
2594+
("anyOf", "--collapse-root-models", "pydantic_v2.BaseModel", "in_array_collapse_root_models_pydantic_v2.py"),
2595+
("oneOf", "--collapse-root-models", "pydantic_v2.BaseModel", "in_array_collapse_root_models_pydantic_v2.py"),
25972596
],
25982597
)
2599-
def test_main_openapi_discriminator_in_array(kind: str, option: str | None, expected: str, output_file: Path) -> None:
2598+
def test_main_openapi_discriminator_in_array(
2599+
kind: str, option: str | None, output_model: str | None, expected: str, output_file: Path
2600+
) -> None:
26002601
"""Test OpenAPI generation with discriminator in array."""
26012602
input_file = f"discriminator_in_array_{kind.lower()}.yaml"
26022603
extra_args = [option] if option else []
2604+
if output_model:
2605+
extra_args.extend(["--output-model-type", output_model])
26032606
run_main_and_assert(
26042607
input_path=OPEN_API_DATA_PATH / input_file,
26052608
output_path=output_file,
@@ -2611,6 +2614,20 @@ def test_main_openapi_discriminator_in_array(kind: str, option: str | None, expe
26112614
)
26122615

26132616

2617+
@freeze_time("2023-07-27")
2618+
@SKIP_PYDANTIC_V1
2619+
def test_main_openapi_discriminator_in_array_underscore(output_file: Path) -> None:
2620+
"""Test discriminator with underscore property name generates valid Pydantic v2 code."""
2621+
run_main_and_assert(
2622+
input_path=OPEN_API_DATA_PATH / "discriminator_in_array_underscore.yaml",
2623+
output_path=output_file,
2624+
input_file_type="openapi",
2625+
assert_func=assert_file_content,
2626+
expected_file="discriminator/in_array_underscore_pydantic_v2.py",
2627+
extra_args=["--output-model-type", "pydantic_v2.BaseModel", "--collapse-root-models"],
2628+
)
2629+
2630+
26142631
@pytest.mark.parametrize(
26152632
("output_model", "expected_output"),
26162633
[

0 commit comments

Comments
 (0)