Skip to content

Commit b84474f

Browse files
authored
Fix RootModel generation order to define referenced types first (#2592)
* Fix RootModel ordering and add tests for model reference handling * Refactor RootModel classes to include 'friends' attribute and ensure model rebuilding
1 parent 96506f6 commit b84474f

5 files changed

Lines changed: 249 additions & 6 deletions

File tree

src/datamodel_code_generator/parser/base.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,19 @@ def sort_data_models( # noqa: PLR0912, PLR0915
208208
ordered_models: list[tuple[int, DataModel]] = []
209209
unresolved_reference_model_names = [m.path for m in unresolved_references]
210210
for model in unresolved_references:
211-
indexes = [
212-
unresolved_reference_model_names.index(b.reference.path)
213-
for b in model.base_classes
214-
if b.reference and b.reference.path in unresolved_reference_model_names
215-
]
211+
if isinstance(model, pydantic_model_v2.RootModel):
212+
indexes = [
213+
unresolved_reference_model_names.index(ref_path)
214+
for f in model.fields
215+
for t in f.data_type.all_data_types
216+
if t.reference and (ref_path := t.reference.path) in unresolved_reference_model_names
217+
]
218+
else:
219+
indexes = [
220+
unresolved_reference_model_names.index(b.reference.path)
221+
for b in model.base_classes
222+
if b.reference and b.reference.path in unresolved_reference_model_names
223+
]
216224
if indexes:
217225
ordered_models.append((
218226
max(indexes),
@@ -1389,7 +1397,7 @@ def __sort_models(
13891397
base_class_refs = {b.type_hint for b in model.base_classes if b.reference}
13901398
if base_class_refs:
13911399
refs = base_class_refs - {class_name}
1392-
elif isinstance(model, TypeAliasBase):
1400+
elif isinstance(model, (TypeAliasBase, pydantic_model_v2.RootModel)):
13931401
refs = {
13941402
t.reference.short_name for f in model.fields for t in f.data_type.all_data_types if t.reference
13951403
} - {class_name}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# generated by datamodel-codegen:
2+
# filename: root_model_ordering.json
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import List, Literal, Optional, Union
8+
9+
from pydantic import BaseModel, Field, RootModel
10+
11+
12+
class Zoo(BaseModel):
13+
animals: Optional[List[Animals]] = Field([], title='Animals')
14+
15+
16+
class Dog(BaseModel):
17+
name: Literal['dog'] = Field('dog', title='woof')
18+
friends: Optional[List[Friends]] = Field([], title='Friends')
19+
20+
21+
class Cat(BaseModel):
22+
name: Literal['cat'] = Field('cat', title='meow')
23+
friends: Optional[List[Friends]] = Field([], title='Friends')
24+
25+
26+
class Bird(BaseModel):
27+
name: Literal['bird'] = Field('bird', title='chirp')
28+
friends: Optional[List[Friends]] = Field([], title='Friends')
29+
30+
31+
class Animals(RootModel[Union[Dog, Cat, Bird]]):
32+
root: Union[Dog, Cat, Bird] = Field(..., discriminator='name', title='Animal')
33+
34+
35+
class Friends(RootModel[Union[Dog, Cat, Bird]]):
36+
root: Union[Dog, Cat, Bird] = Field(..., discriminator='name', title='Animal')
37+
38+
39+
Zoo.model_rebuild()
40+
Dog.model_rebuild()
41+
Cat.model_rebuild()
42+
Bird.model_rebuild()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# generated by datamodel-codegen:
2+
# filename: root_model_ordering.json
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import List, Literal, Optional, Union
8+
9+
from pydantic import BaseModel, Field, RootModel
10+
11+
12+
class Bird(BaseModel):
13+
name: Literal['bird'] = Field('bird', title='chirp')
14+
friends: Optional[List[Friends]] = Field([], title='Friends')
15+
16+
17+
class Cat(BaseModel):
18+
name: Literal['cat'] = Field('cat', title='meow')
19+
friends: Optional[List[Friends]] = Field([], title='Friends')
20+
21+
22+
class Dog(BaseModel):
23+
name: Literal['dog'] = Field('dog', title='woof')
24+
friends: Optional[List[Friends]] = Field([], title='Friends')
25+
26+
27+
class Friends(RootModel[Union[Dog, Cat, Bird]]):
28+
root: Union[Dog, Cat, Bird] = Field(..., discriminator='name', title='Animal')
29+
30+
31+
class Zoo(BaseModel):
32+
animals: Optional[List[Animals]] = Field([], title='Animals')
33+
34+
35+
class Animals(RootModel[Union[Dog, Cat, Bird]]):
36+
root: Union[Dog, Cat, Bird] = Field(..., discriminator='name', title='Animal')
37+
38+
39+
Bird.model_rebuild()
40+
Cat.model_rebuild()
41+
Dog.model_rebuild()
42+
Zoo.model_rebuild()
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
{
2+
"$defs": {
3+
"Dog": {
4+
"properties": {
5+
"name": {
6+
"const": "dog",
7+
"default": "dog",
8+
"title": "woof",
9+
"type": "string"
10+
},
11+
"friends": {
12+
"items": {
13+
"discriminator": {
14+
"mapping": {
15+
"bird": "#/$defs/Bird",
16+
"cat": "#/$defs/Cat",
17+
"dog": "#/$defs/Dog"
18+
},
19+
"propertyName": "name"
20+
},
21+
"oneOf": [
22+
{"$ref": "#/$defs/Dog"},
23+
{"$ref": "#/$defs/Cat"},
24+
{"$ref": "#/$defs/Bird"}
25+
],
26+
"title": "Animal"
27+
},
28+
"title": "Friends",
29+
"type": "array",
30+
"default": []
31+
}
32+
},
33+
"title": "Dog",
34+
"type": "object"
35+
},
36+
"Cat": {
37+
"properties": {
38+
"name": {
39+
"const": "cat",
40+
"default": "cat",
41+
"title": "meow",
42+
"type": "string"
43+
},
44+
"friends": {
45+
"items": {
46+
"discriminator": {
47+
"mapping": {
48+
"bird": "#/$defs/Bird",
49+
"cat": "#/$defs/Cat",
50+
"dog": "#/$defs/Dog"
51+
},
52+
"propertyName": "name"
53+
},
54+
"oneOf": [
55+
{"$ref": "#/$defs/Dog"},
56+
{"$ref": "#/$defs/Cat"},
57+
{"$ref": "#/$defs/Bird"}
58+
],
59+
"title": "Animal"
60+
},
61+
"title": "Friends",
62+
"type": "array",
63+
"default": []
64+
}
65+
},
66+
"title": "Cat",
67+
"type": "object"
68+
},
69+
"Bird": {
70+
"properties": {
71+
"name": {
72+
"const": "bird",
73+
"default": "bird",
74+
"title": "chirp",
75+
"type": "string"
76+
},
77+
"friends": {
78+
"items": {
79+
"discriminator": {
80+
"mapping": {
81+
"bird": "#/$defs/Bird",
82+
"cat": "#/$defs/Cat",
83+
"dog": "#/$defs/Dog"
84+
},
85+
"propertyName": "name"
86+
},
87+
"oneOf": [
88+
{"$ref": "#/$defs/Dog"},
89+
{"$ref": "#/$defs/Cat"},
90+
{"$ref": "#/$defs/Bird"}
91+
],
92+
"title": "Animal"
93+
},
94+
"title": "Friends",
95+
"type": "array",
96+
"default": []
97+
}
98+
},
99+
"title": "Bird",
100+
"type": "object"
101+
}
102+
},
103+
"properties": {
104+
"animals": {
105+
"default": [],
106+
"items": {
107+
"discriminator": {
108+
"mapping": {
109+
"bird": "#/$defs/Bird",
110+
"cat": "#/$defs/Cat",
111+
"dog": "#/$defs/Dog"
112+
},
113+
"propertyName": "name"
114+
},
115+
"oneOf": [
116+
{"$ref": "#/$defs/Dog"},
117+
{"$ref": "#/$defs/Cat"},
118+
{"$ref": "#/$defs/Bird"}
119+
],
120+
"title": "Animal"
121+
},
122+
"title": "Animals",
123+
"type": "array"
124+
}
125+
},
126+
"title": "Zoo",
127+
"type": "object"
128+
}

tests/main/jsonschema/test_main_jsonschema.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,6 +1505,29 @@ def test_main_jsonschema_combine_any_of_object(
15051505
)
15061506

15071507

1508+
@pytest.mark.benchmark
1509+
@pytest.mark.parametrize(
1510+
("extra_args", "expected_file"),
1511+
[
1512+
(["--output-model", "pydantic_v2.BaseModel"], "jsonschema_root_model_ordering.py"),
1513+
(
1514+
["--output-model", "pydantic_v2.BaseModel", "--keep-model-order"],
1515+
"jsonschema_root_model_ordering_keep_model_order.py",
1516+
),
1517+
],
1518+
)
1519+
def test_main_jsonschema_root_model_ordering(output_file: Path, extra_args: list[str], expected_file: str) -> None:
1520+
"""Test RootModel is ordered after the types it references."""
1521+
run_main_and_assert(
1522+
input_path=JSON_SCHEMA_DATA_PATH / "root_model_ordering.json",
1523+
output_path=output_file,
1524+
input_file_type="jsonschema",
1525+
assert_func=assert_file_content,
1526+
expected_file=expected_file,
1527+
extra_args=extra_args,
1528+
)
1529+
1530+
15081531
@pytest.mark.benchmark
15091532
def test_main_jsonschema_field_include_all_keys(output_file: Path) -> None:
15101533
"""Test field generation including all keys."""

0 commit comments

Comments
 (0)