Skip to content

Commit 52c9b92

Browse files
authored
Merge branch 'main' into required-nullable-annotated
2 parents eb05935 + af3c45e commit 52c9b92

52 files changed

Lines changed: 801 additions & 187 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/root-model-and-type-alias.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,23 @@ When a schema defines a simple type (not an object with properties), `datamodel-
1010
- A RootModel or type alias is also generated for the main schema, allowing you to define a single type alias from a schema file (e.g. `model.json` containing `{"title": "MyString", "type": "string"}`)
1111
- Type aliases cannot be combined with `Annotated` for Pydantic v1
1212

13-
## Pydantic v1 vs v2
13+
## Type Alias Behavior by Output Type and Python Version
1414

15-
The type of type alias generated depends on the Pydantic version:
15+
The type of type alias generated depends on the output model type and target Python version:
1616

17-
- **Pydantic v2**: Uses `TypeAliasType` (Python 3.9-3.11) or `type` statement (Python 3.12+)
18-
- **Pydantic v1**: Uses `TypeAlias` because Pydantic v1 cannot handle `TypeAliasType` objects
17+
| Output Type | Python 3.12+ | Python 3.10-3.11 | Python 3.9 |
18+
|-------------|--------------|------------------|------------|
19+
| **Pydantic v2** | `type` statement | `TypeAliasType` (typing_extensions) | `TypeAliasType` (typing_extensions) |
20+
| **Pydantic v1** | `TypeAlias` | `TypeAlias` | `TypeAlias` (typing_extensions) |
21+
| **TypedDict** | `type` statement | `TypeAlias` | `TypeAlias` (typing_extensions) |
22+
| **dataclasses** | `type` statement | `TypeAlias` | `TypeAlias` (typing_extensions) |
23+
| **msgspec** | `type` statement | `TypeAlias` | `TypeAlias` (typing_extensions) |
24+
25+
**Why the difference?**
26+
27+
- **Pydantic v2** requires `TypeAliasType` because it cannot properly handle `TypeAlias` annotations
28+
- **Other output types** (TypedDict, dataclasses, msgspec) use `TypeAlias` for better compatibility with libraries that may not expect `TypeAliasType` objects
29+
- **Python 3.12+** uses the native `type` statement for all output types
1930

2031
## Example
2132

src/datamodel_code_generator/model/__init__.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,28 +55,28 @@ def get_data_model_types(
5555
)
5656
from .types import DataTypeManager # noqa: PLC0415
5757

58-
# Pydantic v1 does not support TypeAliasType type, fallback to TypeAlias
59-
if data_model_type == DataModelType.PydanticBaseModel:
60-
if target_python_version.has_type_alias:
61-
# Python 3.10+: typing.TypeAlias
62-
type_alias_class = type_alias.TypeAlias
63-
scalar_class = scalar.DataTypeScalar
64-
union_class = union.DataTypeUnion
58+
# Pydantic v2 requires TypeAliasType; other output types use TypeAlias for better compatibility
59+
if data_model_type == DataModelType.PydanticV2BaseModel:
60+
if target_python_version.has_type_statement:
61+
type_alias_class = type_alias.TypeStatement
62+
scalar_class = scalar.DataTypeScalarTypeStatement
63+
union_class = union.DataTypeUnionTypeStatement
6564
else:
66-
# Python 3.9: typing_extensions.TypeAlias
67-
type_alias_class = type_alias.TypeAliasBackport
68-
scalar_class = scalar.DataTypeScalarBackport
69-
union_class = union.DataTypeUnionBackport
65+
type_alias_class = type_alias.TypeAliasTypeBackport
66+
scalar_class = scalar.DataTypeScalarTypeBackport
67+
union_class = union.DataTypeUnionTypeBackport
7068
elif target_python_version.has_type_statement:
71-
# Python 3.12+ with Pydantic v2 or other formats: Use type statement
7269
type_alias_class = type_alias.TypeStatement
7370
scalar_class = scalar.DataTypeScalarTypeStatement
7471
union_class = union.DataTypeUnionTypeStatement
72+
elif target_python_version.has_type_alias:
73+
type_alias_class = type_alias.TypeAlias
74+
scalar_class = scalar.DataTypeScalar
75+
union_class = union.DataTypeUnion
7576
else:
76-
# Python 3.9-3.11 with Pydantic v2 or other formats: Use TypeAliasType
77-
type_alias_class = type_alias.TypeAliasTypeBackport
78-
scalar_class = scalar.DataTypeScalarTypeBackport
79-
union_class = union.DataTypeUnionTypeBackport
77+
type_alias_class = type_alias.TypeAliasBackport
78+
scalar_class = scalar.DataTypeScalarBackport
79+
union_class = union.DataTypeUnionBackport
8080

8181
if data_model_type == DataModelType.PydanticBaseModel:
8282
return DataModelSet(

src/datamodel_code_generator/model/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ class DataModel(TemplateBase, Nullable, ABC):
340340
TEMPLATE_FILE_PATH: ClassVar[str] = ""
341341
BASE_CLASS: ClassVar[str] = ""
342342
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = ()
343+
IS_ALIAS: bool = False
343344

344345
def __init__( # noqa: PLR0913
345346
self,
@@ -523,6 +524,11 @@ def all_data_types(self) -> Iterator[DataType]:
523524
yield from field.data_type.all_data_types
524525
yield from self.base_classes
525526

527+
@property
528+
def is_alias(self) -> bool:
529+
"""Whether is a type alias (i.e. not an instance of BaseModel/RootModel)."""
530+
return self.IS_ALIAS
531+
526532
@property
527533
def nullable(self) -> bool:
528534
"""Check if this model is nullable."""

src/datamodel_code_generator/model/type_alias.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
class TypeAliasBase(DataModel):
2323
"""Base class for all type alias implementations."""
2424

25+
IS_ALIAS: bool = True
26+
2527
@property
2628
def imports(self) -> tuple[Import, ...]:
2729
"""Get imports including Annotated if needed."""

src/datamodel_code_generator/parser/base.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,28 @@ def dump_templates(templates: list[DataModel]) -> str:
132132
MAX_RECURSION_COUNT: int = sys.getrecursionlimit()
133133

134134

135+
def add_model_path_to_list(
136+
paths: list[str] | None,
137+
model: DataModel,
138+
/,
139+
) -> list[str]:
140+
"""
141+
Auxiliary method which adds model path to list, provided the following hold.
142+
143+
- model is not a type alias
144+
- path is not already in the list.
145+
146+
"""
147+
if paths is None:
148+
paths = []
149+
if model.is_alias:
150+
return paths
151+
if (path := model.path) in paths:
152+
return paths
153+
paths.append(path)
154+
return paths
155+
156+
135157
def sort_data_models( # noqa: PLR0912, PLR0915
136158
unsorted_data_models: list[DataModel],
137159
sorted_data_models: SortedDataModels | None = None,
@@ -151,13 +173,13 @@ def sort_data_models( # noqa: PLR0912, PLR0915
151173
sorted_data_models[model.path] = model
152174
elif model.path in model.reference_classes and len(model.reference_classes) == 1: # only self-referencing
153175
sorted_data_models[model.path] = model
154-
require_update_action_models.append(model.path)
176+
add_model_path_to_list(require_update_action_models, model)
155177
elif (
156178
not model.reference_classes - {model.path} - set(sorted_data_models)
157179
): # reference classes have been resolved
158180
sorted_data_models[model.path] = model
159181
if model.path in model.reference_classes:
160-
require_update_action_models.append(model.path)
182+
add_model_path_to_list(require_update_action_models, model)
161183
else:
162184
unresolved_references.append(model)
163185
if unresolved_references:
@@ -206,11 +228,11 @@ def sort_data_models( # noqa: PLR0912, PLR0915
206228
if not unresolved_model:
207229
sorted_data_models[model.path] = model
208230
if update_action_parent:
209-
require_update_action_models.append(model.path)
231+
add_model_path_to_list(require_update_action_models, model)
210232
continue
211233
if not unresolved_model - unsorted_data_model_names:
212234
sorted_data_models[model.path] = model
213-
require_update_action_models.append(model.path)
235+
add_model_path_to_list(require_update_action_models, model)
214236
continue
215237
# unresolved
216238
unresolved_classes = ", ".join(
@@ -1003,7 +1025,7 @@ def __reuse_model(self, models: list[DataModel], require_update_action_models: l
10031025
custom_template_dir=model._custom_template_dir, # noqa: SLF001
10041026
)
10051027
if cached_model_reference.path in require_update_action_models:
1006-
require_update_action_models.append(inherited_model.path)
1028+
add_model_path_to_list(require_update_action_models, inherited_model)
10071029
models.insert(index, inherited_model)
10081030
models.remove(model)
10091031

@@ -1341,6 +1363,39 @@ def __alias_shadowed_imports( # noqa: PLR6301
13411363
reference_path=model_field.data_type.import_.reference_path,
13421364
)
13431365

1366+
@classmethod
1367+
def _collect_used_names_from_models(cls, models: list[DataModel]) -> set[str]:
1368+
"""Collect identifiers referenced by models before rendering."""
1369+
names: set[str] = set()
1370+
1371+
def add(name: str | None) -> None:
1372+
if not name:
1373+
return
1374+
# first segment is sufficient to match import target or alias
1375+
names.add(name.split(".")[0])
1376+
1377+
def walk_data_type(data_type: DataType) -> None:
1378+
add(data_type.alias or data_type.type)
1379+
if data_type.reference:
1380+
add(data_type.reference.short_name)
1381+
for child in data_type.data_types:
1382+
walk_data_type(child)
1383+
if data_type.dict_key:
1384+
walk_data_type(data_type.dict_key)
1385+
1386+
for model in models:
1387+
add(model.class_name)
1388+
add(model.duplicate_class_name)
1389+
for base in model.base_classes:
1390+
add(base.type_hint)
1391+
for import_ in model.imports:
1392+
add(import_.alias or import_.import_.split(".")[-1])
1393+
for field in model.fields:
1394+
add(field.name)
1395+
add(field.alias)
1396+
walk_data_type(field.data_type)
1397+
return names
1398+
13441399
def parse( # noqa: PLR0912, PLR0914, PLR0915
13451400
self,
13461401
with_import: bool | None = True, # noqa: FBT001, FBT002
@@ -1466,12 +1521,14 @@ class Processed(NamedTuple):
14661521

14671522
for processed_model in processed_models:
14681523
# postprocess imports to remove unused imports.
1469-
model_code = str("\n".join([str(m) for m in processed_model.models]))
1524+
used_names = self._collect_used_names_from_models(processed_model.models)
14701525
unused_imports = [
14711526
(from_, import_)
14721527
for from_, imports_ in processed_model.imports.items()
14731528
for import_ in imports_
1474-
if import_ not in model_code
1529+
if not {processed_model.imports.alias.get(from_, {}).get(import_, import_), import_}.intersection(
1530+
used_names
1531+
)
14751532
]
14761533
for from_, import_ in unused_imports:
14771534
processed_model.imports.remove(Import(from_=from_, import_=import_))

src/datamodel_code_generator/parser/openapi.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,8 @@ def __init__( # noqa: PLR0913
344344
)
345345
self.open_api_scopes: list[OpenAPIScope] = openapi_scopes or [OpenAPIScope.Schemas]
346346
self.include_path_parameters: bool = include_path_parameters
347+
self._discriminator_schemas: dict[str, dict[str, Any]] = {}
348+
self._discriminator_subtypes: dict[str, list[str]] = defaultdict(list)
347349

348350
def get_ref_model(self, ref: str) -> dict[str, Any]:
349351
"""Resolve a reference to its model definition."""
@@ -362,6 +364,49 @@ def get_data_type(self, obj: JsonSchemaObject) -> DataType:
362364

363365
return super().get_data_type(obj)
364366

367+
def _get_discriminator_union_type(self, ref: str) -> DataType | None:
368+
"""Create a union type for discriminator subtypes if available."""
369+
subtypes = self._discriminator_subtypes.get(ref, [])
370+
if not subtypes:
371+
return None
372+
refs = map(self.model_resolver.add_ref, subtypes)
373+
return self.data_type(data_types=[self.data_type(reference=r) for r in refs])
374+
375+
def get_ref_data_type(self, ref: str) -> DataType:
376+
"""Get data type for a reference, handling discriminator polymorphism."""
377+
if ref in self._discriminator_schemas and (union_type := self._get_discriminator_union_type(ref)):
378+
return union_type
379+
return super().get_ref_data_type(ref)
380+
381+
def parse_object_fields(
382+
self,
383+
obj: JsonSchemaObject,
384+
path: list[str],
385+
module_name: Optional[str] = None, # noqa: UP045
386+
) -> list[DataModelFieldBase]:
387+
"""Parse object fields, adding discriminator info for allOf polymorphism."""
388+
fields = super().parse_object_fields(obj, path, module_name)
389+
properties = obj.properties or {}
390+
391+
result_fields: list[DataModelFieldBase] = []
392+
for field_obj in fields:
393+
field = properties.get(field_obj.original_name)
394+
395+
if (
396+
isinstance(field, JsonSchemaObject)
397+
and field.ref
398+
and (discriminator := self._discriminator_schemas.get(field.ref))
399+
):
400+
new_field_type = self._get_discriminator_union_type(field.ref) or field_obj.data_type
401+
field_obj = self.data_model_field_type(**{ # noqa: PLW2901
402+
**field_obj.__dict__,
403+
"data_type": new_field_type,
404+
"extras": {**field_obj.extras, "discriminator": discriminator},
405+
})
406+
result_fields.append(field_obj)
407+
408+
return result_fields
409+
365410
def resolve_object(self, obj: ReferenceObject | BaseModelT, object_type: type[BaseModelT]) -> BaseModelT:
366411
"""Resolve a reference object to its actual type or return the object as-is."""
367412
if isinstance(obj, ReferenceObject):
@@ -651,6 +696,7 @@ def parse_raw(self) -> None: # noqa: PLR0912, PLR0915
651696

652697
specification: dict[str, Any] = load_yaml_dict(source.text)
653698
self.raw_obj = specification
699+
self._collect_discriminator_schemas()
654700
schemas: dict[str, Any] = specification.get("components", {}).get("schemas", {})
655701
security: list[dict[str, list[str]]] | None = specification.get("security")
656702
if OpenAPIScope.Schemas in self.open_api_scopes:
@@ -727,3 +773,25 @@ def parse_raw(self) -> None: # noqa: PLR0912, PLR0915
727773
)
728774

729775
self._resolve_unparsed_json_pointer()
776+
777+
def _collect_discriminator_schemas(self) -> None:
778+
"""Collect schemas with discriminators but no oneOf/anyOf, and find their subtypes."""
779+
schemas: dict[str, Any] = self.raw_obj.get("components", {}).get("schemas", {})
780+
781+
for schema_name, schema in schemas.items():
782+
discriminator = schema.get("discriminator")
783+
if not discriminator:
784+
continue
785+
786+
if schema.get("oneOf") or schema.get("anyOf"):
787+
continue
788+
789+
ref = f"#/components/schemas/{schema_name}"
790+
self._discriminator_schemas[ref] = discriminator
791+
792+
for schema_name, schema in schemas.items():
793+
for all_of_item in schema.get("allOf", []):
794+
ref_in_allof = all_of_item.get("$ref")
795+
if ref_in_allof and ref_in_allof in self._discriminator_schemas:
796+
subtype_ref = f"#/components/schemas/{schema_name}"
797+
self._discriminator_subtypes[ref_in_allof].append(subtype_ref)

tests/data/expected/main/graphql/simple_star_wars_dataclass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,27 @@
77
from dataclasses import dataclass
88
from typing import List, Literal, Optional
99

10-
from typing_extensions import TypeAliasType
10+
from typing_extensions import TypeAlias
1111

12-
Boolean = TypeAliasType("Boolean", bool)
12+
Boolean: TypeAlias = bool
1313
"""
1414
The `Boolean` scalar type represents `true` or `false`.
1515
"""
1616

1717

18-
ID = TypeAliasType("ID", str)
18+
ID: TypeAlias = str
1919
"""
2020
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
2121
"""
2222

2323

24-
Int = TypeAliasType("Int", int)
24+
Int: TypeAlias = int
2525
"""
2626
The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1.
2727
"""
2828

2929

30-
String = TypeAliasType("String", str)
30+
String: TypeAlias = str
3131
"""
3232
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
3333
"""

tests/data/expected/main/graphql/simple_star_wars_dataclass_arguments.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,27 @@
77
from dataclasses import dataclass
88
from typing import List, Literal, Optional
99

10-
from typing_extensions import TypeAliasType
10+
from typing_extensions import TypeAlias
1111

12-
Boolean = TypeAliasType("Boolean", bool)
12+
Boolean: TypeAlias = bool
1313
"""
1414
The `Boolean` scalar type represents `true` or `false`.
1515
"""
1616

1717

18-
ID = TypeAliasType("ID", str)
18+
ID: TypeAlias = str
1919
"""
2020
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
2121
"""
2222

2323

24-
Int = TypeAliasType("Int", int)
24+
Int: TypeAlias = int
2525
"""
2626
The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1.
2727
"""
2828

2929

30-
String = TypeAliasType("String", str)
30+
String: TypeAlias = str
3131
"""
3232
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
3333
"""

0 commit comments

Comments
 (0)