diff --git a/src/datamodel_code_generator/__main__.py b/src/datamodel_code_generator/__main__.py index 25e59f57b..d79dfaaa5 100644 --- a/src/datamodel_code_generator/__main__.py +++ b/src/datamodel_code_generator/__main__.py @@ -922,7 +922,7 @@ def _collect_nested_models(model: type, visited: set[type] | None = None) -> dic def _find_models_in_type(tp: type, result: dict[str, type], visited: set[type]) -> None: - """Recursively find BaseModel subclasses, Enums, and dataclasses in a type annotation.""" + """Recursively find BaseModel, Enum, dataclass, TypedDict, and msgspec in a type annotation.""" from dataclasses import is_dataclass # noqa: PLC0415 from enum import Enum as PyEnum # noqa: PLC0415 from typing import get_args # noqa: PLC0415 @@ -931,7 +931,12 @@ def _find_models_in_type(tp: type, result: dict[str, type], visited: set[type]) if issubclass(tp, BaseModel): result[tp.__name__] = tp result.update(_collect_nested_models(tp, visited)) - elif issubclass(tp, PyEnum) or is_dataclass(tp): + elif ( + issubclass(tp, PyEnum) + or is_dataclass(tp) + or hasattr(tp, "__required_keys__") + or hasattr(tp, "__struct_fields__") + ): result[tp.__name__] = tp for arg in get_args(tp): @@ -1004,10 +1009,11 @@ def _add_python_type_info_generic(schema: dict[str, Any], obj: type) -> dict[str _TYPE_FAMILY_PYDANTIC = "pydantic" _TYPE_FAMILY_DATACLASS = "dataclass" _TYPE_FAMILY_TYPEDDICT = "typeddict" +_TYPE_FAMILY_MSGSPEC = "msgspec" _TYPE_FAMILY_OTHER = "other" -def _get_type_family(tp: type) -> str: +def _get_type_family(tp: type) -> str: # noqa: PLR0911 """Determine the type family of a Python type.""" from dataclasses import is_dataclass # noqa: PLC0415 from enum import Enum as PyEnum # noqa: PLC0415 @@ -1027,13 +1033,45 @@ def _get_type_family(tp: type) -> str: if isinstance(tp, type) and hasattr(tp, "__required_keys__"): return _TYPE_FAMILY_TYPEDDICT + if isinstance(tp, type) and hasattr(tp, "__struct_fields__"): # pragma: no cover + return _TYPE_FAMILY_MSGSPEC + + return _TYPE_FAMILY_OTHER # pragma: no cover + + +def _get_output_family(output_model_type: DataModelType) -> str: + """Get the type family corresponding to a DataModelType.""" + pydantic_types = { + DataModelType.PydanticBaseModel, + DataModelType.PydanticV2BaseModel, + DataModelType.PydanticV2Dataclass, + } + if output_model_type in pydantic_types: + return _TYPE_FAMILY_PYDANTIC + if output_model_type == DataModelType.DataclassesDataclass: + return _TYPE_FAMILY_DATACLASS + if output_model_type == DataModelType.TypingTypedDict: + return _TYPE_FAMILY_TYPEDDICT + if output_model_type == DataModelType.MsgspecStruct: + return _TYPE_FAMILY_MSGSPEC return _TYPE_FAMILY_OTHER # pragma: no cover +def _should_reuse_type(source_family: str, output_family: str) -> bool: + """Determine if a source type can be reused without conversion. + + Returns True if the source type should be imported and reused, + False if it needs to be regenerated into the output type. + """ + if source_family == _TYPE_FAMILY_ENUM: + return True + return source_family == output_family + + def _filter_defs_by_strategy( schema: dict[str, Any], nested_models: dict[str, type], - input_model_family: str, + output_model_type: DataModelType, strategy: InputModelRefStrategy, ) -> dict[str, Any]: """Filter $defs based on ref strategy, marking reused types with x-python-import.""" @@ -1043,6 +1081,7 @@ def _filter_defs_by_strategy( if "$defs" not in schema: # pragma: no cover return schema + output_family = _get_output_family(output_model_type) new_defs: dict[str, Any] = {} for def_name, def_schema in schema["$defs"].items(): @@ -1054,7 +1093,7 @@ def _filter_defs_by_strategy( type_family = _get_type_family(nested_type) should_reuse = strategy == InputModelRefStrategy.ReuseAll or ( - strategy == InputModelRefStrategy.ReuseForeign and type_family != input_model_family + strategy == InputModelRefStrategy.ReuseForeign and _should_reuse_type(type_family, output_family) ) if should_reuse: @@ -1074,6 +1113,7 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915 input_model: str, input_file_type: InputFileType, ref_strategy: InputModelRefStrategy | None = None, + output_model_type: DataModelType = DataModelType.PydanticBaseModel, ) -> dict[str, object]: """Load schema from a Python import path. @@ -1081,6 +1121,7 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915 input_model: Import path in 'module.path:ObjectName' format input_file_type: Current input file type setting for validation ref_strategy: Strategy for handling referenced types + output_model_type: Target output model type for reuse-foreign strategy Returns: Schema dict @@ -1161,8 +1202,7 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915 model_name = getattr(obj, "__name__", None) if model_name and "$defs" in schema and model_name in schema["$defs"]: # pragma: no cover nested_models[model_name] = obj - input_family = _get_type_family(obj) - schema = _filter_defs_by_strategy(schema, nested_models, input_family, ref_strategy) + schema = _filter_defs_by_strategy(schema, nested_models, output_model_type, ref_strategy) return schema @@ -1190,8 +1230,7 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915 obj_name = getattr(obj, "__name__", None) if obj_name and "$defs" in schema and obj_name in schema["$defs"]: # pragma: no cover nested_models[obj_name] = obj_type - input_family = _get_type_family(obj_type) - schema = _filter_defs_by_strategy(schema, nested_models, input_family, ref_strategy) + schema = _filter_defs_by_strategy(schema, nested_models, output_model_type, ref_strategy) except ImportError as e: msg = "--input-model with dataclass/TypedDict requires Pydantic v2 runtime." raise Error(msg) from e @@ -1789,6 +1828,7 @@ def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912, config.input_model, config.input_file_type, config.input_model_ref_strategy, + config.output_model_type, ) input_ = json.dumps(schema) if config.input_file_type == InputFileType.Auto: diff --git a/tests/data/python/input_model/mixed_nested.py b/tests/data/python/input_model/mixed_nested.py new file mode 100644 index 000000000..d5d42b11b --- /dev/null +++ b/tests/data/python/input_model/mixed_nested.py @@ -0,0 +1,68 @@ +"""Models with mixed types for reuse-foreign same-family tests.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +from pydantic import BaseModel +from typing_extensions import TypedDict + + +class Category(Enum): + """Enum that should always be reused.""" + + A = "a" + B = "b" + + +class NestedTypedDict(TypedDict): + """TypedDict that should be reused when output is TypedDict.""" + + key: str + value: int + + +class NestedPydantic(BaseModel): + """Pydantic model that should be regenerated when output is TypedDict.""" + + name: str + age: int + + +@dataclass +class NestedDataclass: + """Dataclass that should be reused when output is dataclass.""" + + title: str + count: int + + +class ModelWithTypedDict(BaseModel): + """Pydantic model containing a TypedDict field.""" + + data: NestedTypedDict + category: Category + + +class ModelWithPydantic(BaseModel): + """Pydantic model containing another Pydantic model.""" + + nested: NestedPydantic + category: Category + + +class ModelWithDataclass(BaseModel): + """Pydantic model containing a dataclass field.""" + + info: NestedDataclass + category: Category + + +class ModelWithMixed(BaseModel): + """Pydantic model with TypedDict, Pydantic, and dataclass nested types.""" + + typed_dict_field: NestedTypedDict + pydantic_field: NestedPydantic + dataclass_field: NestedDataclass + category: Category diff --git a/tests/test_input_model.py b/tests/test_input_model.py index f711b0253..8f0323394 100644 --- a/tests/test_input_model.py +++ b/tests/test_input_model.py @@ -778,7 +778,7 @@ def test_input_model_ref_strategy_regenerate_all_explicit(tmp_path: Path) -> Non @SKIP_PYDANTIC_V1 def test_input_model_ref_strategy_reuse_foreign(tmp_path: Path) -> None: - """Test reuse-foreign strategy imports enums and dataclasses.""" + """Test reuse-foreign imports enum (always) and same-family types.""" run_input_model_and_assert( input_model="tests.data.python.input_model.nested_models:User", output_path=tmp_path / "output.py", @@ -789,9 +789,8 @@ def test_input_model_ref_strategy_reuse_foreign(tmp_path: Path) -> None: "reuse-foreign", ], expected_output_contains=[ - "from tests.data.python.input_model.nested_models import", - "Metadata", - "Status", + "from tests.data.python.input_model.nested_models import Status", + "class Metadata", "class Address", "class User", ], @@ -800,7 +799,7 @@ def test_input_model_ref_strategy_reuse_foreign(tmp_path: Path) -> None: @SKIP_PYDANTIC_V1 def test_input_model_ref_strategy_reuse_foreign_no_regeneration(tmp_path: Path) -> None: - """Test reuse-foreign strategy does not regenerate foreign types.""" + """Test reuse-foreign imports only types compatible with output (enum always, same family).""" output_path = tmp_path / "output.py" run_input_model_and_assert( input_model="tests.data.python.input_model.nested_models:User", @@ -812,12 +811,13 @@ def test_input_model_ref_strategy_reuse_foreign_no_regeneration(tmp_path: Path) "reuse-foreign", ], expected_output_contains=[ - "from tests.data.python.input_model.nested_models import", + "from tests.data.python.input_model.nested_models import Status", + "class Metadata", + "class Address", ], ) content = output_path.read_text(encoding="utf-8") assert "Status: TypeAlias" not in content - assert "class Metadata" not in content @SKIP_PYDANTIC_V1 @@ -962,3 +962,138 @@ def test_input_model_ref_strategy_typeddict_reuse_foreign(tmp_path: Path) -> Non "class Profile", ], ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_ref_strategy_reuse_foreign_same_family_typeddict(tmp_path: Path) -> None: + """Test reuse-foreign imports TypedDict when output is TypedDict (same family).""" + output_path = tmp_path / "output.py" + run_input_model_and_assert( + input_model="tests.data.python.input_model.mixed_nested:ModelWithTypedDict", + output_path=output_path, + extra_args=[ + "--output-model-type", + "typing.TypedDict", + "--input-model-ref-strategy", + "reuse-foreign", + ], + expected_output_contains=[ + "from tests.data.python.input_model.mixed_nested import", + "Category", + "NestedTypedDict", + ], + ) + content = output_path.read_text(encoding="utf-8") + assert "class NestedTypedDict" not in content + + +@SKIP_PYDANTIC_V1 +def test_input_model_ref_strategy_reuse_foreign_different_family_regenerate(tmp_path: Path) -> None: + """Test reuse-foreign regenerates Pydantic model when output is TypedDict.""" + output_path = tmp_path / "output.py" + run_input_model_and_assert( + input_model="tests.data.python.input_model.mixed_nested:ModelWithPydantic", + output_path=output_path, + extra_args=[ + "--output-model-type", + "typing.TypedDict", + "--input-model-ref-strategy", + "reuse-foreign", + ], + expected_output_contains=[ + "from tests.data.python.input_model.mixed_nested import Category", + "class NestedPydantic", + ], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_ref_strategy_reuse_foreign_same_family_dataclass(tmp_path: Path) -> None: + """Test reuse-foreign imports dataclass when output is dataclass (same family).""" + output_path = tmp_path / "output.py" + run_input_model_and_assert( + input_model="tests.data.python.input_model.mixed_nested:ModelWithDataclass", + output_path=output_path, + extra_args=[ + "--output-model-type", + "dataclasses.dataclass", + "--input-model-ref-strategy", + "reuse-foreign", + ], + expected_output_contains=[ + "from tests.data.python.input_model.mixed_nested import", + "Category", + "NestedDataclass", + ], + ) + content = output_path.read_text(encoding="utf-8") + assert "class NestedDataclass" not in content + + +@SKIP_PYDANTIC_V1 +def test_input_model_ref_strategy_reuse_foreign_mixed_types(tmp_path: Path) -> None: + """Test reuse-foreign with mixed nested types (TypedDict, Pydantic, dataclass).""" + output_path = tmp_path / "output.py" + run_input_model_and_assert( + input_model="tests.data.python.input_model.mixed_nested:ModelWithMixed", + output_path=output_path, + extra_args=[ + "--output-model-type", + "typing.TypedDict", + "--input-model-ref-strategy", + "reuse-foreign", + ], + expected_output_contains=[ + "from tests.data.python.input_model.mixed_nested import", + "Category", + "NestedTypedDict", + "class NestedPydantic", + "class NestedDataclass", + ], + ) + content = output_path.read_text(encoding="utf-8") + assert "class NestedTypedDict" not in content + + +@SKIP_PYDANTIC_V1 +def test_input_model_ref_strategy_reuse_foreign_pydantic_output(tmp_path: Path) -> None: + """Test reuse-foreign imports Pydantic when output is Pydantic (same family).""" + output_path = tmp_path / "output.py" + run_input_model_and_assert( + input_model="tests.data.python.input_model.mixed_nested:ModelWithPydantic", + output_path=output_path, + extra_args=[ + "--output-model-type", + "pydantic.BaseModel", + "--input-model-ref-strategy", + "reuse-foreign", + ], + expected_output_contains=[ + "from tests.data.python.input_model.mixed_nested import", + "Category", + "NestedPydantic", + ], + ) + content = output_path.read_text(encoding="utf-8") + assert "class NestedPydantic" not in content + + +@SKIP_PYDANTIC_V1 +def test_input_model_ref_strategy_reuse_foreign_msgspec_output(tmp_path: Path) -> None: + """Test reuse-foreign regenerates non-msgspec types when output is msgspec.""" + output_path = tmp_path / "output.py" + run_input_model_and_assert( + input_model="tests.data.python.input_model.mixed_nested:ModelWithPydantic", + output_path=output_path, + extra_args=[ + "--output-model-type", + "msgspec.Struct", + "--input-model-ref-strategy", + "reuse-foreign", + ], + expected_output_contains=[ + "from tests.data.python.input_model.mixed_nested import Category", + "class NestedPydantic", + "class ModelWithPydantic", + ], + )