From fa4acb533a2425c031551c7b3dc0c6027537c4a5 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Mon, 29 Dec 2025 04:23:31 +0000 Subject: [PATCH] Add automatic handling of unserializable types in --input-model --- src/datamodel_code_generator/__main__.py | 218 +++++++++++++++++- .../parser/jsonschema.py | 19 +- .../x_python_type_no_schema_type.py | 5 +- .../python/input_model/pydantic_models.py | 49 +++- tests/test_input_model.py | 117 ++++++++++ 5 files changed, 401 insertions(+), 7 deletions(-) diff --git a/src/datamodel_code_generator/__main__.py b/src/datamodel_code_generator/__main__.py index 9b9a6cf47..25e59f57b 100644 --- a/src/datamodel_code_generator/__main__.py +++ b/src/datamodel_code_generator/__main__.py @@ -599,6 +599,220 @@ def _extract_additional_imports(extra_template_data: defaultdict[str, dict[str, # Types that are lost during JSON Schema conversion and need to be preserved _PRESERVED_TYPE_ORIGINS: dict[type, str] = {} +# Marker for types that Pydantic cannot serialize to JSON Schema +_UNSERIALIZABLE_MARKER = "x-python-unserializable" + + +def _serialize_python_type_full(tp: type) -> str: # noqa: PLR0911 + """Serialize ANY Python type to its string representation. + + Handles: + - Basic types: str, int, bool, etc. + - Generic types: List[str], Dict[str, int], etc. + - Callable: Callable[[str], str], Callable[..., Any] + - Union types: str | int, Optional[str] + - Type: Type[BaseModel] + - Custom classes: mymodule.MyClass + - Nested generics: List[Callable[[str], str]] + """ + import types # noqa: PLC0415 + from typing import Union, get_args, get_origin # noqa: PLC0415 + + if tp is type(None): # pragma: no cover + return "None" + + if tp is ...: # pragma: no cover + return "..." + + origin = get_origin(tp) + args = get_args(tp) + + if origin is None: + module = getattr(tp, "__module__", "") + name = getattr(tp, "__name__", None) or getattr(tp, "__qualname__", None) + + if name is None: + return str(tp).replace("typing.", "") + + if module and module not in {"builtins", "typing", "collections.abc"}: + return f"{module}.{name}" + return name + + if _is_callable_origin(origin): + return _serialize_callable(args) + + if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType): # pragma: no cover + parts = [_serialize_python_type_full(arg) for arg in args] + return " | ".join(parts) + + if origin is type: + if args: + return f"Type[{_serialize_python_type_full(args[0])}]" + return "Type" # pragma: no cover + + origin_name = _get_origin_name(origin) + if args: + args_str = ", ".join(_serialize_python_type_full(arg) for arg in args) + return f"{origin_name}[{args_str}]" + + return origin_name # pragma: no cover + + +def _is_callable_origin(origin: type | None) -> bool: + """Check if origin is Callable.""" + if origin is None: # pragma: no cover + return False + from collections.abc import Callable as ABCCallable # noqa: PLC0415 + + if origin is ABCCallable: + return True + origin_str = str(origin) + return "Callable" in origin_str or "callable" in origin_str + + +def _serialize_callable(args: tuple[type, ...]) -> str: + """Serialize Callable type.""" + if not args: # pragma: no cover + return "Callable" + + params = args[:-1] + ret = args[-1] + + if len(params) == 1 and params[0] is ...: + return f"Callable[..., {_serialize_python_type_full(ret)}]" + + if len(params) == 1 and isinstance(params[0], (list, tuple)): # pragma: no cover + params = tuple(params[0]) + + params_str = ", ".join(_serialize_python_type_full(p) for p in params) + return f"Callable[[{params_str}], {_serialize_python_type_full(ret)}]" + + +def _get_origin_name(origin: type) -> str: + """Get the name of a generic origin.""" + name = getattr(origin, "__name__", None) + if name: + return name + + # Fallback for origins without __name__ (rare edge case) + origin_str = str(origin) # pragma: no cover + if "typing." in origin_str: # pragma: no cover + return origin_str.replace("typing.", "") + + return origin_str # pragma: no cover + + +def _get_input_model_json_schema_class() -> type: + """Get the InputModelJsonSchema class (lazy import to avoid Pydantic v1 issues).""" + from pydantic.json_schema import GenerateJsonSchema # noqa: PLC0415 + + class InputModelJsonSchema(GenerateJsonSchema): + """Custom schema generator that handles ALL unserializable types.""" + + def handle_invalid_for_json_schema( # noqa: PLR6301 + self, + schema: Any, # noqa: ARG002 + error_info: Any, # noqa: ARG002 + ) -> dict[str, Any]: + """Catch ALL types that Pydantic can't serialize to JSON Schema.""" + return { + "type": "object", + _UNSERIALIZABLE_MARKER: True, + } + + def callable_schema( # noqa: PLR6301 + self, + schema: Any, # noqa: ARG002 + ) -> dict[str, Any]: + """Handle Callable types - these raise before handle_invalid_for_json_schema.""" + return { + "type": "string", + _UNSERIALIZABLE_MARKER: True, + } + + return InputModelJsonSchema + + +def _is_type_origin(annotation: type) -> bool: + """Check if annotation is Type[X].""" + from typing import get_origin # noqa: PLC0415 + + origin = get_origin(annotation) + return origin is type + + +def _process_unserializable_property(prop: dict[str, Any], annotation: type) -> None: + """Process a single property, handling anyOf/oneOf/items structures.""" + if "anyOf" in prop: + for item in prop["anyOf"]: + if item.get(_UNSERIALIZABLE_MARKER): + _set_python_type_for_unserializable(item, annotation) + elif "oneOf" in prop: # pragma: no cover + for item in prop["oneOf"]: + if item.get(_UNSERIALIZABLE_MARKER): + _set_python_type_for_unserializable(item, annotation) + elif prop.get(_UNSERIALIZABLE_MARKER): + _set_python_type_for_unserializable(prop, annotation) + elif "items" in prop and prop["items"].get(_UNSERIALIZABLE_MARKER): + prop["x-python-type"] = _serialize_python_type_full(annotation) + prop["items"].pop(_UNSERIALIZABLE_MARKER, None) + elif _is_type_origin(annotation): + prop["x-python-type"] = _serialize_python_type_full(annotation) + + +def _set_python_type_for_unserializable(item: dict[str, Any], annotation: type) -> None: + """Set x-python-type and clean up markers.""" + from typing import Union, get_args, get_origin # noqa: PLC0415 + + origin = get_origin(annotation) + actual_type = annotation + + if origin is Union: + for arg in get_args(annotation): # pragma: no branch + if arg is not type(None): # pragma: no branch + actual_type = arg + break + + item["x-python-type"] = _serialize_python_type_full(actual_type) + item.pop(_UNSERIALIZABLE_MARKER, None) + + +def _add_python_type_for_unserializable( + schema: dict[str, Any], + model: type, + visited_defs: set[str] | None = None, +) -> dict[str, Any]: + """Add x-python-type to ALL fields marked as unserializable. + + Handles: + - Top-level properties + - Nested in anyOf/oneOf/allOf + - $defs definitions + """ + if visited_defs is None: + visited_defs = set() + + if "properties" in schema: + model_fields = getattr(model, "model_fields", {}) + for field_name, prop in schema["properties"].items(): + if field_name in model_fields: # pragma: no branch + annotation = model_fields[field_name].annotation + _process_unserializable_property(prop, annotation) + + if "$defs" in schema: + nested_models = _collect_nested_models(model) + model_name = getattr(model, "__name__", None) + if model_name: # pragma: no branch + nested_models[model_name] = model + for def_name, def_schema in schema["$defs"].items(): + if def_name in visited_defs: # pragma: no cover + continue + visited_defs.add(def_name) + if def_name in nested_models: # pragma: no branch + _add_python_type_for_unserializable(def_schema, nested_models[def_name], visited_defs) + + return schema + def _init_preserved_type_origins() -> dict[type, str]: """Initialize preserved type origins mapping (lazy initialization).""" @@ -937,7 +1151,9 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915 if not hasattr(obj, "model_json_schema"): msg = "--input-model with Pydantic model requires Pydantic v2 runtime. Please upgrade Pydantic to v2." raise Error(msg) - schema = obj.model_json_schema() + schema_generator = _get_input_model_json_schema_class() + schema = obj.model_json_schema(schema_generator=schema_generator) + schema = _add_python_type_for_unserializable(schema, obj) schema = _add_python_type_info(schema, obj) if ref_strategy and ref_strategy != InputModelRefStrategy.RegenerateAll: diff --git a/src/datamodel_code_generator/parser/jsonschema.py b/src/datamodel_code_generator/parser/jsonschema.py index 9ccbd5554..19395294c 100644 --- a/src/datamodel_code_generator/parser/jsonschema.py +++ b/src/datamodel_code_generator/parser/jsonschema.py @@ -569,8 +569,15 @@ class JsonSchemaParser(Parser): "AsyncGenerator": Import.from_full_path("collections.abc.AsyncGenerator"), "Pattern": Import.from_full_path("re.Pattern"), "Match": Import.from_full_path("re.Match"), + "Type": Import.from_full_path("typing.Type"), } + # Types that require x-python-type override regardless of schema type + PYTHON_TYPE_OVERRIDE_ALWAYS: ClassVar[frozenset[str]] = frozenset({ + "Callable", + "Type", + }) + def __init__( # noqa: PLR0913 self, source: str | Path | list[Path] | ParseResult, @@ -1346,9 +1353,14 @@ def _get_python_type_base(self, python_type: str) -> str: # noqa: PLR6301 def _is_compatible_python_type(self, schema_type: str | None, python_type: str) -> bool: """Check if x-python-type is compatible with the JSON Schema type.""" + base_type = self._get_python_type_base(python_type) + if base_type in self.PYTHON_TYPE_OVERRIDE_ALWAYS: + return False + all_type_names = self._extract_all_type_names(python_type) + if any(t in self.PYTHON_TYPE_OVERRIDE_ALWAYS for t in all_type_names): + return False if schema_type is None: return True - base_type = self._get_python_type_base(python_type) if base_type in {"Union", "Optional"}: return True compatible = self.COMPATIBLE_PYTHON_TYPES.get(schema_type, frozenset()) @@ -2717,7 +2729,7 @@ def parse_property_names( # noqa: PLR0912 dict_key=key_type, ) - def parse_item( # noqa: PLR0911, PLR0912 + def parse_item( # noqa: PLR0911, PLR0912, PLR0914 self, name: str, item: JsonSchemaObject, @@ -2726,6 +2738,9 @@ def parse_item( # noqa: PLR0911, PLR0912 parent: JsonSchemaObject | None = None, ) -> DataType: """Parse a single JSON Schema item into a data type.""" + python_type_override = self._get_python_type_override(item) + if python_type_override: + return python_type_override if self.use_title_as_name and item.title: name = sanitize_module_name(item.title, treat_dot_as_module=self.treat_dot_as_module) singular_name = False diff --git a/tests/data/expected/main/jsonschema/x_python_type_no_schema_type.py b/tests/data/expected/main/jsonschema/x_python_type_no_schema_type.py index 7584ba5b8..a7952aaec 100644 --- a/tests/data/expected/main/jsonschema/x_python_type_no_schema_type.py +++ b/tests/data/expected/main/jsonschema/x_python_type_no_schema_type.py @@ -4,10 +4,11 @@ from __future__ import annotations -from typing import Any, TypedDict +from collections.abc import Callable +from typing import TypedDict from typing_extensions import NotRequired class Model(TypedDict): - callback: NotRequired[Any] + callback: NotRequired[Callable[[str], str]] diff --git a/tests/data/python/input_model/pydantic_models.py b/tests/data/python/input_model/pydantic_models.py index ddcfb167f..d8016f518 100644 --- a/tests/data/python/input_model/pydantic_models.py +++ b/tests/data/python/input_model/pydantic_models.py @@ -2,8 +2,8 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence -from typing import FrozenSet, Optional, Set, Union +from collections.abc import Callable, Mapping, Sequence +from typing import Any, FrozenSet, Optional, Set, Type, Union from pydantic import BaseModel @@ -41,3 +41,48 @@ class RecursiveNode(BaseModel): value: Set[str] children: Optional[list[RecursiveNode]] = None + + +class ModelWithCallableTypes(BaseModel): + """Model with Callable and other unserializable types.""" + + callback: Callable[[str], str] + multi_param_callback: Callable[[int, int], bool] + variadic_callback: Callable[..., Any] + no_param_callback: Callable[[], None] + optional_callback: Callable[[str], str] | None + type_field: Type[BaseModel] + nested_callable: list[Callable[[str], int]] + + +class NestedCallableModel(BaseModel): + """Model with nested Callable types for $defs coverage.""" + + handler: Callable[[str], int] + + +class ModelWithNestedCallable(BaseModel): + """Model referencing another model with Callable to test $defs processing.""" + + nested: NestedCallableModel + own_callback: Callable[[int], str] + + +class CustomClass: + """Custom class for testing handle_invalid_for_json_schema.""" + + pass + + +class ModelWithCustomClass(BaseModel): + """Model with a custom class that triggers handle_invalid_for_json_schema.""" + + model_config = {"arbitrary_types_allowed": True} + custom_obj: CustomClass + + +class ModelWithUnionCallable(BaseModel): + """Model with Union of Callable and other types to test Union serialization.""" + + union_callback: Union[Callable[[str], str], int] + raw_callable: Callable # Callable without type args diff --git a/tests/test_input_model.py b/tests/test_input_model.py index 2e9632d88..f711b0253 100644 --- a/tests/test_input_model.py +++ b/tests/test_input_model.py @@ -606,6 +606,123 @@ def test_input_model_optional_mapping_union_syntax(tmp_path: Path) -> None: ) +# ============================================================================ +# Callable and unserializable type tests +# ============================================================================ + + +@SKIP_PYDANTIC_V1 +def test_input_model_callable_basic(tmp_path: Path) -> None: + """Test that Callable[[str], str] is preserved when converting Pydantic model.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithCallableTypes", + output_path=tmp_path / "output.py", + expected_output_contains=["Callable[[str], str]", "callback:"], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_callable_multi_param(tmp_path: Path) -> None: + """Test that Callable[[int, int], bool] is preserved.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithCallableTypes", + output_path=tmp_path / "output.py", + expected_output_contains=["Callable[[int, int], bool]", "multi_param_callback:"], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_callable_variadic(tmp_path: Path) -> None: + """Test that Callable[..., Any] is preserved.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithCallableTypes", + output_path=tmp_path / "output.py", + expected_output_contains=["Callable[..., Any]", "variadic_callback:"], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_callable_no_param(tmp_path: Path) -> None: + """Test that Callable[[], None] is preserved.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithCallableTypes", + output_path=tmp_path / "output.py", + expected_output_contains=["Callable[[], None]", "no_param_callback:"], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_callable_optional(tmp_path: Path) -> None: + """Test that Callable[[str], str] | None is preserved.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithCallableTypes", + output_path=tmp_path / "output.py", + expected_output_contains=["Callable[[str], str] | None", "optional_callback:"], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_type_field(tmp_path: Path) -> None: + """Test that Type[BaseModel] is preserved.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithCallableTypes", + output_path=tmp_path / "output.py", + expected_output_contains=["Type[pydantic.main.BaseModel]", "type_field:"], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_nested_callable(tmp_path: Path) -> None: + """Test that list[Callable[[str], int]] is preserved.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithCallableTypes", + output_path=tmp_path / "output.py", + expected_output_contains=["list[Callable[[str], int]]", "nested_callable:"], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_nested_model_with_callable(tmp_path: Path) -> None: + """Test that nested models with Callable types in $defs are processed.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithNestedCallable", + output_path=tmp_path / "output.py", + expected_output_contains=[ + "Callable[[str], int]", + "Callable[[int], str]", + "NestedCallableModel", + ], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_custom_class(tmp_path: Path) -> None: + """Test that custom classes trigger handle_invalid_for_json_schema.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithCustomClass", + output_path=tmp_path / "output.py", + expected_output_contains=[ + "CustomClass", + "custom_obj:", + ], + ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_union_callable(tmp_path: Path) -> None: + """Test that Union[Callable, int] and raw Callable are preserved.""" + run_input_model_and_assert( + input_model="tests.data.python.input_model.pydantic_models:ModelWithUnionCallable", + output_path=tmp_path / "output.py", + expected_output_contains=[ + "Callable[[str], str] | int", + "union_callback:", + "Callable", + "raw_callable:", + ], + ) + + # ============================================================================ # --input-model-ref-strategy tests # ============================================================================