Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 217 additions & 1 deletion src/datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,220 @@ def _extract_additional_imports(extra_template_data: defaultdict[str, dict[str,
# Types that are lost during JSON Schema conversion and need to be preserved
_PRESERVED_TYPE_ORIGINS: dict[type, str] = {}

# Marker for types that Pydantic cannot serialize to JSON Schema
_UNSERIALIZABLE_MARKER = "x-python-unserializable"


def _serialize_python_type_full(tp: type) -> str: # noqa: PLR0911
"""Serialize ANY Python type to its string representation.

Handles:
- Basic types: str, int, bool, etc.
- Generic types: List[str], Dict[str, int], etc.
- Callable: Callable[[str], str], Callable[..., Any]
- Union types: str | int, Optional[str]
- Type: Type[BaseModel]
- Custom classes: mymodule.MyClass
- Nested generics: List[Callable[[str], str]]
"""
import types # noqa: PLC0415
from typing import Union, get_args, get_origin # noqa: PLC0415

if tp is type(None): # pragma: no cover
return "None"

if tp is ...: # pragma: no cover
return "..."

origin = get_origin(tp)
args = get_args(tp)

if origin is None:
module = getattr(tp, "__module__", "")
name = getattr(tp, "__name__", None) or getattr(tp, "__qualname__", None)

if name is None:
return str(tp).replace("typing.", "")

if module and module not in {"builtins", "typing", "collections.abc"}:
return f"{module}.{name}"
return name

if _is_callable_origin(origin):
return _serialize_callable(args)

if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType): # pragma: no cover
parts = [_serialize_python_type_full(arg) for arg in args]
return " | ".join(parts)

if origin is type:
if args:
return f"Type[{_serialize_python_type_full(args[0])}]"
return "Type" # pragma: no cover

origin_name = _get_origin_name(origin)
if args:
args_str = ", ".join(_serialize_python_type_full(arg) for arg in args)
return f"{origin_name}[{args_str}]"

return origin_name # pragma: no cover


def _is_callable_origin(origin: type | None) -> bool:
"""Check if origin is Callable."""
if origin is None: # pragma: no cover
return False
from collections.abc import Callable as ABCCallable # noqa: PLC0415

if origin is ABCCallable:
return True
origin_str = str(origin)
return "Callable" in origin_str or "callable" in origin_str


def _serialize_callable(args: tuple[type, ...]) -> str:
"""Serialize Callable type."""
if not args: # pragma: no cover
return "Callable"

params = args[:-1]
ret = args[-1]

if len(params) == 1 and params[0] is ...:
return f"Callable[..., {_serialize_python_type_full(ret)}]"

if len(params) == 1 and isinstance(params[0], (list, tuple)): # pragma: no cover
params = tuple(params[0])

params_str = ", ".join(_serialize_python_type_full(p) for p in params)
return f"Callable[[{params_str}], {_serialize_python_type_full(ret)}]"


def _get_origin_name(origin: type) -> str:
"""Get the name of a generic origin."""
name = getattr(origin, "__name__", None)
if name:
return name

# Fallback for origins without __name__ (rare edge case)
origin_str = str(origin) # pragma: no cover
if "typing." in origin_str: # pragma: no cover
return origin_str.replace("typing.", "")

return origin_str # pragma: no cover


def _get_input_model_json_schema_class() -> type:
"""Get the InputModelJsonSchema class (lazy import to avoid Pydantic v1 issues)."""
from pydantic.json_schema import GenerateJsonSchema # noqa: PLC0415

class InputModelJsonSchema(GenerateJsonSchema):
"""Custom schema generator that handles ALL unserializable types."""

def handle_invalid_for_json_schema( # noqa: PLR6301
self,
schema: Any, # noqa: ARG002
error_info: Any, # noqa: ARG002
) -> dict[str, Any]:
"""Catch ALL types that Pydantic can't serialize to JSON Schema."""
return {
"type": "object",
_UNSERIALIZABLE_MARKER: True,
}

def callable_schema( # noqa: PLR6301
self,
schema: Any, # noqa: ARG002
) -> dict[str, Any]:
"""Handle Callable types - these raise before handle_invalid_for_json_schema."""
return {
"type": "string",
_UNSERIALIZABLE_MARKER: True,
}

return InputModelJsonSchema


def _is_type_origin(annotation: type) -> bool:
"""Check if annotation is Type[X]."""
from typing import get_origin # noqa: PLC0415

origin = get_origin(annotation)
return origin is type


def _process_unserializable_property(prop: dict[str, Any], annotation: type) -> None:
"""Process a single property, handling anyOf/oneOf/items structures."""
if "anyOf" in prop:
for item in prop["anyOf"]:
if item.get(_UNSERIALIZABLE_MARKER):
_set_python_type_for_unserializable(item, annotation)
elif "oneOf" in prop: # pragma: no cover
for item in prop["oneOf"]:
if item.get(_UNSERIALIZABLE_MARKER):
_set_python_type_for_unserializable(item, annotation)
elif prop.get(_UNSERIALIZABLE_MARKER):
_set_python_type_for_unserializable(prop, annotation)
elif "items" in prop and prop["items"].get(_UNSERIALIZABLE_MARKER):
prop["x-python-type"] = _serialize_python_type_full(annotation)
prop["items"].pop(_UNSERIALIZABLE_MARKER, None)
elif _is_type_origin(annotation):
prop["x-python-type"] = _serialize_python_type_full(annotation)


def _set_python_type_for_unserializable(item: dict[str, Any], annotation: type) -> None:
"""Set x-python-type and clean up markers."""
from typing import Union, get_args, get_origin # noqa: PLC0415

origin = get_origin(annotation)
actual_type = annotation

if origin is Union:
for arg in get_args(annotation): # pragma: no branch
if arg is not type(None): # pragma: no branch
actual_type = arg
break

item["x-python-type"] = _serialize_python_type_full(actual_type)
item.pop(_UNSERIALIZABLE_MARKER, None)


def _add_python_type_for_unserializable(
schema: dict[str, Any],
model: type,
visited_defs: set[str] | None = None,
) -> dict[str, Any]:
"""Add x-python-type to ALL fields marked as unserializable.

Handles:
- Top-level properties
- Nested in anyOf/oneOf/allOf
- $defs definitions
"""
if visited_defs is None:
visited_defs = set()

if "properties" in schema:
model_fields = getattr(model, "model_fields", {})
for field_name, prop in schema["properties"].items():
if field_name in model_fields: # pragma: no branch
annotation = model_fields[field_name].annotation
_process_unserializable_property(prop, annotation)

if "$defs" in schema:
nested_models = _collect_nested_models(model)
model_name = getattr(model, "__name__", None)
if model_name: # pragma: no branch
nested_models[model_name] = model
for def_name, def_schema in schema["$defs"].items():
if def_name in visited_defs: # pragma: no cover
continue
visited_defs.add(def_name)
if def_name in nested_models: # pragma: no branch
_add_python_type_for_unserializable(def_schema, nested_models[def_name], visited_defs)

return schema


def _init_preserved_type_origins() -> dict[type, str]:
"""Initialize preserved type origins mapping (lazy initialization)."""
Expand Down Expand Up @@ -937,7 +1151,9 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
if not hasattr(obj, "model_json_schema"):
msg = "--input-model with Pydantic model requires Pydantic v2 runtime. Please upgrade Pydantic to v2."
raise Error(msg)
schema = obj.model_json_schema()
schema_generator = _get_input_model_json_schema_class()
schema = obj.model_json_schema(schema_generator=schema_generator)
schema = _add_python_type_for_unserializable(schema, obj)
schema = _add_python_type_info(schema, obj)

if ref_strategy and ref_strategy != InputModelRefStrategy.RegenerateAll:
Expand Down
19 changes: 17 additions & 2 deletions src/datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,15 @@ class JsonSchemaParser(Parser):
"AsyncGenerator": Import.from_full_path("collections.abc.AsyncGenerator"),
"Pattern": Import.from_full_path("re.Pattern"),
"Match": Import.from_full_path("re.Match"),
"Type": Import.from_full_path("typing.Type"),
}

# Types that require x-python-type override regardless of schema type
PYTHON_TYPE_OVERRIDE_ALWAYS: ClassVar[frozenset[str]] = frozenset({
"Callable",
"Type",
})

def __init__( # noqa: PLR0913
self,
source: str | Path | list[Path] | ParseResult,
Expand Down Expand Up @@ -1346,9 +1353,14 @@ def _get_python_type_base(self, python_type: str) -> str: # noqa: PLR6301

def _is_compatible_python_type(self, schema_type: str | None, python_type: str) -> bool:
"""Check if x-python-type is compatible with the JSON Schema type."""
base_type = self._get_python_type_base(python_type)
if base_type in self.PYTHON_TYPE_OVERRIDE_ALWAYS:
return False
all_type_names = self._extract_all_type_names(python_type)
if any(t in self.PYTHON_TYPE_OVERRIDE_ALWAYS for t in all_type_names):
return False
if schema_type is None:
return True
base_type = self._get_python_type_base(python_type)
if base_type in {"Union", "Optional"}:
return True
compatible = self.COMPATIBLE_PYTHON_TYPES.get(schema_type, frozenset())
Expand Down Expand Up @@ -2717,7 +2729,7 @@ def parse_property_names( # noqa: PLR0912
dict_key=key_type,
)

def parse_item( # noqa: PLR0911, PLR0912
def parse_item( # noqa: PLR0911, PLR0912, PLR0914
self,
name: str,
item: JsonSchemaObject,
Expand All @@ -2726,6 +2738,9 @@ def parse_item( # noqa: PLR0911, PLR0912
parent: JsonSchemaObject | None = None,
) -> DataType:
"""Parse a single JSON Schema item into a data type."""
python_type_override = self._get_python_type_override(item)
if python_type_override:
return python_type_override
if self.use_title_as_name and item.title:
name = sanitize_module_name(item.title, treat_dot_as_module=self.treat_dot_as_module)
singular_name = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from __future__ import annotations

from typing import Any, TypedDict
from collections.abc import Callable
from typing import TypedDict

from typing_extensions import NotRequired


class Model(TypedDict):
callback: NotRequired[Any]
callback: NotRequired[Callable[[str], str]]
49 changes: 47 additions & 2 deletions tests/data/python/input_model/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import FrozenSet, Optional, Set, Union
from collections.abc import Callable, Mapping, Sequence
from typing import Any, FrozenSet, Optional, Set, Type, Union

from pydantic import BaseModel

Expand Down Expand Up @@ -41,3 +41,48 @@ class RecursiveNode(BaseModel):

value: Set[str]
children: Optional[list[RecursiveNode]] = None


class ModelWithCallableTypes(BaseModel):
"""Model with Callable and other unserializable types."""

callback: Callable[[str], str]
multi_param_callback: Callable[[int, int], bool]
variadic_callback: Callable[..., Any]
no_param_callback: Callable[[], None]
optional_callback: Callable[[str], str] | None
type_field: Type[BaseModel]
nested_callable: list[Callable[[str], int]]


class NestedCallableModel(BaseModel):
"""Model with nested Callable types for $defs coverage."""

handler: Callable[[str], int]


class ModelWithNestedCallable(BaseModel):
"""Model referencing another model with Callable to test $defs processing."""

nested: NestedCallableModel
own_callback: Callable[[int], str]


class CustomClass:
"""Custom class for testing handle_invalid_for_json_schema."""

pass


class ModelWithCustomClass(BaseModel):
"""Model with a custom class that triggers handle_invalid_for_json_schema."""

model_config = {"arbitrary_types_allowed": True}
custom_obj: CustomClass


class ModelWithUnionCallable(BaseModel):
"""Model with Union of Callable and other types to test Union serialization."""

union_callback: Union[Callable[[str], str], int]
raw_callable: Callable # Callable without type args
Loading
Loading