Skip to content

Commit 6e4fb3a

Browse files
authored
Enhance TypeAlias handling to maintain model order with forward references (#2560)
1 parent 926021e commit 6e4fb3a

4 files changed

Lines changed: 105 additions & 7 deletions

File tree

src/datamodel_code_generator/parser/base.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
DataModelFieldBase,
4848
)
4949
from datamodel_code_generator.model.enum import Enum, Member
50+
from datamodel_code_generator.model.type_alias import TypeAliasBase
5051
from datamodel_code_generator.parser import DefaultPutDict, LiteralType
5152
from datamodel_code_generator.reference import ModelResolver, Reference
5253
from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes
@@ -1174,22 +1175,28 @@ def __sort_models(
11741175
models.sort(key=lambda x: x.class_name)
11751176

11761177
imported = {i for v in imports.values() for i in v}
1177-
model_class_name_baseclasses: dict[DataModel, tuple[str, set[str]]] = {}
1178+
model_class_name_refs: dict[DataModel, tuple[str, set[str]]] = {}
11781179
for model in models:
11791180
class_name = model.class_name
1180-
model_class_name_baseclasses[model] = (
1181-
class_name,
1182-
{b.type_hint for b in model.base_classes if b.reference} - {class_name},
1183-
)
1181+
base_class_refs = {b.type_hint for b in model.base_classes if b.reference}
1182+
if base_class_refs:
1183+
refs = base_class_refs - {class_name}
1184+
elif isinstance(model, TypeAliasBase):
1185+
refs = {
1186+
t.reference.short_name for f in model.fields for t in f.data_type.all_data_types if t.reference
1187+
} - {class_name}
1188+
else:
1189+
refs = set()
1190+
model_class_name_refs[model] = (class_name, refs)
11841191

11851192
changed: bool = True
11861193
while changed:
11871194
changed = False
11881195
resolved = imported.copy()
11891196
for i in range(len(models) - 1):
11901197
model = models[i]
1191-
class_name, baseclasses = model_class_name_baseclasses[model]
1192-
if not baseclasses - resolved:
1198+
class_name, refs = model_class_name_refs[model]
1199+
if not refs - resolved:
11931200
resolved.add(class_name)
11941201
continue
11951202
models[i], models[i + 1] = models[i + 1], model
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# generated by datamodel-codegen:
2+
# filename: type_alias_forward_ref.json
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Literal, TypedDict
8+
9+
from typing_extensions import TypeAliasType
10+
11+
12+
class BlobPart(TypedDict):
13+
type: Literal['blob']
14+
data: str
15+
16+
17+
FieldPlaceholder = TypeAliasType("FieldPlaceholder", None)
18+
19+
20+
class TextPart(TypedDict):
21+
type: Literal['text']
22+
content: str
23+
24+
25+
SystemInstructions = TypeAliasType("SystemInstructions", list[TextPart | BlobPart])
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
{
2+
"$schema": "http://json-schema.org/draft-07/schema#",
3+
"title": "_Placeholder",
4+
"type": "null",
5+
"$defs": {
6+
"TextPart": {
7+
"properties": {
8+
"type": {
9+
"const": "text",
10+
"type": "string"
11+
},
12+
"content": {
13+
"type": "string"
14+
}
15+
},
16+
"required": ["type", "content"],
17+
"title": "TextPart",
18+
"type": "object"
19+
},
20+
"BlobPart": {
21+
"properties": {
22+
"type": {
23+
"const": "blob",
24+
"type": "string"
25+
},
26+
"data": {
27+
"type": "string"
28+
}
29+
},
30+
"required": ["type", "data"],
31+
"title": "BlobPart",
32+
"type": "object"
33+
},
34+
"SystemInstructions": {
35+
"type": "array",
36+
"items": {
37+
"anyOf": [
38+
{"$ref": "#/$defs/TextPart"},
39+
{"$ref": "#/$defs/BlobPart"}
40+
]
41+
},
42+
"title": "SystemInstructions"
43+
}
44+
}
45+
}

tests/main/jsonschema/test_main_jsonschema.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,27 @@ def test_main_inheritance_forward_ref_keep_model_order(output_file: Path, tmp_pa
6464
)
6565

6666

67+
@pytest.mark.benchmark
68+
def test_main_type_alias_forward_ref_keep_model_order(output_file: Path) -> None:
69+
"""Test TypeAliasType with forward references keeping model order."""
70+
run_main_and_assert(
71+
input_path=JSON_SCHEMA_DATA_PATH / "type_alias_forward_ref.json",
72+
output_path=output_file,
73+
input_file_type=None,
74+
assert_func=assert_file_content,
75+
extra_args=[
76+
"--keep-model-order",
77+
"--output-model-type",
78+
"typing.TypedDict",
79+
"--use-standard-collections",
80+
"--use-union-operator",
81+
"--use-type-alias",
82+
"--target-python-version",
83+
"3.10",
84+
],
85+
)
86+
87+
6788
@pytest.mark.skip(reason="pytest-xdist does not support the test")
6889
def test_main_without_arguments() -> None:
6990
"""Test main function without arguments raises SystemExit."""

0 commit comments

Comments
 (0)