Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 9 additions & 28 deletions src/datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
StrictTypes,
Types,
UnionIntFloat,
get_subscript_args,
get_type_base_name,
)
from datamodel_code_generator.util import (
BaseModel,
Expand Down Expand Up @@ -1314,42 +1316,21 @@ class decorator which does not preserve staticmethod descriptors.
"MutableSet": {"is_set": True},
}

base_type = x_python_type.split("[")[0].strip()
base_type = get_type_base_name(x_python_type)
if base_type in type_to_flag:
return type_to_flag[base_type]

if base_type in {"Union", "Optional"}:
bracket_start = x_python_type.find("[")
if bracket_start != -1:
inner = x_python_type[bracket_start + 1 : -1]
depth = 0
current = ""
for char in inner:
if char == "[":
depth += 1
elif char == "]":
depth -= 1
if char == "," and depth == 0:
arg_base = current.strip().split("[")[0]
if arg_base in type_to_flag:
return type_to_flag[arg_base]
current = ""
else:
current += char
if current.strip():
arg_base = current.strip().split("[")[0]
if arg_base in type_to_flag:
return type_to_flag[arg_base]
if base_type in {"Union", "Optional"} or " | " in x_python_type:
for arg in get_subscript_args(x_python_type):
arg_base = get_type_base_name(arg)
if arg_base in type_to_flag:
return type_to_flag[arg_base]

return {}

def _get_python_type_base(self, python_type: str) -> str: # noqa: PLR6301
"""Extract base type from a Python type annotation string."""
if "." in python_type.split("[", maxsplit=1)[0]:
base = python_type.split("[", maxsplit=1)[0].rsplit(".", 1)[-1]
else:
base = python_type.split("[", maxsplit=1)[0].strip()
return base
return get_type_base_name(python_type)

def _is_compatible_python_type(self, schema_type: str | None, python_type: str) -> bool:
"""Check if x-python-type is compatible with the JSON Schema type."""
Expand Down
103 changes: 103 additions & 0 deletions src/datamodel_code_generator/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import ast
import re
from abc import ABC, abstractmethod
from copy import deepcopy
Expand Down Expand Up @@ -191,6 +192,108 @@ def chain_as_tuple(*iterables: Iterable[T]) -> tuple[T, ...]:
return tuple(chain(*iterables))


def get_type_base_name(type_str: str) -> str:
"""Extract base type name from a type annotation string using AST.

Examples:
"List[str]" -> "List"
"foo.bar.Baz" -> "Baz"
"Optional[int]" -> "Optional"
"""
try:
tree = ast.parse(type_str, mode="eval")
except SyntaxError:
return type_str.split("[", maxsplit=1)[0].rsplit(".", 1)[-1].strip()

body = tree.body
if isinstance(body, ast.Subscript):
body = body.value

if isinstance(body, ast.Attribute):
return body.attr
if isinstance(body, ast.Name):
return body.id
return type_str.split("[", maxsplit=1)[0].rsplit(".", 1)[-1].strip()


def get_subscript_args(type_str: str) -> list[str]:
"""Extract type arguments from a subscripted type using AST.

Examples:
"List[str]" -> ["str"]
"Dict[str, int]" -> ["str", "int"]
"Union[str, int, None]" -> ["str", "int", "None"]
"str | int | None" -> ["str", "int", "None"]
"str" -> []
"""
try:
tree = ast.parse(type_str, mode="eval")
except SyntaxError:
return []

body = tree.body

if isinstance(body, ast.BinOp) and isinstance(body.op, ast.BitOr):
args: list[str] = []

def collect_union_args(node: ast.expr) -> None:
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
collect_union_args(node.left)
collect_union_args(node.right)
else:
args.append(ast.unparse(node))

collect_union_args(body)
return args

if isinstance(body, ast.Subscript):
slice_node = body.slice
if isinstance(slice_node, ast.Tuple):
return [ast.unparse(elt) for elt in slice_node.elts]
return [ast.unparse(slice_node)]

return []


def extract_qualified_names(type_str: str) -> list[str]:
"""Extract all fully qualified names from a type annotation string using AST.

Finds patterns like 'module.path.ClassName' where the name contains dots.

Examples:
"type[foo.bar.Baz]" -> ["foo.bar.Baz"]
"Dict[a.B, c.D]" -> ["a.B", "c.D"]
"str" -> []
"""
try:
tree = ast.parse(type_str, mode="eval")
except SyntaxError:
return []

qualified_names: list[str] = []
visited: set[int] = set()

def get_full_name(node: ast.expr) -> str | None:
parts: list[str] = []
current: ast.expr = node
while isinstance(current, ast.Attribute):
visited.add(id(current))
parts.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.append(current.id)
return ".".join(reversed(parts))
return None

for node in ast.walk(tree):
if isinstance(node, ast.Attribute) and id(node) not in visited:
name = get_full_name(node)
if name and "." in name:
qualified_names.append(name)

return qualified_names


def _remove_none_from_union(type_: str, *, use_union_operator: bool) -> str: # noqa: PLR0912
"""Remove None from a Union type string, handling nested unions."""
if use_union_operator:
Expand Down
42 changes: 42 additions & 0 deletions tests/parser/test_jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,3 +1181,45 @@ def test_timestamp_with_time_zone_format() -> None:

# Verify the format is mapped correctly
assert json_schema_data_formats["string"]["timestamp with time zone"] == Types.date_time


@pytest.mark.parametrize(
("x_python_type", "expected"),
[
# Direct matches for special container types
("Set[str]", {"is_set": True}),
("set[int]", {"is_set": True}),
("FrozenSet[int]", {"is_frozen_set": True}),
("frozenset[str]", {"is_frozen_set": True}),
("Sequence[str]", {"is_sequence": True}),
("MutableSequence[int]", {"is_sequence": True}),
("Mapping[str, int]", {"is_mapping": True}),
("MutableMapping[str, int]", {"is_mapping": True}),
("AbstractSet[str]", {"is_frozen_set": True}),
("MutableSet[int]", {"is_set": True}),
# Union with special container type
("Union[Set[str], None]", {"is_set": True}),
("Optional[FrozenSet[int]]", {"is_frozen_set": True}),
("Set[int] | None", {"is_set": True}),
("Sequence[str] | int", {"is_sequence": True}),
# Union without special container type (loop completes without match)
("Union[str, int]", {}),
("str | int", {}),
("Optional[str]", {}),
("Union[str, int, float]", {}),
("Union[List[str], None]", {}), # List is not a special container
("Optional[Dict[str, int]]", {}), # Dict is not a special container
# Non-special container types
("List[str]", {}),
("Dict[str, int]", {}),
("str", {}),
("int", {}),
("CustomType", {}),
],
)
def test_get_python_type_flags(x_python_type: str, expected: dict[str, bool]) -> None:
"""Test _get_python_type_flags extracts collection flags correctly."""
parser = JsonSchemaParser("")
obj = model_validate(JsonSchemaObject, {"x-python-type": x_python_type})
result = parser._get_python_type_flags(obj)
assert result == expected
109 changes: 108 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

import pytest

from datamodel_code_generator.types import _remove_none_from_union, get_optional_type
from datamodel_code_generator.types import (
_remove_none_from_union,
extract_qualified_names,
get_optional_type,
get_subscript_args,
get_type_base_name,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -273,3 +279,104 @@ def test_datatype_deepcopy_memo_cache_hit() -> None:
# Second call with same memo - should return cached object (covers memo hit branch)
copied2 = data_type.__deepcopy__(memo) # noqa: PLC2801
assert copied2 is copied1 # Same object from memo


@pytest.mark.parametrize(
("type_str", "expected"),
[
# Simple types
("str", "str"),
("int", "int"),
("List", "List"),
# Subscripted types
("List[str]", "List"),
("Dict[str, int]", "Dict"),
("Optional[int]", "Optional"),
("Union[str, int]", "Union"),
# Qualified names
("foo.bar.Baz", "Baz"),
("datamodel_code_generator.model.base.DataModel", "DataModel"),
# Subscripted with qualified names
("type[foo.bar.Baz]", "type"),
("List[foo.Bar]", "List"),
# Invalid syntax (fallback to string parsing)
("List[", "List"),
("[invalid", ""), # splits on "[" giving empty string
],
)
def test_get_type_base_name(type_str: str, expected: str) -> None:
"""Test get_type_base_name extracts base type correctly."""
assert get_type_base_name(type_str) == expected


@pytest.mark.parametrize(
("type_str", "expected"),
[
# Simple types (no subscript)
("str", []),
("int", []),
# Single argument
("List[str]", ["str"]),
("Optional[int]", ["int"]),
("type[Foo]", ["Foo"]),
# Multiple arguments
("Dict[str, int]", ["str", "int"]),
("Union[str, int, None]", ["str", "int", "None"]),
("Tuple[int, str, float]", ["int", "str", "float"]),
# Union operator syntax
("str | int", ["str", "int"]),
("str | int | None", ["str", "int", "None"]),
("List[str] | None", ["List[str]", "None"]),
# Complex nested types
("Dict[str, List[int]]", ["str", "List[int]"]),
("Union[List[str], Dict[str, int]]", ["List[str]", "Dict[str, int]"]),
# Qualified names in arguments
("type[foo.bar.Baz]", ["foo.bar.Baz"]),
("Dict[a.B, c.D]", ["a.B", "c.D"]),
# Invalid syntax
("List[", []),
("[invalid", []),
],
)
def test_get_subscript_args(type_str: str, expected: list[str]) -> None:
"""Test get_subscript_args extracts type arguments correctly."""
assert get_subscript_args(type_str) == expected


@pytest.mark.parametrize(
("type_str", "expected"),
[
# No qualified names
("str", []),
("List[str]", []),
("Union[str, int]", []),
# Single qualified name
("foo.Bar", ["foo.Bar"]),
("foo.bar.Baz", ["foo.bar.Baz"]),
("datamodel_code_generator.model.base.DataModel", ["datamodel_code_generator.model.base.DataModel"]),
# Qualified names in subscript
("type[foo.bar.Baz]", ["foo.bar.Baz"]),
("List[foo.Bar]", ["foo.Bar"]),
("Optional[a.b.C]", ["a.b.C"]),
# Multiple qualified names
("Dict[a.B, c.D]", ["a.B", "c.D"]),
("Union[foo.Bar, baz.Qux]", ["foo.Bar", "baz.Qux"]),
# Mixed with simple types
("Dict[str, foo.Bar]", ["foo.Bar"]),
("Union[int, a.B, None]", ["a.B"]),
# Union operator syntax
("foo.Bar | None", ["foo.Bar"]),
("a.B | c.D", ["a.B", "c.D"]),
# Complex nested
("Dict[str, List[foo.Bar]]", ["foo.Bar"]),
("type[datamodel_code_generator.types.DataTypeManager]", ["datamodel_code_generator.types.DataTypeManager"]),
# Attribute on non-Name (function call result) - should not extract
("foo().bar", []),
("func().attr.name", []),
# Invalid syntax
("foo.Bar[", []),
],
)
def test_extract_qualified_names(type_str: str, expected: list[str]) -> None:
"""Test extract_qualified_names finds all fully qualified names."""
assert extract_qualified_names(type_str) == expected
Loading