Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ filterwarnings = [
"ignore:^.*jsonschema.exceptions.RefResolutionError is deprecated as of version 4.18.0. If you wish to catch potential reference resolution errors, directly catch referencing.exceptions.Unresolvable..*",
"ignore:^.*`experimental string processing` has been included in `preview` and deprecated. Use `preview` instead..*",
"ignore:^.*No schemas found in components/schemas.*",
"ignore:^.*Dataclass .* has a field ordering conflict due to inheritance.*:UserWarning",
]
norecursedirs = "tests/data/*"
verbosity_assertions = 2
Expand Down
5 changes: 3 additions & 2 deletions src/datamodel_code_generator/model/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from datamodel_code_generator.reference import Reference


def _has_field_assignment(field: DataModelFieldBase) -> bool:
def has_field_assignment(field: DataModelFieldBase) -> bool:
"""Check if a dataclass field has a default value or field() assignment."""
return bool(field.field) or not (
field.required or (field.represented_default == "None" and field.strip_default_none)
)
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__( # noqa: PLR0913
"""Initialize dataclass with fields sorted by field assignment requirement."""
super().__init__(
reference=reference,
fields=sorted(fields, key=_has_field_assignment),
fields=sorted(fields, key=has_field_assignment),
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
Expand Down
61 changes: 61 additions & 0 deletions src/datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Protocol, TypeVar, cast, runtime_checkable
from urllib.parse import ParseResult
from warnings import warn

from pydantic import BaseModel
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -1798,6 +1799,65 @@ def __set_one_literal_on_default(self, models: list[DataModel]) -> None:
if model_field.nullable is not True: # pragma: no cover
model_field.nullable = False

def __fix_dataclass_field_ordering(self, models: list[DataModel]) -> None:
"""Fix field ordering for dataclasses with inheritance after defaults are set."""
for model in models:
if (inherited := self.__get_dataclass_inherited_info(model)) is None:
continue
inherited_names, has_default = inherited
if not has_default or not any(self.__is_new_required_field(f, inherited_names) for f in model.fields):
continue

if self.target_python_version.has_kw_only_dataclass:
for field in model.fields:
if self.__is_new_required_field(field, inherited_names):
field.extras["kw_only"] = True
else:
warn(
f"Dataclass '{model.class_name}' has a field ordering conflict due to inheritance. "
f"An inherited field has a default value, but new required fields are added. "
f"This will cause a TypeError at runtime. Consider using --target-python-version 3.10 "
f"or higher to enable automatic field(kw_only=True) fix.",
category=UserWarning,
stacklevel=2,
)
model.fields = sorted(model.fields, key=dataclass_model.has_field_assignment)

@classmethod
def __get_dataclass_inherited_info(cls, model: DataModel) -> tuple[set[str], bool] | None:
"""Get inherited field names and whether any has default. Returns None if not applicable."""
if not isinstance(model, dataclass_model.DataClass):
return None
if not model.base_classes or model.dataclass_arguments.get("kw_only"):
return None

inherited_names: set[str] = set()
has_default = False
for base in model.base_classes:
if not base.reference or not isinstance(base.reference.source, DataModel):
continue # pragma: no cover
for f in base.reference.source.iter_all_fields():
if not f.name or f.extras.get("init") is False:
continue # pragma: no cover
inherited_names.add(f.name)
if dataclass_model.has_field_assignment(f):
has_default = True

for f in model.fields:
if f.name not in inherited_names or f.extras.get("init") is False:
continue
if dataclass_model.has_field_assignment(f): # pragma: no branch
has_default = True
return (inherited_names, has_default) if inherited_names else None

def __is_new_required_field(self, field: DataModelFieldBase, inherited: set[str]) -> bool: # noqa: PLR6301
"""Check if field is a new required init field."""
return (
field.name not in inherited
and field.extras.get("init") is not False
and not dataclass_model.has_field_assignment(field)
)

@classmethod
def __update_type_aliases(cls, models: list[DataModel]) -> None:
"""Update type aliases to properly handle forward references per PEP 484."""
Expand Down Expand Up @@ -2463,6 +2523,7 @@ class Processed(NamedTuple):
self.__change_field_name(models)
self.__apply_discriminator_type(models, imports)
self.__set_one_literal_on_default(models)
self.__fix_dataclass_field_ordering(models)

processed_models.append(Processed(module, module_, models, init, imports, scoped_model_resolver))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# generated by datamodel-codegen:
# filename: dataclass_inheritance_field_ordering.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ParentWithDefault:
name: Optional[str] = 'default_name'
read_only_field: Optional[str] = None


@dataclass
class ChildWithRequired(ParentWithDefault):
child_id: str = field(kw_only=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# generated by datamodel-codegen:
# filename: discriminator_enum.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, TypeAlias, Union


class RequestVersionEnum(Enum):
v1 = 'v1'
v2 = 'v2'


@dataclass
class RequestBase:
version: RequestVersionEnum


@dataclass
class RequestV1(RequestBase):
request_id: str = field(kw_only=True)
version: Literal['v1'] = 'v1'


@dataclass
class RequestV2(RequestBase):
version: Literal['v2'] = 'v2'


Request: TypeAlias = Union[RequestV1, RequestV2]
25 changes: 25 additions & 0 deletions tests/data/openapi/dataclass_inheritance_field_ordering.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
openapi: "3.0.0"
info:
title: Dataclass Inheritance Field Ordering Test
version: "1.0"
components:
schemas:
ParentWithDefault:
type: object
properties:
name:
type: string
default: "default_name"
read_only_field:
type: string
readOnly: true

ChildWithRequired:
allOf:
- $ref: '#/components/schemas/ParentWithDefault'
- type: object
properties:
child_id:
type: string
required:
- child_id
66 changes: 66 additions & 0 deletions tests/main/openapi/test_main_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2413,6 +2413,72 @@ def test_main_openapi_discriminator_one_literal_as_default_dataclass(output_file
)


@pytest.mark.skipif(
black.__version__.split(".")[0] == "19",
reason="Installed black doesn't support the old style",
)
def test_main_openapi_discriminator_one_literal_as_default_dataclass_py310(output_file: Path) -> None:
"""Test OpenAPI generation with discriminator one literal as default for dataclass with Python 3.10+."""
run_main_and_assert(
input_path=OPEN_API_DATA_PATH / "discriminator_enum.yaml",
output_path=output_file,
input_file_type="openapi",
assert_func=assert_file_content,
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "dataclass_enum_one_literal_as_default_py310.py",
extra_args=[
"--output-model-type",
"dataclasses.dataclass",
"--use-one-literal-as-default",
"--target-python-version",
"3.10",
],
)


@pytest.mark.skipif(
black.__version__.split(".")[0] == "19",
reason="Installed black doesn't support the old style",
)
def test_main_openapi_discriminator_one_literal_as_default_dataclass_py39_warning(output_file: Path) -> None:
"""Test that Python 3.9 emits warning for dataclass field ordering conflict."""
with pytest.warns(UserWarning, match=r"Dataclass .* has a field ordering conflict due to inheritance"):
run_main_and_assert(
input_path=OPEN_API_DATA_PATH / "discriminator_enum.yaml",
output_path=output_file,
input_file_type="openapi",
assert_func=assert_file_content,
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "dataclass_enum_one_literal_as_default.py",
extra_args=[
"--output-model-type",
"dataclasses.dataclass",
"--use-one-literal-as-default",
"--target-python-version",
"3.9",
],
)


@pytest.mark.skipif(
black.__version__.split(".")[0] == "19",
reason="Installed black doesn't support the old style",
)
def test_main_openapi_dataclass_inheritance_parent_default(output_file: Path) -> None:
"""Test dataclass field ordering fix when parent has default field."""
run_main_and_assert(
input_path=OPEN_API_DATA_PATH / "dataclass_inheritance_field_ordering.yaml",
output_path=output_file,
input_file_type="openapi",
assert_func=assert_file_content,
expected_file=EXPECTED_OPENAPI_PATH / "dataclass_inheritance_field_ordering_py310.py",
extra_args=[
"--output-model-type",
"dataclasses.dataclass",
"--target-python-version",
"3.10",
],
)


@pytest.mark.skipif(
black.__version__.split(".")[0] == "19",
reason="Installed black doesn't support the old style",
Expand Down
Loading