Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/datamodel_code_generator/model/pydantic_v2/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from datamodel_code_generator.model.pydantic.base_model import (
DataModelField as DataModelFieldV1,
)
from datamodel_code_generator.model.pydantic.imports import IMPORT_FIELD
from datamodel_code_generator.model.pydantic_v2.imports import IMPORT_BASE_MODEL, IMPORT_CONFIG_DICT
from datamodel_code_generator.types import chain_as_tuple
from datamodel_code_generator.util import field_validator, model_validate, model_validator
Expand Down Expand Up @@ -174,14 +175,23 @@ def _process_annotated_field_arguments( # noqa: PLR6301
) -> list[str]:
return field_arguments

def _has_discriminator_in_data_type(self) -> bool:
"""Check if any nested DataType has a discriminator."""
return any(dt.discriminator for dt in self.data_type.all_data_types)

@property
def imports(self) -> tuple[Import, ...]:
"""Get all required imports including AliasChoices if needed."""
"""Get all required imports including AliasChoices and Field for discriminator."""
base_imports = super().imports
extra_imports: list[Import] = []
if self.validation_aliases:
from datamodel_code_generator.model.pydantic_v2.imports import IMPORT_ALIAS_CHOICES # noqa: PLC0415

return chain_as_tuple(base_imports, (IMPORT_ALIAS_CHOICES,))
extra_imports.append(IMPORT_ALIAS_CHOICES)
if self._has_discriminator_in_data_type():
extra_imports.append(IMPORT_FIELD)
if extra_imports:
return chain_as_tuple(base_imports, tuple(extra_imports))
return base_imports


Expand Down
22 changes: 11 additions & 11 deletions src/datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,10 @@ def __reuse_model_tree_scope(
self.__validate_shared_module_name(module_models)
return self.__create_shared_module_from_duplicates(module_models, duplicates, require_update_action_models)

def _is_pydantic_v2_model(self) -> bool:
"""Check if the output model type is Pydantic v2."""
return self.data_model_type.__module__.startswith("datamodel_code_generator.model.pydantic_v2")

def __collapse_root_models( # noqa: PLR0912, PLR0914, PLR0915
self,
models: list[DataModel],
Expand Down Expand Up @@ -1819,19 +1823,15 @@ def __collapse_root_models( # noqa: PLR0912, PLR0914, PLR0915
model_field.constraints = ConstraintsBase.merge_constraints(
root_type_field.constraints, model_field.constraints
)
if ( # pragma: no cover
isinstance(
root_type_field,
pydantic_model.DataModelField,
discriminator = root_type_field.extras.get("discriminator")
if discriminator and isinstance(root_type_field, pydantic_model.DataModelField):
prop_name = (
discriminator.get("propertyName") if isinstance(discriminator, dict) else discriminator
)
and not model_field.extras.get("discriminator")
and not any(t.is_list for t in model_field.data_type.data_types)
):
discriminator = root_type_field.extras.get("discriminator")
if discriminator:
model_field.extras["discriminator"] = discriminator
if self._is_pydantic_v2_model():
copied_data_type.discriminator = prop_name
assert isinstance(data_type.parent, DataType)
data_type.parent.data_types.remove(data_type) # pragma: no cover
data_type.parent.data_types.remove(data_type)
data_type.parent.data_types.append(copied_data_type)

elif isinstance(data_type.parent, DataType):
Expand Down
6 changes: 5 additions & 1 deletion src/datamodel_code_generator/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from datamodel_code_generator.imports import (
IMPORT_ABC_MAPPING,
IMPORT_ABC_SEQUENCE,
IMPORT_ANNOTATED,
IMPORT_ANY,
IMPORT_DICT,
IMPORT_FROZEN_SET,
Expand Down Expand Up @@ -454,6 +455,7 @@ class Config:
dict_key: Optional[DataType] = None # noqa: UP045
treat_dot_as_module: bool = False
use_serialize_as_any: bool = False
discriminator: Optional[str] = None # noqa: UP045

_exclude_fields: ClassVar[set[str]] = {"parent", "children"}
_pass_fields: ClassVar[set[str]] = {"parent", "children", "data_types", "reference"}
Expand Down Expand Up @@ -621,11 +623,11 @@ def imports(self) -> Iterator[Import]:
if self.import_:
yield self.import_

# Define required imports based on type features and conditions
imports: tuple[tuple[bool, Import], ...] = (
(self.is_optional and not self.use_union_operator, IMPORT_OPTIONAL),
(len(self.data_types) > 1 and not self.use_union_operator, IMPORT_UNION),
(bool(self.literals) or bool(self.enum_member_literals), IMPORT_LITERAL),
(bool(self.discriminator), IMPORT_ANNOTATED),
)

imports = (
Expand Down Expand Up @@ -748,6 +750,8 @@ def type_hint(self) -> str: # noqa: PLR0912, PLR0915
type_ = UNION_OPERATOR_DELIMITER.join(data_types)
else:
type_ = f"{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]"
if self.discriminator:
type_ = f"Annotated[{type_}, Field(discriminator={self.discriminator!r})]"
elif len(self.data_types) == 1:
type_ = self.data_types[0].type_hint
elif self.enum_member_literals:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# generated by datamodel-codegen:
# filename: discriminator_in_array.yaml
# timestamp: 2023-07-27T00:00:00+00:00

from __future__ import annotations

from enum import Enum
from typing import Annotated, Literal

from pydantic import BaseModel, Field


class Type(Enum):
my_first_object = 'my_first_object'
my_second_object = 'my_second_object'


class ObjectBase(BaseModel):
name: str | None = Field(None, description='Name of the object')
type: Literal['type1'] = Field(..., description='Object type')


class CreateObjectRequest(ObjectBase):
name: str = Field(..., description='Name of the object')
type: Literal['type2'] = Field(..., description='Object type')


class UpdateObjectRequest(ObjectBase):
type: Literal['type3']


class Demo(BaseModel):
myArray: list[
Annotated[
ObjectBase | CreateObjectRequest | UpdateObjectRequest,
Field(discriminator='type'),
]
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# generated by datamodel-codegen:
# filename: discriminator_in_array_underscore.yaml
# timestamp: 2023-07-27T00:00:00+00:00

from __future__ import annotations

from enum import Enum
from typing import Annotated, Literal

from pydantic import BaseModel, Field


class AttrType(Enum):
string_val = 'string_val'


class StringAttr(BaseModel):
attr_type: Literal['string_val']
value: str


class AttrType1(Enum):
number_val = 'number_val'


class NumberAttr(BaseModel):
attr_type: Literal['number_val']
value: float


class Container(BaseModel):
attributes: list[
Annotated[StringAttr | NumberAttr, Field(discriminator='attr_type')]
]
43 changes: 43 additions & 0 deletions tests/data/openapi/discriminator_in_array_underscore.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
openapi: "3.0.0"
components:
schemas:
StringAttr:
type: object
required:
- attr_type
- value
properties:
attr_type:
type: string
enum:
- string_val
value:
type: string
NumberAttr:
type: object
required:
- attr_type
- value
properties:
attr_type:
type: string
enum:
- number_val
value:
type: number
Container:
type: object
required:
- attributes
properties:
attributes:
type: array
items:
oneOf:
- $ref: "#/components/schemas/StringAttr"
- $ref: "#/components/schemas/NumberAttr"
discriminator:
propertyName: attr_type
mapping:
string_val: "#/components/schemas/StringAttr"
number_val: "#/components/schemas/NumberAttr"
45 changes: 31 additions & 14 deletions tests/main/openapi/test_main_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
if TYPE_CHECKING:
from pytest_mock import MockerFixture

SKIP_PYDANTIC_V1 = pytest.mark.skipif(
pydantic.VERSION < "2.0.0",
reason="This test requires Pydantic v2",
)


@pytest.mark.benchmark
def test_main(output_file: Path) -> None:
Expand Down Expand Up @@ -2580,26 +2585,24 @@ def test_main_openapi_discriminator(input_: str, output: str, output_file: Path)

@freeze_time("2023-07-27")
@pytest.mark.parametrize(
("kind", "option", "expected"),
("kind", "option", "output_model", "expected"),
[
(
"anyOf",
"--collapse-root-models",
"in_array_collapse_root_models.py",
),
(
"oneOf",
"--collapse-root-models",
"in_array_collapse_root_models.py",
),
("anyOf", None, "in_array.py"),
("oneOf", None, "in_array.py"),
("anyOf", "--collapse-root-models", None, "in_array_collapse_root_models.py"),
("oneOf", "--collapse-root-models", None, "in_array_collapse_root_models.py"),
("anyOf", None, None, "in_array.py"),
("oneOf", None, None, "in_array.py"),
("anyOf", "--collapse-root-models", "pydantic_v2.BaseModel", "in_array_collapse_root_models_pydantic_v2.py"),
("oneOf", "--collapse-root-models", "pydantic_v2.BaseModel", "in_array_collapse_root_models_pydantic_v2.py"),
],
)
def test_main_openapi_discriminator_in_array(kind: str, option: str | None, expected: str, output_file: Path) -> None:
def test_main_openapi_discriminator_in_array(
kind: str, option: str | None, output_model: str | None, expected: str, output_file: Path
) -> None:
"""Test OpenAPI generation with discriminator in array."""
input_file = f"discriminator_in_array_{kind.lower()}.yaml"
extra_args = [option] if option else []
if output_model:
extra_args.extend(["--output-model-type", output_model])
run_main_and_assert(
input_path=OPEN_API_DATA_PATH / input_file,
output_path=output_file,
Expand All @@ -2611,6 +2614,20 @@ def test_main_openapi_discriminator_in_array(kind: str, option: str | None, expe
)


@freeze_time("2023-07-27")
@SKIP_PYDANTIC_V1
def test_main_openapi_discriminator_in_array_underscore(output_file: Path) -> None:
"""Test discriminator with underscore property name generates valid Pydantic v2 code."""
run_main_and_assert(
input_path=OPEN_API_DATA_PATH / "discriminator_in_array_underscore.yaml",
output_path=output_file,
input_file_type="openapi",
assert_func=assert_file_content,
expected_file="discriminator/in_array_underscore_pydantic_v2.py",
extra_args=["--output-model-type", "pydantic_v2.BaseModel", "--collapse-root-models"],
)


@pytest.mark.parametrize(
("output_model", "expected_output"),
[
Expand Down
Loading