Skip to content

Commit 0f7a6c9

Browse files
authored
Add automatic handling of unserializable types in --input-model (#2851)
1 parent 36d102c commit 0f7a6c9

5 files changed

Lines changed: 401 additions & 7 deletions

File tree

src/datamodel_code_generator/__main__.py

Lines changed: 217 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,220 @@ def _extract_additional_imports(extra_template_data: defaultdict[str, dict[str,
599599
# Types that are lost during JSON Schema conversion and need to be preserved
600600
_PRESERVED_TYPE_ORIGINS: dict[type, str] = {}
601601

602+
# Marker for types that Pydantic cannot serialize to JSON Schema
603+
_UNSERIALIZABLE_MARKER = "x-python-unserializable"
604+
605+
606+
def _serialize_python_type_full(tp: type) -> str: # noqa: PLR0911
607+
"""Serialize ANY Python type to its string representation.
608+
609+
Handles:
610+
- Basic types: str, int, bool, etc.
611+
- Generic types: List[str], Dict[str, int], etc.
612+
- Callable: Callable[[str], str], Callable[..., Any]
613+
- Union types: str | int, Optional[str]
614+
- Type: Type[BaseModel]
615+
- Custom classes: mymodule.MyClass
616+
- Nested generics: List[Callable[[str], str]]
617+
"""
618+
import types # noqa: PLC0415
619+
from typing import Union, get_args, get_origin # noqa: PLC0415
620+
621+
if tp is type(None): # pragma: no cover
622+
return "None"
623+
624+
if tp is ...: # pragma: no cover
625+
return "..."
626+
627+
origin = get_origin(tp)
628+
args = get_args(tp)
629+
630+
if origin is None:
631+
module = getattr(tp, "__module__", "")
632+
name = getattr(tp, "__name__", None) or getattr(tp, "__qualname__", None)
633+
634+
if name is None:
635+
return str(tp).replace("typing.", "")
636+
637+
if module and module not in {"builtins", "typing", "collections.abc"}:
638+
return f"{module}.{name}"
639+
return name
640+
641+
if _is_callable_origin(origin):
642+
return _serialize_callable(args)
643+
644+
if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType): # pragma: no cover
645+
parts = [_serialize_python_type_full(arg) for arg in args]
646+
return " | ".join(parts)
647+
648+
if origin is type:
649+
if args:
650+
return f"Type[{_serialize_python_type_full(args[0])}]"
651+
return "Type" # pragma: no cover
652+
653+
origin_name = _get_origin_name(origin)
654+
if args:
655+
args_str = ", ".join(_serialize_python_type_full(arg) for arg in args)
656+
return f"{origin_name}[{args_str}]"
657+
658+
return origin_name # pragma: no cover
659+
660+
661+
def _is_callable_origin(origin: type | None) -> bool:
662+
"""Check if origin is Callable."""
663+
if origin is None: # pragma: no cover
664+
return False
665+
from collections.abc import Callable as ABCCallable # noqa: PLC0415
666+
667+
if origin is ABCCallable:
668+
return True
669+
origin_str = str(origin)
670+
return "Callable" in origin_str or "callable" in origin_str
671+
672+
673+
def _serialize_callable(args: tuple[type, ...]) -> str:
674+
"""Serialize Callable type."""
675+
if not args: # pragma: no cover
676+
return "Callable"
677+
678+
params = args[:-1]
679+
ret = args[-1]
680+
681+
if len(params) == 1 and params[0] is ...:
682+
return f"Callable[..., {_serialize_python_type_full(ret)}]"
683+
684+
if len(params) == 1 and isinstance(params[0], (list, tuple)): # pragma: no cover
685+
params = tuple(params[0])
686+
687+
params_str = ", ".join(_serialize_python_type_full(p) for p in params)
688+
return f"Callable[[{params_str}], {_serialize_python_type_full(ret)}]"
689+
690+
691+
def _get_origin_name(origin: type) -> str:
692+
"""Get the name of a generic origin."""
693+
name = getattr(origin, "__name__", None)
694+
if name:
695+
return name
696+
697+
# Fallback for origins without __name__ (rare edge case)
698+
origin_str = str(origin) # pragma: no cover
699+
if "typing." in origin_str: # pragma: no cover
700+
return origin_str.replace("typing.", "")
701+
702+
return origin_str # pragma: no cover
703+
704+
705+
def _get_input_model_json_schema_class() -> type:
706+
"""Get the InputModelJsonSchema class (lazy import to avoid Pydantic v1 issues)."""
707+
from pydantic.json_schema import GenerateJsonSchema # noqa: PLC0415
708+
709+
class InputModelJsonSchema(GenerateJsonSchema):
710+
"""Custom schema generator that handles ALL unserializable types."""
711+
712+
def handle_invalid_for_json_schema( # noqa: PLR6301
713+
self,
714+
schema: Any, # noqa: ARG002
715+
error_info: Any, # noqa: ARG002
716+
) -> dict[str, Any]:
717+
"""Catch ALL types that Pydantic can't serialize to JSON Schema."""
718+
return {
719+
"type": "object",
720+
_UNSERIALIZABLE_MARKER: True,
721+
}
722+
723+
def callable_schema( # noqa: PLR6301
724+
self,
725+
schema: Any, # noqa: ARG002
726+
) -> dict[str, Any]:
727+
"""Handle Callable types - these raise before handle_invalid_for_json_schema."""
728+
return {
729+
"type": "string",
730+
_UNSERIALIZABLE_MARKER: True,
731+
}
732+
733+
return InputModelJsonSchema
734+
735+
736+
def _is_type_origin(annotation: type) -> bool:
737+
"""Check if annotation is Type[X]."""
738+
from typing import get_origin # noqa: PLC0415
739+
740+
origin = get_origin(annotation)
741+
return origin is type
742+
743+
744+
def _process_unserializable_property(prop: dict[str, Any], annotation: type) -> None:
745+
"""Process a single property, handling anyOf/oneOf/items structures."""
746+
if "anyOf" in prop:
747+
for item in prop["anyOf"]:
748+
if item.get(_UNSERIALIZABLE_MARKER):
749+
_set_python_type_for_unserializable(item, annotation)
750+
elif "oneOf" in prop: # pragma: no cover
751+
for item in prop["oneOf"]:
752+
if item.get(_UNSERIALIZABLE_MARKER):
753+
_set_python_type_for_unserializable(item, annotation)
754+
elif prop.get(_UNSERIALIZABLE_MARKER):
755+
_set_python_type_for_unserializable(prop, annotation)
756+
elif "items" in prop and prop["items"].get(_UNSERIALIZABLE_MARKER):
757+
prop["x-python-type"] = _serialize_python_type_full(annotation)
758+
prop["items"].pop(_UNSERIALIZABLE_MARKER, None)
759+
elif _is_type_origin(annotation):
760+
prop["x-python-type"] = _serialize_python_type_full(annotation)
761+
762+
763+
def _set_python_type_for_unserializable(item: dict[str, Any], annotation: type) -> None:
764+
"""Set x-python-type and clean up markers."""
765+
from typing import Union, get_args, get_origin # noqa: PLC0415
766+
767+
origin = get_origin(annotation)
768+
actual_type = annotation
769+
770+
if origin is Union:
771+
for arg in get_args(annotation): # pragma: no branch
772+
if arg is not type(None): # pragma: no branch
773+
actual_type = arg
774+
break
775+
776+
item["x-python-type"] = _serialize_python_type_full(actual_type)
777+
item.pop(_UNSERIALIZABLE_MARKER, None)
778+
779+
780+
def _add_python_type_for_unserializable(
781+
schema: dict[str, Any],
782+
model: type,
783+
visited_defs: set[str] | None = None,
784+
) -> dict[str, Any]:
785+
"""Add x-python-type to ALL fields marked as unserializable.
786+
787+
Handles:
788+
- Top-level properties
789+
- Nested in anyOf/oneOf/allOf
790+
- $defs definitions
791+
"""
792+
if visited_defs is None:
793+
visited_defs = set()
794+
795+
if "properties" in schema:
796+
model_fields = getattr(model, "model_fields", {})
797+
for field_name, prop in schema["properties"].items():
798+
if field_name in model_fields: # pragma: no branch
799+
annotation = model_fields[field_name].annotation
800+
_process_unserializable_property(prop, annotation)
801+
802+
if "$defs" in schema:
803+
nested_models = _collect_nested_models(model)
804+
model_name = getattr(model, "__name__", None)
805+
if model_name: # pragma: no branch
806+
nested_models[model_name] = model
807+
for def_name, def_schema in schema["$defs"].items():
808+
if def_name in visited_defs: # pragma: no cover
809+
continue
810+
visited_defs.add(def_name)
811+
if def_name in nested_models: # pragma: no branch
812+
_add_python_type_for_unserializable(def_schema, nested_models[def_name], visited_defs)
813+
814+
return schema
815+
602816

603817
def _init_preserved_type_origins() -> dict[type, str]:
604818
"""Initialize preserved type origins mapping (lazy initialization)."""
@@ -937,7 +1151,9 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
9371151
if not hasattr(obj, "model_json_schema"):
9381152
msg = "--input-model with Pydantic model requires Pydantic v2 runtime. Please upgrade Pydantic to v2."
9391153
raise Error(msg)
940-
schema = obj.model_json_schema()
1154+
schema_generator = _get_input_model_json_schema_class()
1155+
schema = obj.model_json_schema(schema_generator=schema_generator)
1156+
schema = _add_python_type_for_unserializable(schema, obj)
9411157
schema = _add_python_type_info(schema, obj)
9421158

9431159
if ref_strategy and ref_strategy != InputModelRefStrategy.RegenerateAll:

src/datamodel_code_generator/parser/jsonschema.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,15 @@ class JsonSchemaParser(Parser):
569569
"AsyncGenerator": Import.from_full_path("collections.abc.AsyncGenerator"),
570570
"Pattern": Import.from_full_path("re.Pattern"),
571571
"Match": Import.from_full_path("re.Match"),
572+
"Type": Import.from_full_path("typing.Type"),
572573
}
573574

575+
# Types that require x-python-type override regardless of schema type
576+
PYTHON_TYPE_OVERRIDE_ALWAYS: ClassVar[frozenset[str]] = frozenset({
577+
"Callable",
578+
"Type",
579+
})
580+
574581
def __init__( # noqa: PLR0913
575582
self,
576583
source: str | Path | list[Path] | ParseResult,
@@ -1346,9 +1353,14 @@ def _get_python_type_base(self, python_type: str) -> str: # noqa: PLR6301
13461353

13471354
def _is_compatible_python_type(self, schema_type: str | None, python_type: str) -> bool:
13481355
"""Check if x-python-type is compatible with the JSON Schema type."""
1356+
base_type = self._get_python_type_base(python_type)
1357+
if base_type in self.PYTHON_TYPE_OVERRIDE_ALWAYS:
1358+
return False
1359+
all_type_names = self._extract_all_type_names(python_type)
1360+
if any(t in self.PYTHON_TYPE_OVERRIDE_ALWAYS for t in all_type_names):
1361+
return False
13491362
if schema_type is None:
13501363
return True
1351-
base_type = self._get_python_type_base(python_type)
13521364
if base_type in {"Union", "Optional"}:
13531365
return True
13541366
compatible = self.COMPATIBLE_PYTHON_TYPES.get(schema_type, frozenset())
@@ -2717,7 +2729,7 @@ def parse_property_names( # noqa: PLR0912
27172729
dict_key=key_type,
27182730
)
27192731

2720-
def parse_item( # noqa: PLR0911, PLR0912
2732+
def parse_item( # noqa: PLR0911, PLR0912, PLR0914
27212733
self,
27222734
name: str,
27232735
item: JsonSchemaObject,
@@ -2726,6 +2738,9 @@ def parse_item( # noqa: PLR0911, PLR0912
27262738
parent: JsonSchemaObject | None = None,
27272739
) -> DataType:
27282740
"""Parse a single JSON Schema item into a data type."""
2741+
python_type_override = self._get_python_type_override(item)
2742+
if python_type_override:
2743+
return python_type_override
27292744
if self.use_title_as_name and item.title:
27302745
name = sanitize_module_name(item.title, treat_dot_as_module=self.treat_dot_as_module)
27312746
singular_name = False

tests/data/expected/main/jsonschema/x_python_type_no_schema_type.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from __future__ import annotations
66

7-
from typing import Any, TypedDict
7+
from collections.abc import Callable
8+
from typing import TypedDict
89

910
from typing_extensions import NotRequired
1011

1112

1213
class Model(TypedDict):
13-
callback: NotRequired[Any]
14+
callback: NotRequired[Callable[[str], str]]

tests/data/python/input_model/pydantic_models.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Mapping, Sequence
6-
from typing import FrozenSet, Optional, Set, Union
5+
from collections.abc import Callable, Mapping, Sequence
6+
from typing import Any, FrozenSet, Optional, Set, Type, Union
77

88
from pydantic import BaseModel
99

@@ -41,3 +41,48 @@ class RecursiveNode(BaseModel):
4141

4242
value: Set[str]
4343
children: Optional[list[RecursiveNode]] = None
44+
45+
46+
class ModelWithCallableTypes(BaseModel):
47+
"""Model with Callable and other unserializable types."""
48+
49+
callback: Callable[[str], str]
50+
multi_param_callback: Callable[[int, int], bool]
51+
variadic_callback: Callable[..., Any]
52+
no_param_callback: Callable[[], None]
53+
optional_callback: Callable[[str], str] | None
54+
type_field: Type[BaseModel]
55+
nested_callable: list[Callable[[str], int]]
56+
57+
58+
class NestedCallableModel(BaseModel):
59+
"""Model with nested Callable types for $defs coverage."""
60+
61+
handler: Callable[[str], int]
62+
63+
64+
class ModelWithNestedCallable(BaseModel):
65+
"""Model referencing another model with Callable to test $defs processing."""
66+
67+
nested: NestedCallableModel
68+
own_callback: Callable[[int], str]
69+
70+
71+
class CustomClass:
72+
"""Custom class for testing handle_invalid_for_json_schema."""
73+
74+
pass
75+
76+
77+
class ModelWithCustomClass(BaseModel):
78+
"""Model with a custom class that triggers handle_invalid_for_json_schema."""
79+
80+
model_config = {"arbitrary_types_allowed": True}
81+
custom_obj: CustomClass
82+
83+
84+
class ModelWithUnionCallable(BaseModel):
85+
"""Model with Union of Callable and other types to test Union serialization."""
86+
87+
union_callback: Union[Callable[[str], str], int]
88+
raw_callable: Callable # Callable without type args

0 commit comments

Comments
 (0)