Skip to content

Commit b7ef8c7

Browse files
raj-openkoxudaxi
andauthored
Bugfix-TypeAlias-Rebuild ---> Main (#2566)
* bugfix-type-alias-rebuild > main: added linting exceptions for docstrings * bugfix-type-alias-rebuild > main: added property to indicate if TypeAlias NOTE: we could also bypass this and just check the instance of the model, but this is more intentional. * bugfix-type-alias-rebuild > main: do not add model path to update list if model is a type alias * bugfix-type-alias-rebuild > main: added tests for coverage * bugfix-type-alias-rebuild > main: added changes from review * bugfix-type-alias-rebuild > main: added missing case for coverage * bugfix-type-alias-rebuild > main: removed type declaration for QA * bugfix-type-alias-rebuild > main: added e2e test scenarios * bugfix-type-alias-rebuild > main: added e2e test cases * bugfix-type-alias-rebuild > main: commented out case as recursive types not yet corrected for py311 * bugfix-type-alias-rebuild > main: linting * Revert "bugfix-type-alias-rebuild > main: added linting exceptions for docstrings" This reverts commit 4ba7085. * bugfix-type-alias-rebuild > main: linting * bugfix-type-alias-rebuild > main: linting --------- Co-authored-by: raj-open <raj-open@users.noreply.github.com> Co-authored-by: Koudai Aono <koxudaxi@gmail.com>
1 parent 4c27f8d commit b7ef8c7

8 files changed

Lines changed: 264 additions & 5 deletions

File tree

src/datamodel_code_generator/model/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ class DataModel(TemplateBase, Nullable, ABC):
340340
TEMPLATE_FILE_PATH: ClassVar[str] = ""
341341
BASE_CLASS: ClassVar[str] = ""
342342
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = ()
343+
IS_ALIAS: bool = False
343344

344345
def __init__( # noqa: PLR0913
345346
self,
@@ -523,6 +524,11 @@ def all_data_types(self) -> Iterator[DataType]:
523524
yield from field.data_type.all_data_types
524525
yield from self.base_classes
525526

527+
@property
528+
def is_alias(self) -> bool:
529+
"""Whether is a type alias (i.e. not an instance of BaseModel/RootModel)."""
530+
return self.IS_ALIAS
531+
526532
@property
527533
def nullable(self) -> bool:
528534
"""Check if this model is nullable."""

src/datamodel_code_generator/model/type_alias.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
class TypeAliasBase(DataModel):
2323
"""Base class for all type alias implementations."""
2424

25+
IS_ALIAS: bool = True
26+
2527
@property
2628
def imports(self) -> tuple[Import, ...]:
2729
"""Get imports including Annotated if needed."""

src/datamodel_code_generator/parser/base.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,28 @@ def dump_templates(templates: list[DataModel]) -> str:
132132
MAX_RECURSION_COUNT: int = sys.getrecursionlimit()
133133

134134

135+
def add_model_path_to_list(
136+
paths: list[str] | None,
137+
model: DataModel,
138+
/,
139+
) -> list[str]:
140+
"""
141+
Auxiliary method which adds model path to list, provided the following hold.
142+
143+
- model is not a type alias
144+
- path is not already in the list.
145+
146+
"""
147+
if paths is None:
148+
paths = []
149+
if model.is_alias:
150+
return paths
151+
if (path := model.path) in paths:
152+
return paths
153+
paths.append(path)
154+
return paths
155+
156+
135157
def sort_data_models( # noqa: PLR0912, PLR0915
136158
unsorted_data_models: list[DataModel],
137159
sorted_data_models: SortedDataModels | None = None,
@@ -151,13 +173,13 @@ def sort_data_models( # noqa: PLR0912, PLR0915
151173
sorted_data_models[model.path] = model
152174
elif model.path in model.reference_classes and len(model.reference_classes) == 1: # only self-referencing
153175
sorted_data_models[model.path] = model
154-
require_update_action_models.append(model.path)
176+
add_model_path_to_list(require_update_action_models, model)
155177
elif (
156178
not model.reference_classes - {model.path} - set(sorted_data_models)
157179
): # reference classes have been resolved
158180
sorted_data_models[model.path] = model
159181
if model.path in model.reference_classes:
160-
require_update_action_models.append(model.path)
182+
add_model_path_to_list(require_update_action_models, model)
161183
else:
162184
unresolved_references.append(model)
163185
if unresolved_references:
@@ -206,11 +228,11 @@ def sort_data_models( # noqa: PLR0912, PLR0915
206228
if not unresolved_model:
207229
sorted_data_models[model.path] = model
208230
if update_action_parent:
209-
require_update_action_models.append(model.path)
231+
add_model_path_to_list(require_update_action_models, model)
210232
continue
211233
if not unresolved_model - unsorted_data_model_names:
212234
sorted_data_models[model.path] = model
213-
require_update_action_models.append(model.path)
235+
add_model_path_to_list(require_update_action_models, model)
214236
continue
215237
# unresolved
216238
unresolved_classes = ", ".join(
@@ -1003,7 +1025,7 @@ def __reuse_model(self, models: list[DataModel], require_update_action_models: l
10031025
custom_template_dir=model._custom_template_dir, # noqa: SLF001
10041026
)
10051027
if cached_model_reference.path in require_update_action_models:
1006-
require_update_action_models.append(inherited_model.path)
1028+
add_model_path_to_list(require_update_action_models, inherited_model)
10071029
models.insert(index, inherited_model)
10081030
models.remove(model)
10091031

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# generated by datamodel-codegen:
2+
# filename: type_alias_recursive.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Dict, List, Optional, Union
8+
9+
from pydantic import BaseModel
10+
from typing_extensions import TypeAlias
11+
12+
13+
class File(BaseModel):
14+
path: str
15+
16+
17+
class Folder(BaseModel):
18+
address: Optional[str] = None
19+
files: List[File]
20+
subfolders: Optional[List[Folder]] = None
21+
22+
23+
ElementaryType: TypeAlias = Optional[Union[bool, str, int, float]]
24+
25+
26+
JsonType: TypeAlias = Union[ElementaryType, List["JsonType"], Dict[str, "JsonType"]]
27+
28+
29+
class Space(BaseModel):
30+
label: Optional[str] = None
31+
data: Optional[JsonType] = None
32+
dual: Optional[DualSpace] = None
33+
34+
35+
class DualSpace(BaseModel):
36+
label: Optional[str] = None
37+
data: Optional[JsonType] = None
38+
predual: Optional[Space] = None
39+
40+
41+
Folder.update_forward_refs()
42+
Space.update_forward_refs()
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# generated by datamodel-codegen:
2+
# filename: type_alias_recursive.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from pydantic import BaseModel
8+
9+
10+
class File(BaseModel):
11+
path: str
12+
13+
14+
class Folder(BaseModel):
15+
address: str | None = None
16+
files: list[File]
17+
subfolders: list[Folder] | None = None
18+
19+
20+
type ElementaryType = bool | str | int | float | None
21+
22+
23+
type JsonType = ElementaryType | list[JsonType] | dict[str, JsonType]
24+
25+
26+
class Space(BaseModel):
27+
label: str | None = None
28+
data: JsonType | None = None
29+
dual: DualSpace | None = None
30+
31+
32+
class DualSpace(BaseModel):
33+
label: str | None = None
34+
data: JsonType | None = None
35+
predual: Space | None = None
36+
37+
38+
Folder.model_rebuild()
39+
Space.model_rebuild()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
openapi: "3.0.0"
2+
info:
3+
version: 1.0.0
4+
title: TypeAlias Test
5+
description: |-
6+
Test cases for recursive models
7+
components:
8+
schemas:
9+
10+
# ----------------------------------------------------------------
11+
# Simple self-referential recursive types
12+
# ----------------------------------------------------------------
13+
14+
File:
15+
type: object
16+
required:
17+
- path
18+
properties:
19+
path:
20+
type: string
21+
22+
Folder:
23+
type: object
24+
required:
25+
- files
26+
- folders
27+
properties:
28+
address:
29+
type: string
30+
files:
31+
type: array
32+
items:
33+
$ref: "#/components/schemas/File"
34+
subfolders:
35+
type: array
36+
items:
37+
$ref: "#/components/schemas/Folder"
38+
39+
ElementaryType:
40+
nullable: true
41+
oneOf:
42+
- type: boolean
43+
- type: string
44+
- type: integer
45+
- type: number
46+
47+
JsonType:
48+
oneOf:
49+
- $ref: "#/components/schemas/ElementaryType"
50+
- type: array
51+
items:
52+
$ref: "#/components/schemas/JsonType"
53+
- type: object
54+
additionalProperties:
55+
$ref: "#/components/schemas/JsonType"
56+
57+
# ----------------------------------------------------------------
58+
# Binary recursive types
59+
# ----------------------------------------------------------------
60+
61+
Space:
62+
type: object
63+
properties:
64+
label:
65+
type: string
66+
data:
67+
$ref: "#/components/schemas/JsonType"
68+
dual:
69+
$ref: "#/components/schemas/DualSpace"
70+
71+
DualSpace:
72+
type: object
73+
properties:
74+
label:
75+
type: string
76+
data:
77+
$ref: "#/components/schemas/JsonType"
78+
predual:
79+
$ref: "#/components/schemas/Space"

tests/main/openapi/test_main_openapi.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,34 @@ def test_main_openapi_type_alias_py312(output_file: Path) -> None:
21302130
)
21312131

21322132

2133+
@pytest.mark.skipif(
2134+
int(black.__version__.split(".")[0]) < 23,
2135+
reason="Installed black doesn't support the new 'type' statement",
2136+
)
2137+
def test_main_openapi_type_alias_recursive_py312(output_file: Path) -> None:
2138+
"""
2139+
Test that handling of type aliases work as expected for recursive types.
2140+
2141+
NOTE: applied to python 3.12--14
2142+
"""
2143+
run_main_and_assert(
2144+
input_path=OPEN_API_DATA_PATH / "type_alias_recursive.yaml",
2145+
output_path=output_file,
2146+
input_file_type="openapi",
2147+
assert_func=assert_file_content,
2148+
expected_file="type_alias_recursive_py312.py",
2149+
extra_args=[
2150+
"--use-type-alias",
2151+
"--target-python-version",
2152+
"3.12",
2153+
"--use-standard-collections",
2154+
"--use-union-operator",
2155+
"--output-model-type",
2156+
"pydantic_v2.BaseModel",
2157+
],
2158+
)
2159+
2160+
21332161
def test_main_openapi_byte_format(output_file: Path) -> None:
21342162
"""Test OpenAPI generation with byte format."""
21352163
run_main_and_assert(

tests/parser/test_base.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
from datamodel_code_generator.model import DataModel, DataModelFieldBase
1111
from datamodel_code_generator.model.pydantic import BaseModel, DataModelField
12+
from datamodel_code_generator.model.type_alias import TypeAlias, TypeAliasBackport, TypeAliasTypeBackport, TypeStatement
1213
from datamodel_code_generator.parser.base import (
1314
Parser,
15+
add_model_path_to_list,
1416
escape_characters,
1517
exact_import,
1618
relative,
@@ -54,6 +56,45 @@ def test_parser() -> None:
5456
assert c.base_class == "Base"
5557

5658

59+
def test_add_model_path_to_list() -> None:
60+
"""Test method which adds model paths to "update" list."""
61+
reference_1 = Reference(path="Base1", original_name="A", name="A")
62+
reference_2 = Reference(path="Alias2", original_name="B", name="B")
63+
reference_3 = Reference(path="Alias3", original_name="B", name="B")
64+
reference_4 = Reference(path="Alias4", original_name="B", name="B")
65+
reference_5 = Reference(path="Alias5", original_name="B", name="B")
66+
model1 = BaseModel(fields=[], reference=reference_1)
67+
model2 = TypeAlias(fields=[], reference=reference_2)
68+
model3 = TypeAliasBackport(fields=[], reference=reference_3)
69+
model4 = TypeAliasTypeBackport(fields=[], reference=reference_4)
70+
model5 = TypeStatement(fields=[], reference=reference_5)
71+
72+
paths = add_model_path_to_list(None, model1)
73+
assert "Base1" in paths
74+
assert len(paths) == 1
75+
76+
paths = list[str]()
77+
add_model_path_to_list(paths, model1)
78+
assert "Base1" in paths
79+
assert len(paths) == 1
80+
81+
add_model_path_to_list(paths, model1)
82+
assert len(paths) != 2
83+
assert len(paths) == 1
84+
85+
add_model_path_to_list(paths, model2)
86+
assert "Alias2" not in paths
87+
88+
add_model_path_to_list(paths, model3)
89+
assert "Alias3" not in paths
90+
91+
add_model_path_to_list(paths, model4)
92+
assert "Alias4" not in paths
93+
94+
add_model_path_to_list(paths, model5)
95+
assert "Alias5" not in paths
96+
97+
5798
def test_sort_data_models() -> None:
5899
"""Test sorting data models by dependencies."""
59100
reference_a = Reference(path="A", original_name="A", name="A")

0 commit comments

Comments
 (0)