Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 49 additions & 9 deletions src/datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _collect_nested_models(model: type, visited: set[type] | None = None) -> dic


def _find_models_in_type(tp: type, result: dict[str, type], visited: set[type]) -> None:
"""Recursively find BaseModel subclasses, Enums, and dataclasses in a type annotation."""
"""Recursively find BaseModel, Enum, dataclass, TypedDict, and msgspec in a type annotation."""
from dataclasses import is_dataclass # noqa: PLC0415
from enum import Enum as PyEnum # noqa: PLC0415
from typing import get_args # noqa: PLC0415
Expand All @@ -931,7 +931,12 @@ def _find_models_in_type(tp: type, result: dict[str, type], visited: set[type])
if issubclass(tp, BaseModel):
result[tp.__name__] = tp
result.update(_collect_nested_models(tp, visited))
elif issubclass(tp, PyEnum) or is_dataclass(tp):
elif (
issubclass(tp, PyEnum)
or is_dataclass(tp)
or hasattr(tp, "__required_keys__")
or hasattr(tp, "__struct_fields__")
):
result[tp.__name__] = tp

for arg in get_args(tp):
Expand Down Expand Up @@ -1004,10 +1009,11 @@ def _add_python_type_info_generic(schema: dict[str, Any], obj: type) -> dict[str
_TYPE_FAMILY_PYDANTIC = "pydantic"
_TYPE_FAMILY_DATACLASS = "dataclass"
_TYPE_FAMILY_TYPEDDICT = "typeddict"
_TYPE_FAMILY_MSGSPEC = "msgspec"
_TYPE_FAMILY_OTHER = "other"


def _get_type_family(tp: type) -> str:
def _get_type_family(tp: type) -> str: # noqa: PLR0911
"""Determine the type family of a Python type."""
from dataclasses import is_dataclass # noqa: PLC0415
from enum import Enum as PyEnum # noqa: PLC0415
Expand All @@ -1027,13 +1033,45 @@ def _get_type_family(tp: type) -> str:
if isinstance(tp, type) and hasattr(tp, "__required_keys__"):
return _TYPE_FAMILY_TYPEDDICT

if isinstance(tp, type) and hasattr(tp, "__struct_fields__"): # pragma: no cover
return _TYPE_FAMILY_MSGSPEC

return _TYPE_FAMILY_OTHER # pragma: no cover


def _get_output_family(output_model_type: DataModelType) -> str:
"""Get the type family corresponding to a DataModelType."""
pydantic_types = {
DataModelType.PydanticBaseModel,
DataModelType.PydanticV2BaseModel,
DataModelType.PydanticV2Dataclass,
}
if output_model_type in pydantic_types:
return _TYPE_FAMILY_PYDANTIC
if output_model_type == DataModelType.DataclassesDataclass:
return _TYPE_FAMILY_DATACLASS
if output_model_type == DataModelType.TypingTypedDict:
return _TYPE_FAMILY_TYPEDDICT
if output_model_type == DataModelType.MsgspecStruct:
return _TYPE_FAMILY_MSGSPEC
return _TYPE_FAMILY_OTHER # pragma: no cover


def _should_reuse_type(source_family: str, output_family: str) -> bool:
"""Determine if a source type can be reused without conversion.

Returns True if the source type should be imported and reused,
False if it needs to be regenerated into the output type.
"""
if source_family == _TYPE_FAMILY_ENUM:
return True
return source_family == output_family


def _filter_defs_by_strategy(
schema: dict[str, Any],
nested_models: dict[str, type],
input_model_family: str,
output_model_type: DataModelType,
strategy: InputModelRefStrategy,
) -> dict[str, Any]:
"""Filter $defs based on ref strategy, marking reused types with x-python-import."""
Expand All @@ -1043,6 +1081,7 @@ def _filter_defs_by_strategy(
if "$defs" not in schema: # pragma: no cover
return schema

output_family = _get_output_family(output_model_type)
new_defs: dict[str, Any] = {}

for def_name, def_schema in schema["$defs"].items():
Expand All @@ -1054,7 +1093,7 @@ def _filter_defs_by_strategy(
type_family = _get_type_family(nested_type)

should_reuse = strategy == InputModelRefStrategy.ReuseAll or (
strategy == InputModelRefStrategy.ReuseForeign and type_family != input_model_family
strategy == InputModelRefStrategy.ReuseForeign and _should_reuse_type(type_family, output_family)
)

if should_reuse:
Expand All @@ -1074,13 +1113,15 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
input_model: str,
input_file_type: InputFileType,
ref_strategy: InputModelRefStrategy | None = None,
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
) -> dict[str, object]:
"""Load schema from a Python import path.

Args:
input_model: Import path in 'module.path:ObjectName' format
input_file_type: Current input file type setting for validation
ref_strategy: Strategy for handling referenced types
output_model_type: Target output model type for reuse-foreign strategy

Returns:
Schema dict
Expand Down Expand Up @@ -1161,8 +1202,7 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
model_name = getattr(obj, "__name__", None)
if model_name and "$defs" in schema and model_name in schema["$defs"]: # pragma: no cover
nested_models[model_name] = obj
input_family = _get_type_family(obj)
schema = _filter_defs_by_strategy(schema, nested_models, input_family, ref_strategy)
schema = _filter_defs_by_strategy(schema, nested_models, output_model_type, ref_strategy)

return schema

Expand Down Expand Up @@ -1190,8 +1230,7 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
obj_name = getattr(obj, "__name__", None)
if obj_name and "$defs" in schema and obj_name in schema["$defs"]: # pragma: no cover
nested_models[obj_name] = obj_type
input_family = _get_type_family(obj_type)
schema = _filter_defs_by_strategy(schema, nested_models, input_family, ref_strategy)
schema = _filter_defs_by_strategy(schema, nested_models, output_model_type, ref_strategy)
except ImportError as e:
msg = "--input-model with dataclass/TypedDict requires Pydantic v2 runtime."
raise Error(msg) from e
Expand Down Expand Up @@ -1789,6 +1828,7 @@ def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912,
config.input_model,
config.input_file_type,
config.input_model_ref_strategy,
config.output_model_type,
)
input_ = json.dumps(schema)
if config.input_file_type == InputFileType.Auto:
Expand Down
68 changes: 68 additions & 0 deletions tests/data/python/input_model/mixed_nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Models with mixed types for reuse-foreign same-family tests."""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum

from pydantic import BaseModel
from typing_extensions import TypedDict


class Category(Enum):
"""Enum that should always be reused."""

A = "a"
B = "b"


class NestedTypedDict(TypedDict):
"""TypedDict that should be reused when output is TypedDict."""

key: str
value: int


class NestedPydantic(BaseModel):
"""Pydantic model that should be regenerated when output is TypedDict."""

name: str
age: int


@dataclass
class NestedDataclass:
"""Dataclass that should be reused when output is dataclass."""

title: str
count: int


class ModelWithTypedDict(BaseModel):
"""Pydantic model containing a TypedDict field."""

data: NestedTypedDict
category: Category


class ModelWithPydantic(BaseModel):
"""Pydantic model containing another Pydantic model."""

nested: NestedPydantic
category: Category


class ModelWithDataclass(BaseModel):
"""Pydantic model containing a dataclass field."""

info: NestedDataclass
category: Category


class ModelWithMixed(BaseModel):
"""Pydantic model with TypedDict, Pydantic, and dataclass nested types."""

typed_dict_field: NestedTypedDict
pydantic_field: NestedPydantic
dataclass_field: NestedDataclass
category: Category
149 changes: 142 additions & 7 deletions tests/test_input_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def test_input_model_ref_strategy_regenerate_all_explicit(tmp_path: Path) -> Non

@SKIP_PYDANTIC_V1
def test_input_model_ref_strategy_reuse_foreign(tmp_path: Path) -> None:
"""Test reuse-foreign strategy imports enums and dataclasses."""
"""Test reuse-foreign imports enum (always) and same-family types."""
run_input_model_and_assert(
input_model="tests.data.python.input_model.nested_models:User",
output_path=tmp_path / "output.py",
Expand All @@ -789,9 +789,8 @@ def test_input_model_ref_strategy_reuse_foreign(tmp_path: Path) -> None:
"reuse-foreign",
],
expected_output_contains=[
"from tests.data.python.input_model.nested_models import",
"Metadata",
"Status",
"from tests.data.python.input_model.nested_models import Status",
"class Metadata",
"class Address",
"class User",
],
Expand All @@ -800,7 +799,7 @@ def test_input_model_ref_strategy_reuse_foreign(tmp_path: Path) -> None:

@SKIP_PYDANTIC_V1
def test_input_model_ref_strategy_reuse_foreign_no_regeneration(tmp_path: Path) -> None:
"""Test reuse-foreign strategy does not regenerate foreign types."""
"""Test reuse-foreign imports only types compatible with output (enum always, same family)."""
output_path = tmp_path / "output.py"
run_input_model_and_assert(
input_model="tests.data.python.input_model.nested_models:User",
Expand All @@ -812,12 +811,13 @@ def test_input_model_ref_strategy_reuse_foreign_no_regeneration(tmp_path: Path)
"reuse-foreign",
],
expected_output_contains=[
"from tests.data.python.input_model.nested_models import",
"from tests.data.python.input_model.nested_models import Status",
"class Metadata",
"class Address",
],
)
content = output_path.read_text(encoding="utf-8")
assert "Status: TypeAlias" not in content
assert "class Metadata" not in content


@SKIP_PYDANTIC_V1
Expand Down Expand Up @@ -962,3 +962,138 @@ def test_input_model_ref_strategy_typeddict_reuse_foreign(tmp_path: Path) -> Non
"class Profile",
],
)


@SKIP_PYDANTIC_V1
def test_input_model_ref_strategy_reuse_foreign_same_family_typeddict(tmp_path: Path) -> None:
"""Test reuse-foreign imports TypedDict when output is TypedDict (same family)."""
output_path = tmp_path / "output.py"
run_input_model_and_assert(
input_model="tests.data.python.input_model.mixed_nested:ModelWithTypedDict",
output_path=output_path,
extra_args=[
"--output-model-type",
"typing.TypedDict",
"--input-model-ref-strategy",
"reuse-foreign",
],
expected_output_contains=[
"from tests.data.python.input_model.mixed_nested import",
"Category",
"NestedTypedDict",
],
)
content = output_path.read_text(encoding="utf-8")
assert "class NestedTypedDict" not in content


@SKIP_PYDANTIC_V1
def test_input_model_ref_strategy_reuse_foreign_different_family_regenerate(tmp_path: Path) -> None:
"""Test reuse-foreign regenerates Pydantic model when output is TypedDict."""
output_path = tmp_path / "output.py"
run_input_model_and_assert(
input_model="tests.data.python.input_model.mixed_nested:ModelWithPydantic",
output_path=output_path,
extra_args=[
"--output-model-type",
"typing.TypedDict",
"--input-model-ref-strategy",
"reuse-foreign",
],
expected_output_contains=[
"from tests.data.python.input_model.mixed_nested import Category",
"class NestedPydantic",
],
)


@SKIP_PYDANTIC_V1
def test_input_model_ref_strategy_reuse_foreign_same_family_dataclass(tmp_path: Path) -> None:
"""Test reuse-foreign imports dataclass when output is dataclass (same family)."""
output_path = tmp_path / "output.py"
run_input_model_and_assert(
input_model="tests.data.python.input_model.mixed_nested:ModelWithDataclass",
output_path=output_path,
extra_args=[
"--output-model-type",
"dataclasses.dataclass",
"--input-model-ref-strategy",
"reuse-foreign",
],
expected_output_contains=[
"from tests.data.python.input_model.mixed_nested import",
"Category",
"NestedDataclass",
],
)
content = output_path.read_text(encoding="utf-8")
assert "class NestedDataclass" not in content


@SKIP_PYDANTIC_V1
def test_input_model_ref_strategy_reuse_foreign_mixed_types(tmp_path: Path) -> None:
"""Test reuse-foreign with mixed nested types (TypedDict, Pydantic, dataclass)."""
output_path = tmp_path / "output.py"
run_input_model_and_assert(
input_model="tests.data.python.input_model.mixed_nested:ModelWithMixed",
output_path=output_path,
extra_args=[
"--output-model-type",
"typing.TypedDict",
"--input-model-ref-strategy",
"reuse-foreign",
],
expected_output_contains=[
"from tests.data.python.input_model.mixed_nested import",
"Category",
"NestedTypedDict",
"class NestedPydantic",
"class NestedDataclass",
],
)
content = output_path.read_text(encoding="utf-8")
assert "class NestedTypedDict" not in content


@SKIP_PYDANTIC_V1
def test_input_model_ref_strategy_reuse_foreign_pydantic_output(tmp_path: Path) -> None:
"""Test reuse-foreign imports Pydantic when output is Pydantic (same family)."""
output_path = tmp_path / "output.py"
run_input_model_and_assert(
input_model="tests.data.python.input_model.mixed_nested:ModelWithPydantic",
output_path=output_path,
extra_args=[
"--output-model-type",
"pydantic.BaseModel",
"--input-model-ref-strategy",
"reuse-foreign",
],
expected_output_contains=[
"from tests.data.python.input_model.mixed_nested import",
"Category",
"NestedPydantic",
],
)
content = output_path.read_text(encoding="utf-8")
assert "class NestedPydantic" not in content


@SKIP_PYDANTIC_V1
def test_input_model_ref_strategy_reuse_foreign_msgspec_output(tmp_path: Path) -> None:
"""Test reuse-foreign regenerates non-msgspec types when output is msgspec."""
output_path = tmp_path / "output.py"
run_input_model_and_assert(
input_model="tests.data.python.input_model.mixed_nested:ModelWithPydantic",
output_path=output_path,
extra_args=[
"--output-model-type",
"msgspec.Struct",
"--input-model-ref-strategy",
"reuse-foreign",
],
expected_output_contains=[
"from tests.data.python.input_model.mixed_nested import Category",
"class NestedPydantic",
"class ModelWithPydantic",
],
)
Loading