From 6a5249c89e781414d8a1d688f63e05133d4f8606 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Tue, 30 Dec 2025 02:30:29 +0000 Subject: [PATCH 1/3] Add AST-based type string parsing helpers Add three helper functions to types.py for robust AST-based parsing of Python type annotation strings: - get_type_base_name(): Extract base type name (e.g., "List[str]" -> "List") - get_subscript_args(): Extract type arguments (e.g., "Dict[str, int]" -> ["str", "int"]) - extract_qualified_names(): Extract fully qualified names for import handling Refactor jsonschema.py to use these helpers: - _get_python_type_flags now uses get_type_base_name and get_subscript_args - _get_python_type_base now uses get_type_base_name - Added support for union operator (|) syntax in type flag detection This provides a solid foundation for handling x-python-type qualified name imports in a follow-up PR. --- .../parser/jsonschema.py | 37 ++----- src/datamodel_code_generator/types.py | 103 ++++++++++++++++++ 2 files changed, 112 insertions(+), 28 deletions(-) diff --git a/src/datamodel_code_generator/parser/jsonschema.py b/src/datamodel_code_generator/parser/jsonschema.py index 19395294c..551aa80f7 100644 --- a/src/datamodel_code_generator/parser/jsonschema.py +++ b/src/datamodel_code_generator/parser/jsonschema.py @@ -76,6 +76,8 @@ StrictTypes, Types, UnionIntFloat, + get_subscript_args, + get_type_base_name, ) from datamodel_code_generator.util import ( BaseModel, @@ -1314,42 +1316,21 @@ class decorator which does not preserve staticmethod descriptors. "MutableSet": {"is_set": True}, } - base_type = x_python_type.split("[")[0].strip() + base_type = get_type_base_name(x_python_type) if base_type in type_to_flag: return type_to_flag[base_type] - if base_type in {"Union", "Optional"}: - bracket_start = x_python_type.find("[") - if bracket_start != -1: - inner = x_python_type[bracket_start + 1 : -1] - depth = 0 - current = "" - for char in inner: - if char == "[": - depth += 1 - elif char == "]": - depth -= 1 - if char == "," and depth == 0: - arg_base = current.strip().split("[")[0] - if arg_base in type_to_flag: - return type_to_flag[arg_base] - current = "" - else: - current += char - if current.strip(): - arg_base = current.strip().split("[")[0] - if arg_base in type_to_flag: - return type_to_flag[arg_base] + if base_type in {"Union", "Optional"} or " | " in x_python_type: + for arg in get_subscript_args(x_python_type): + arg_base = get_type_base_name(arg) + if arg_base in type_to_flag: + return type_to_flag[arg_base] return {} def _get_python_type_base(self, python_type: str) -> str: # noqa: PLR6301 """Extract base type from a Python type annotation string.""" - if "." in python_type.split("[", maxsplit=1)[0]: - base = python_type.split("[", maxsplit=1)[0].rsplit(".", 1)[-1] - else: - base = python_type.split("[", maxsplit=1)[0].strip() - return base + return get_type_base_name(python_type) def _is_compatible_python_type(self, schema_type: str | None, python_type: str) -> bool: """Check if x-python-type is compatible with the JSON Schema type.""" diff --git a/src/datamodel_code_generator/types.py b/src/datamodel_code_generator/types.py index 4bf0d9314..2990b0208 100644 --- a/src/datamodel_code_generator/types.py +++ b/src/datamodel_code_generator/types.py @@ -7,6 +7,7 @@ from __future__ import annotations +import ast import re from abc import ABC, abstractmethod from copy import deepcopy @@ -191,6 +192,108 @@ def chain_as_tuple(*iterables: Iterable[T]) -> tuple[T, ...]: return tuple(chain(*iterables)) +def get_type_base_name(type_str: str) -> str: + """Extract base type name from a type annotation string using AST. + + Examples: + "List[str]" -> "List" + "foo.bar.Baz" -> "Baz" + "Optional[int]" -> "Optional" + """ + try: + tree = ast.parse(type_str, mode="eval") + except SyntaxError: + return type_str.split("[", maxsplit=1)[0].rsplit(".", 1)[-1].strip() + + body = tree.body + if isinstance(body, ast.Subscript): + body = body.value + + if isinstance(body, ast.Attribute): + return body.attr + if isinstance(body, ast.Name): + return body.id + return type_str.split("[", maxsplit=1)[0].rsplit(".", 1)[-1].strip() + + +def get_subscript_args(type_str: str) -> list[str]: + """Extract type arguments from a subscripted type using AST. + + Examples: + "List[str]" -> ["str"] + "Dict[str, int]" -> ["str", "int"] + "Union[str, int, None]" -> ["str", "int", "None"] + "str | int | None" -> ["str", "int", "None"] + "str" -> [] + """ + try: + tree = ast.parse(type_str, mode="eval") + except SyntaxError: + return [] + + body = tree.body + + if isinstance(body, ast.BinOp) and isinstance(body.op, ast.BitOr): + args: list[str] = [] + + def collect_union_args(node: ast.expr) -> None: + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + collect_union_args(node.left) + collect_union_args(node.right) + else: + args.append(ast.unparse(node)) + + collect_union_args(body) + return args + + if isinstance(body, ast.Subscript): + slice_node = body.slice + if isinstance(slice_node, ast.Tuple): + return [ast.unparse(elt) for elt in slice_node.elts] + return [ast.unparse(slice_node)] + + return [] + + +def extract_qualified_names(type_str: str) -> list[str]: + """Extract all fully qualified names from a type annotation string using AST. + + Finds patterns like 'module.path.ClassName' where the name contains dots. + + Examples: + "type[foo.bar.Baz]" -> ["foo.bar.Baz"] + "Dict[a.B, c.D]" -> ["a.B", "c.D"] + "str" -> [] + """ + try: + tree = ast.parse(type_str, mode="eval") + except SyntaxError: + return [] + + qualified_names: list[str] = [] + visited: set[int] = set() + + def get_full_name(node: ast.expr) -> str | None: + parts: list[str] = [] + current: ast.expr = node + while isinstance(current, ast.Attribute): + visited.add(id(current)) + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + return ".".join(reversed(parts)) + return None + + for node in ast.walk(tree): + if isinstance(node, ast.Attribute) and id(node) not in visited: + name = get_full_name(node) + if name and "." in name: + qualified_names.append(name) + + return qualified_names + + def _remove_none_from_union(type_: str, *, use_union_operator: bool) -> str: # noqa: PLR0912 """Remove None from a Union type string, handling nested unions.""" if use_union_operator: From e86ed97403925be8627f3edfa0d248a8a4b0ccf9 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Tue, 30 Dec 2025 02:44:39 +0000 Subject: [PATCH 2/3] Add tests for AST-based type string parsing helpers Add comprehensive tests for the three new helper functions: - get_type_base_name: 14 test cases - get_subscript_args: 18 test cases - extract_qualified_names: 20 test cases Achieves 100% diff coverage for the new code. --- tests/test_types.py | 109 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/tests/test_types.py b/tests/test_types.py index 065db5cfa..ab7f5b454 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -4,7 +4,13 @@ import pytest -from datamodel_code_generator.types import _remove_none_from_union, get_optional_type +from datamodel_code_generator.types import ( + _remove_none_from_union, + extract_qualified_names, + get_optional_type, + get_subscript_args, + get_type_base_name, +) @pytest.mark.parametrize( @@ -273,3 +279,104 @@ def test_datatype_deepcopy_memo_cache_hit() -> None: # Second call with same memo - should return cached object (covers memo hit branch) copied2 = data_type.__deepcopy__(memo) # noqa: PLC2801 assert copied2 is copied1 # Same object from memo + + +@pytest.mark.parametrize( + ("type_str", "expected"), + [ + # Simple types + ("str", "str"), + ("int", "int"), + ("List", "List"), + # Subscripted types + ("List[str]", "List"), + ("Dict[str, int]", "Dict"), + ("Optional[int]", "Optional"), + ("Union[str, int]", "Union"), + # Qualified names + ("foo.bar.Baz", "Baz"), + ("datamodel_code_generator.model.base.DataModel", "DataModel"), + # Subscripted with qualified names + ("type[foo.bar.Baz]", "type"), + ("List[foo.Bar]", "List"), + # Invalid syntax (fallback to string parsing) + ("List[", "List"), + ("[invalid", ""), # splits on "[" giving empty string + ], +) +def test_get_type_base_name(type_str: str, expected: str) -> None: + """Test get_type_base_name extracts base type correctly.""" + assert get_type_base_name(type_str) == expected + + +@pytest.mark.parametrize( + ("type_str", "expected"), + [ + # Simple types (no subscript) + ("str", []), + ("int", []), + # Single argument + ("List[str]", ["str"]), + ("Optional[int]", ["int"]), + ("type[Foo]", ["Foo"]), + # Multiple arguments + ("Dict[str, int]", ["str", "int"]), + ("Union[str, int, None]", ["str", "int", "None"]), + ("Tuple[int, str, float]", ["int", "str", "float"]), + # Union operator syntax + ("str | int", ["str", "int"]), + ("str | int | None", ["str", "int", "None"]), + ("List[str] | None", ["List[str]", "None"]), + # Complex nested types + ("Dict[str, List[int]]", ["str", "List[int]"]), + ("Union[List[str], Dict[str, int]]", ["List[str]", "Dict[str, int]"]), + # Qualified names in arguments + ("type[foo.bar.Baz]", ["foo.bar.Baz"]), + ("Dict[a.B, c.D]", ["a.B", "c.D"]), + # Invalid syntax + ("List[", []), + ("[invalid", []), + ], +) +def test_get_subscript_args(type_str: str, expected: list[str]) -> None: + """Test get_subscript_args extracts type arguments correctly.""" + assert get_subscript_args(type_str) == expected + + +@pytest.mark.parametrize( + ("type_str", "expected"), + [ + # No qualified names + ("str", []), + ("List[str]", []), + ("Union[str, int]", []), + # Single qualified name + ("foo.Bar", ["foo.Bar"]), + ("foo.bar.Baz", ["foo.bar.Baz"]), + ("datamodel_code_generator.model.base.DataModel", ["datamodel_code_generator.model.base.DataModel"]), + # Qualified names in subscript + ("type[foo.bar.Baz]", ["foo.bar.Baz"]), + ("List[foo.Bar]", ["foo.Bar"]), + ("Optional[a.b.C]", ["a.b.C"]), + # Multiple qualified names + ("Dict[a.B, c.D]", ["a.B", "c.D"]), + ("Union[foo.Bar, baz.Qux]", ["foo.Bar", "baz.Qux"]), + # Mixed with simple types + ("Dict[str, foo.Bar]", ["foo.Bar"]), + ("Union[int, a.B, None]", ["a.B"]), + # Union operator syntax + ("foo.Bar | None", ["foo.Bar"]), + ("a.B | c.D", ["a.B", "c.D"]), + # Complex nested + ("Dict[str, List[foo.Bar]]", ["foo.Bar"]), + ("type[datamodel_code_generator.types.DataTypeManager]", ["datamodel_code_generator.types.DataTypeManager"]), + # Attribute on non-Name (function call result) - should not extract + ("foo().bar", []), + ("func().attr.name", []), + # Invalid syntax + ("foo.Bar[", []), + ], +) +def test_extract_qualified_names(type_str: str, expected: list[str]) -> None: + """Test extract_qualified_names finds all fully qualified names.""" + assert extract_qualified_names(type_str) == expected From 6763c635742944f1f59764fbd3c8293241065a6b Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Tue, 30 Dec 2025 02:57:51 +0000 Subject: [PATCH 3/3] Add test for _get_python_type_flags to cover partial branch Add parametrized test with 25 cases covering: - Direct matches for special container types (Set, FrozenSet, Mapping, etc.) - Union types with special containers - Union types without special containers (completes loop without match) - Non-special container types This ensures 100% diff coverage for jsonschema.py line 1324. --- tests/parser/test_jsonschema.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/parser/test_jsonschema.py b/tests/parser/test_jsonschema.py index bab93fd57..4ff4d3f9e 100644 --- a/tests/parser/test_jsonschema.py +++ b/tests/parser/test_jsonschema.py @@ -1181,3 +1181,45 @@ def test_timestamp_with_time_zone_format() -> None: # Verify the format is mapped correctly assert json_schema_data_formats["string"]["timestamp with time zone"] == Types.date_time + + +@pytest.mark.parametrize( + ("x_python_type", "expected"), + [ + # Direct matches for special container types + ("Set[str]", {"is_set": True}), + ("set[int]", {"is_set": True}), + ("FrozenSet[int]", {"is_frozen_set": True}), + ("frozenset[str]", {"is_frozen_set": True}), + ("Sequence[str]", {"is_sequence": True}), + ("MutableSequence[int]", {"is_sequence": True}), + ("Mapping[str, int]", {"is_mapping": True}), + ("MutableMapping[str, int]", {"is_mapping": True}), + ("AbstractSet[str]", {"is_frozen_set": True}), + ("MutableSet[int]", {"is_set": True}), + # Union with special container type + ("Union[Set[str], None]", {"is_set": True}), + ("Optional[FrozenSet[int]]", {"is_frozen_set": True}), + ("Set[int] | None", {"is_set": True}), + ("Sequence[str] | int", {"is_sequence": True}), + # Union without special container type (loop completes without match) + ("Union[str, int]", {}), + ("str | int", {}), + ("Optional[str]", {}), + ("Union[str, int, float]", {}), + ("Union[List[str], None]", {}), # List is not a special container + ("Optional[Dict[str, int]]", {}), # Dict is not a special container + # Non-special container types + ("List[str]", {}), + ("Dict[str, int]", {}), + ("str", {}), + ("int", {}), + ("CustomType", {}), + ], +) +def test_get_python_type_flags(x_python_type: str, expected: dict[str, bool]) -> None: + """Test _get_python_type_flags extracts collection flags correctly.""" + parser = JsonSchemaParser("") + obj = model_validate(JsonSchemaObject, {"x-python-type": x_python_type}) + result = parser._get_python_type_flags(obj) + assert result == expected