Skip to content

Commit b9bee27

Browse files
authored
Add AST-based type string parsing helpers (#2856)
* Add AST-based type string parsing helpers Add three helper functions to types.py for robust AST-based parsing of Python type annotation strings: - get_type_base_name(): Extract base type name (e.g., "List[str]" -> "List") - get_subscript_args(): Extract type arguments (e.g., "Dict[str, int]" -> ["str", "int"]) - extract_qualified_names(): Extract fully qualified names for import handling Refactor jsonschema.py to use these helpers: - _get_python_type_flags now uses get_type_base_name and get_subscript_args - _get_python_type_base now uses get_type_base_name - Added support for union operator (|) syntax in type flag detection This provides a solid foundation for handling x-python-type qualified name imports in a follow-up PR. * Add tests for AST-based type string parsing helpers Add comprehensive tests for the three new helper functions: - get_type_base_name: 14 test cases - get_subscript_args: 18 test cases - extract_qualified_names: 20 test cases Achieves 100% diff coverage for the new code. * Add test for _get_python_type_flags to cover partial branch Add parametrized test with 25 cases covering: - Direct matches for special container types (Set, FrozenSet, Mapping, etc.) - Union types with special containers - Union types without special containers (completes loop without match) - Non-special container types This ensures 100% diff coverage for jsonschema.py line 1324.
1 parent 7a4709c commit b9bee27

4 files changed

Lines changed: 262 additions & 29 deletions

File tree

src/datamodel_code_generator/parser/jsonschema.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
StrictTypes,
7777
Types,
7878
UnionIntFloat,
79+
get_subscript_args,
80+
get_type_base_name,
7981
)
8082
from datamodel_code_generator.util import (
8183
BaseModel,
@@ -1314,42 +1316,21 @@ class decorator which does not preserve staticmethod descriptors.
13141316
"MutableSet": {"is_set": True},
13151317
}
13161318

1317-
base_type = x_python_type.split("[")[0].strip()
1319+
base_type = get_type_base_name(x_python_type)
13181320
if base_type in type_to_flag:
13191321
return type_to_flag[base_type]
13201322

1321-
if base_type in {"Union", "Optional"}:
1322-
bracket_start = x_python_type.find("[")
1323-
if bracket_start != -1:
1324-
inner = x_python_type[bracket_start + 1 : -1]
1325-
depth = 0
1326-
current = ""
1327-
for char in inner:
1328-
if char == "[":
1329-
depth += 1
1330-
elif char == "]":
1331-
depth -= 1
1332-
if char == "," and depth == 0:
1333-
arg_base = current.strip().split("[")[0]
1334-
if arg_base in type_to_flag:
1335-
return type_to_flag[arg_base]
1336-
current = ""
1337-
else:
1338-
current += char
1339-
if current.strip():
1340-
arg_base = current.strip().split("[")[0]
1341-
if arg_base in type_to_flag:
1342-
return type_to_flag[arg_base]
1323+
if base_type in {"Union", "Optional"} or " | " in x_python_type:
1324+
for arg in get_subscript_args(x_python_type):
1325+
arg_base = get_type_base_name(arg)
1326+
if arg_base in type_to_flag:
1327+
return type_to_flag[arg_base]
13431328

13441329
return {}
13451330

13461331
def _get_python_type_base(self, python_type: str) -> str: # noqa: PLR6301
13471332
"""Extract base type from a Python type annotation string."""
1348-
if "." in python_type.split("[", maxsplit=1)[0]:
1349-
base = python_type.split("[", maxsplit=1)[0].rsplit(".", 1)[-1]
1350-
else:
1351-
base = python_type.split("[", maxsplit=1)[0].strip()
1352-
return base
1333+
return get_type_base_name(python_type)
13531334

13541335
def _is_compatible_python_type(self, schema_type: str | None, python_type: str) -> bool:
13551336
"""Check if x-python-type is compatible with the JSON Schema type."""

src/datamodel_code_generator/types.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
import ast
1011
import re
1112
from abc import ABC, abstractmethod
1213
from copy import deepcopy
@@ -191,6 +192,108 @@ def chain_as_tuple(*iterables: Iterable[T]) -> tuple[T, ...]:
191192
return tuple(chain(*iterables))
192193

193194

195+
def get_type_base_name(type_str: str) -> str:
196+
"""Extract base type name from a type annotation string using AST.
197+
198+
Examples:
199+
"List[str]" -> "List"
200+
"foo.bar.Baz" -> "Baz"
201+
"Optional[int]" -> "Optional"
202+
"""
203+
try:
204+
tree = ast.parse(type_str, mode="eval")
205+
except SyntaxError:
206+
return type_str.split("[", maxsplit=1)[0].rsplit(".", 1)[-1].strip()
207+
208+
body = tree.body
209+
if isinstance(body, ast.Subscript):
210+
body = body.value
211+
212+
if isinstance(body, ast.Attribute):
213+
return body.attr
214+
if isinstance(body, ast.Name):
215+
return body.id
216+
return type_str.split("[", maxsplit=1)[0].rsplit(".", 1)[-1].strip()
217+
218+
219+
def get_subscript_args(type_str: str) -> list[str]:
220+
"""Extract type arguments from a subscripted type using AST.
221+
222+
Examples:
223+
"List[str]" -> ["str"]
224+
"Dict[str, int]" -> ["str", "int"]
225+
"Union[str, int, None]" -> ["str", "int", "None"]
226+
"str | int | None" -> ["str", "int", "None"]
227+
"str" -> []
228+
"""
229+
try:
230+
tree = ast.parse(type_str, mode="eval")
231+
except SyntaxError:
232+
return []
233+
234+
body = tree.body
235+
236+
if isinstance(body, ast.BinOp) and isinstance(body.op, ast.BitOr):
237+
args: list[str] = []
238+
239+
def collect_union_args(node: ast.expr) -> None:
240+
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
241+
collect_union_args(node.left)
242+
collect_union_args(node.right)
243+
else:
244+
args.append(ast.unparse(node))
245+
246+
collect_union_args(body)
247+
return args
248+
249+
if isinstance(body, ast.Subscript):
250+
slice_node = body.slice
251+
if isinstance(slice_node, ast.Tuple):
252+
return [ast.unparse(elt) for elt in slice_node.elts]
253+
return [ast.unparse(slice_node)]
254+
255+
return []
256+
257+
258+
def extract_qualified_names(type_str: str) -> list[str]:
259+
"""Extract all fully qualified names from a type annotation string using AST.
260+
261+
Finds patterns like 'module.path.ClassName' where the name contains dots.
262+
263+
Examples:
264+
"type[foo.bar.Baz]" -> ["foo.bar.Baz"]
265+
"Dict[a.B, c.D]" -> ["a.B", "c.D"]
266+
"str" -> []
267+
"""
268+
try:
269+
tree = ast.parse(type_str, mode="eval")
270+
except SyntaxError:
271+
return []
272+
273+
qualified_names: list[str] = []
274+
visited: set[int] = set()
275+
276+
def get_full_name(node: ast.expr) -> str | None:
277+
parts: list[str] = []
278+
current: ast.expr = node
279+
while isinstance(current, ast.Attribute):
280+
visited.add(id(current))
281+
parts.append(current.attr)
282+
current = current.value
283+
if isinstance(current, ast.Name):
284+
parts.append(current.id)
285+
return ".".join(reversed(parts))
286+
return None
287+
288+
for node in ast.walk(tree):
289+
if isinstance(node, ast.Attribute) and id(node) not in visited:
290+
name = get_full_name(node)
291+
if name and "." in name:
292+
qualified_names.append(name)
293+
294+
return qualified_names
295+
296+
194297
def _remove_none_from_union(type_: str, *, use_union_operator: bool) -> str: # noqa: PLR0912
195298
"""Remove None from a Union type string, handling nested unions."""
196299
if use_union_operator:

tests/parser/test_jsonschema.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,3 +1181,45 @@ def test_timestamp_with_time_zone_format() -> None:
11811181

11821182
# Verify the format is mapped correctly
11831183
assert json_schema_data_formats["string"]["timestamp with time zone"] == Types.date_time
1184+
1185+
1186+
@pytest.mark.parametrize(
1187+
("x_python_type", "expected"),
1188+
[
1189+
# Direct matches for special container types
1190+
("Set[str]", {"is_set": True}),
1191+
("set[int]", {"is_set": True}),
1192+
("FrozenSet[int]", {"is_frozen_set": True}),
1193+
("frozenset[str]", {"is_frozen_set": True}),
1194+
("Sequence[str]", {"is_sequence": True}),
1195+
("MutableSequence[int]", {"is_sequence": True}),
1196+
("Mapping[str, int]", {"is_mapping": True}),
1197+
("MutableMapping[str, int]", {"is_mapping": True}),
1198+
("AbstractSet[str]", {"is_frozen_set": True}),
1199+
("MutableSet[int]", {"is_set": True}),
1200+
# Union with special container type
1201+
("Union[Set[str], None]", {"is_set": True}),
1202+
("Optional[FrozenSet[int]]", {"is_frozen_set": True}),
1203+
("Set[int] | None", {"is_set": True}),
1204+
("Sequence[str] | int", {"is_sequence": True}),
1205+
# Union without special container type (loop completes without match)
1206+
("Union[str, int]", {}),
1207+
("str | int", {}),
1208+
("Optional[str]", {}),
1209+
("Union[str, int, float]", {}),
1210+
("Union[List[str], None]", {}), # List is not a special container
1211+
("Optional[Dict[str, int]]", {}), # Dict is not a special container
1212+
# Non-special container types
1213+
("List[str]", {}),
1214+
("Dict[str, int]", {}),
1215+
("str", {}),
1216+
("int", {}),
1217+
("CustomType", {}),
1218+
],
1219+
)
1220+
def test_get_python_type_flags(x_python_type: str, expected: dict[str, bool]) -> None:
1221+
"""Test _get_python_type_flags extracts collection flags correctly."""
1222+
parser = JsonSchemaParser("")
1223+
obj = model_validate(JsonSchemaObject, {"x-python-type": x_python_type})
1224+
result = parser._get_python_type_flags(obj)
1225+
assert result == expected

tests/test_types.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
import pytest
66

7-
from datamodel_code_generator.types import _remove_none_from_union, get_optional_type
7+
from datamodel_code_generator.types import (
8+
_remove_none_from_union,
9+
extract_qualified_names,
10+
get_optional_type,
11+
get_subscript_args,
12+
get_type_base_name,
13+
)
814

915

1016
@pytest.mark.parametrize(
@@ -273,3 +279,104 @@ def test_datatype_deepcopy_memo_cache_hit() -> None:
273279
# Second call with same memo - should return cached object (covers memo hit branch)
274280
copied2 = data_type.__deepcopy__(memo) # noqa: PLC2801
275281
assert copied2 is copied1 # Same object from memo
282+
283+
284+
@pytest.mark.parametrize(
285+
("type_str", "expected"),
286+
[
287+
# Simple types
288+
("str", "str"),
289+
("int", "int"),
290+
("List", "List"),
291+
# Subscripted types
292+
("List[str]", "List"),
293+
("Dict[str, int]", "Dict"),
294+
("Optional[int]", "Optional"),
295+
("Union[str, int]", "Union"),
296+
# Qualified names
297+
("foo.bar.Baz", "Baz"),
298+
("datamodel_code_generator.model.base.DataModel", "DataModel"),
299+
# Subscripted with qualified names
300+
("type[foo.bar.Baz]", "type"),
301+
("List[foo.Bar]", "List"),
302+
# Invalid syntax (fallback to string parsing)
303+
("List[", "List"),
304+
("[invalid", ""), # splits on "[" giving empty string
305+
],
306+
)
307+
def test_get_type_base_name(type_str: str, expected: str) -> None:
308+
"""Test get_type_base_name extracts base type correctly."""
309+
assert get_type_base_name(type_str) == expected
310+
311+
312+
@pytest.mark.parametrize(
313+
("type_str", "expected"),
314+
[
315+
# Simple types (no subscript)
316+
("str", []),
317+
("int", []),
318+
# Single argument
319+
("List[str]", ["str"]),
320+
("Optional[int]", ["int"]),
321+
("type[Foo]", ["Foo"]),
322+
# Multiple arguments
323+
("Dict[str, int]", ["str", "int"]),
324+
("Union[str, int, None]", ["str", "int", "None"]),
325+
("Tuple[int, str, float]", ["int", "str", "float"]),
326+
# Union operator syntax
327+
("str | int", ["str", "int"]),
328+
("str | int | None", ["str", "int", "None"]),
329+
("List[str] | None", ["List[str]", "None"]),
330+
# Complex nested types
331+
("Dict[str, List[int]]", ["str", "List[int]"]),
332+
("Union[List[str], Dict[str, int]]", ["List[str]", "Dict[str, int]"]),
333+
# Qualified names in arguments
334+
("type[foo.bar.Baz]", ["foo.bar.Baz"]),
335+
("Dict[a.B, c.D]", ["a.B", "c.D"]),
336+
# Invalid syntax
337+
("List[", []),
338+
("[invalid", []),
339+
],
340+
)
341+
def test_get_subscript_args(type_str: str, expected: list[str]) -> None:
342+
"""Test get_subscript_args extracts type arguments correctly."""
343+
assert get_subscript_args(type_str) == expected
344+
345+
346+
@pytest.mark.parametrize(
347+
("type_str", "expected"),
348+
[
349+
# No qualified names
350+
("str", []),
351+
("List[str]", []),
352+
("Union[str, int]", []),
353+
# Single qualified name
354+
("foo.Bar", ["foo.Bar"]),
355+
("foo.bar.Baz", ["foo.bar.Baz"]),
356+
("datamodel_code_generator.model.base.DataModel", ["datamodel_code_generator.model.base.DataModel"]),
357+
# Qualified names in subscript
358+
("type[foo.bar.Baz]", ["foo.bar.Baz"]),
359+
("List[foo.Bar]", ["foo.Bar"]),
360+
("Optional[a.b.C]", ["a.b.C"]),
361+
# Multiple qualified names
362+
("Dict[a.B, c.D]", ["a.B", "c.D"]),
363+
("Union[foo.Bar, baz.Qux]", ["foo.Bar", "baz.Qux"]),
364+
# Mixed with simple types
365+
("Dict[str, foo.Bar]", ["foo.Bar"]),
366+
("Union[int, a.B, None]", ["a.B"]),
367+
# Union operator syntax
368+
("foo.Bar | None", ["foo.Bar"]),
369+
("a.B | c.D", ["a.B", "c.D"]),
370+
# Complex nested
371+
("Dict[str, List[foo.Bar]]", ["foo.Bar"]),
372+
("type[datamodel_code_generator.types.DataTypeManager]", ["datamodel_code_generator.types.DataTypeManager"]),
373+
# Attribute on non-Name (function call result) - should not extract
374+
("foo().bar", []),
375+
("func().attr.name", []),
376+
# Invalid syntax
377+
("foo.Bar[", []),
378+
],
379+
)
380+
def test_extract_qualified_names(type_str: str, expected: list[str]) -> None:
381+
"""Test extract_qualified_names finds all fully qualified names."""
382+
assert extract_qualified_names(type_str) == expected

0 commit comments

Comments
 (0)