Skip to content

Commit af3c45e

Browse files
zdenecekilovelinuxkoxudaxi
authored
Support allOf polymorphism with discriminator in OpenAPI (#2530)
* Support allOf polymorphism in openapi * move logic to openapi.py * Apply suggestions from code review Co-authored-by: Antonio Spadaro <ilovelinux@users.noreply.github.com> * refactor parse_object_fields for code reuse * remove target python version from test * Refactor discriminator handling to improve polymorphism support * Enhance discriminator handling for allOf polymorphism support * Add test for OpenAPI generation with discriminator and no allOf subtypes * Update src/datamodel_code_generator/parser/openapi.py --------- Co-authored-by: Antonio Spadaro <ilovelinux@users.noreply.github.com> Co-authored-by: Koudai Aono <koxudaxi@gmail.com>
1 parent b083256 commit af3c45e

6 files changed

Lines changed: 275 additions & 0 deletions

File tree

src/datamodel_code_generator/parser/openapi.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,8 @@ def __init__( # noqa: PLR0913
344344
)
345345
self.open_api_scopes: list[OpenAPIScope] = openapi_scopes or [OpenAPIScope.Schemas]
346346
self.include_path_parameters: bool = include_path_parameters
347+
self._discriminator_schemas: dict[str, dict[str, Any]] = {}
348+
self._discriminator_subtypes: dict[str, list[str]] = defaultdict(list)
347349

348350
def get_ref_model(self, ref: str) -> dict[str, Any]:
349351
"""Resolve a reference to its model definition."""
@@ -362,6 +364,49 @@ def get_data_type(self, obj: JsonSchemaObject) -> DataType:
362364

363365
return super().get_data_type(obj)
364366

367+
def _get_discriminator_union_type(self, ref: str) -> DataType | None:
368+
"""Create a union type for discriminator subtypes if available."""
369+
subtypes = self._discriminator_subtypes.get(ref, [])
370+
if not subtypes:
371+
return None
372+
refs = map(self.model_resolver.add_ref, subtypes)
373+
return self.data_type(data_types=[self.data_type(reference=r) for r in refs])
374+
375+
def get_ref_data_type(self, ref: str) -> DataType:
376+
"""Get data type for a reference, handling discriminator polymorphism."""
377+
if ref in self._discriminator_schemas and (union_type := self._get_discriminator_union_type(ref)):
378+
return union_type
379+
return super().get_ref_data_type(ref)
380+
381+
def parse_object_fields(
382+
self,
383+
obj: JsonSchemaObject,
384+
path: list[str],
385+
module_name: Optional[str] = None, # noqa: UP045
386+
) -> list[DataModelFieldBase]:
387+
"""Parse object fields, adding discriminator info for allOf polymorphism."""
388+
fields = super().parse_object_fields(obj, path, module_name)
389+
properties = obj.properties or {}
390+
391+
result_fields: list[DataModelFieldBase] = []
392+
for field_obj in fields:
393+
field = properties.get(field_obj.original_name)
394+
395+
if (
396+
isinstance(field, JsonSchemaObject)
397+
and field.ref
398+
and (discriminator := self._discriminator_schemas.get(field.ref))
399+
):
400+
new_field_type = self._get_discriminator_union_type(field.ref) or field_obj.data_type
401+
field_obj = self.data_model_field_type(**{ # noqa: PLW2901
402+
**field_obj.__dict__,
403+
"data_type": new_field_type,
404+
"extras": {**field_obj.extras, "discriminator": discriminator},
405+
})
406+
result_fields.append(field_obj)
407+
408+
return result_fields
409+
365410
def resolve_object(self, obj: ReferenceObject | BaseModelT, object_type: type[BaseModelT]) -> BaseModelT:
366411
"""Resolve a reference object to its actual type or return the object as-is."""
367412
if isinstance(obj, ReferenceObject):
@@ -651,6 +696,7 @@ def parse_raw(self) -> None: # noqa: PLR0912, PLR0915
651696

652697
specification: dict[str, Any] = load_yaml_dict(source.text)
653698
self.raw_obj = specification
699+
self._collect_discriminator_schemas()
654700
schemas: dict[str, Any] = specification.get("components", {}).get("schemas", {})
655701
security: list[dict[str, list[str]]] | None = specification.get("security")
656702
if OpenAPIScope.Schemas in self.open_api_scopes:
@@ -727,3 +773,25 @@ def parse_raw(self) -> None: # noqa: PLR0912, PLR0915
727773
)
728774

729775
self._resolve_unparsed_json_pointer()
776+
777+
def _collect_discriminator_schemas(self) -> None:
778+
"""Collect schemas with discriminators but no oneOf/anyOf, and find their subtypes."""
779+
schemas: dict[str, Any] = self.raw_obj.get("components", {}).get("schemas", {})
780+
781+
for schema_name, schema in schemas.items():
782+
discriminator = schema.get("discriminator")
783+
if not discriminator:
784+
continue
785+
786+
if schema.get("oneOf") or schema.get("anyOf"):
787+
continue
788+
789+
ref = f"#/components/schemas/{schema_name}"
790+
self._discriminator_schemas[ref] = discriminator
791+
792+
for schema_name, schema in schemas.items():
793+
for all_of_item in schema.get("allOf", []):
794+
ref_in_allof = all_of_item.get("$ref")
795+
if ref_in_allof and ref_in_allof in self._discriminator_schemas:
796+
subtype_ref = f"#/components/schemas/{schema_name}"
797+
self._discriminator_subtypes[ref_in_allof].append(subtype_ref)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_allof.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Annotated, Literal
8+
9+
from pydantic import BaseModel, Field
10+
11+
12+
class Pet(BaseModel):
13+
pet_type: Annotated[str, Field(alias='petType')]
14+
15+
16+
class Cat(Pet):
17+
name: str | None = None
18+
pet_type: Literal['cat'] = Field(..., alias='petType')
19+
20+
21+
class Dog(Pet):
22+
bark: str | None = None
23+
pet_type: Literal['dog'] = Field(..., alias='petType')
24+
25+
26+
class Lizard(Pet):
27+
loves_rocks: Annotated[bool | None, Field(alias='lovesRocks')] = None
28+
pet_type: Literal['lizard'] = Field(..., alias='petType')
29+
30+
31+
class PetContainer(BaseModel):
32+
pet: Annotated[Cat | Dog | Lizard, Field(discriminator='pet_type')]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_allof_no_subtypes.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Optional
8+
9+
from pydantic import BaseModel, Field
10+
11+
12+
class BaseItem(BaseModel):
13+
itemType: str
14+
15+
16+
class FooItem(BaseModel):
17+
fooValue: Optional[str] = None
18+
19+
20+
class BarItem(BaseModel):
21+
barValue: Optional[int] = None
22+
23+
24+
class ItemContainer(BaseModel):
25+
item: BaseItem = Field(..., discriminator='itemType')
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Example from https://spec.openapis.org/oas/v3.1.1.html#examples-1
2+
# This tests discriminator without oneOf/anyOf, where subtypes use allOf
3+
4+
openapi: 3.1.0
5+
info:
6+
title: Test
7+
description: "Test API"
8+
version: 0.0.0
9+
paths:
10+
/pet:
11+
get:
12+
responses:
13+
'200':
14+
description: "Pet"
15+
content:
16+
application/json:
17+
schema:
18+
$ref: "#/components/schemas/PetContainer"
19+
components:
20+
schemas:
21+
PetContainer:
22+
type: object
23+
required:
24+
- pet
25+
properties:
26+
pet:
27+
$ref: "#/components/schemas/Pet"
28+
Pet:
29+
type: object
30+
required:
31+
- petType
32+
properties:
33+
petType:
34+
type: string
35+
discriminator:
36+
propertyName: petType
37+
mapping:
38+
cat: "#/components/schemas/Cat"
39+
dog: "#/components/schemas/Dog"
40+
lizard: "#/components/schemas/Lizard"
41+
Cat:
42+
allOf:
43+
- $ref: "#/components/schemas/Pet"
44+
- type: object
45+
properties:
46+
name:
47+
type: string
48+
Dog:
49+
allOf:
50+
- $ref: "#/components/schemas/Pet"
51+
- type: object
52+
properties:
53+
bark:
54+
type: string
55+
Lizard:
56+
allOf:
57+
- $ref: "#/components/schemas/Pet"
58+
- type: object
59+
properties:
60+
lovesRocks:
61+
type: boolean
62+
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Tests discriminator without any allOf subtypes
2+
# This tests the edge case where a schema has a discriminator but nothing inherits from it
3+
4+
openapi: 3.1.0
5+
info:
6+
title: Test
7+
description: "Test API"
8+
version: 0.0.0
9+
paths:
10+
/item:
11+
get:
12+
responses:
13+
'200':
14+
description: "Item"
15+
content:
16+
application/json:
17+
schema:
18+
$ref: "#/components/schemas/ItemContainer"
19+
components:
20+
schemas:
21+
ItemContainer:
22+
type: object
23+
required:
24+
- item
25+
properties:
26+
item:
27+
$ref: "#/components/schemas/BaseItem"
28+
BaseItem:
29+
type: object
30+
required:
31+
- itemType
32+
properties:
33+
itemType:
34+
type: string
35+
discriminator:
36+
propertyName: itemType
37+
mapping:
38+
foo: "#/components/schemas/FooItem"
39+
bar: "#/components/schemas/BarItem"
40+
# These schemas exist but don't use allOf to inherit from BaseItem
41+
FooItem:
42+
type: object
43+
properties:
44+
fooValue:
45+
type: string
46+
BarItem:
47+
type: object
48+
properties:
49+
barValue:
50+
type: integer

tests/main/openapi/test_main_openapi.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,44 @@ def test_main_openapi_discriminator_with_properties(output_file: Path) -> None:
9797
)
9898

9999

100+
def test_main_openapi_discriminator_allof(output_file: Path) -> None:
101+
"""Test OpenAPI generation with allOf discriminator polymorphism."""
102+
run_main_and_assert(
103+
input_path=OPEN_API_DATA_PATH / "discriminator_allof.yaml",
104+
output_path=output_file,
105+
input_file_type="openapi",
106+
assert_func=assert_file_content,
107+
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "allof.py",
108+
extra_args=[
109+
"--output-model-type",
110+
"pydantic_v2.BaseModel",
111+
"--snake-case-field",
112+
"--use-annotated",
113+
"--use-union-operator",
114+
"--collapse-root-models",
115+
],
116+
)
117+
118+
119+
def test_main_openapi_discriminator_allof_no_subtypes(output_file: Path) -> None:
120+
"""Test OpenAPI generation with discriminator but no allOf subtypes.
121+
122+
This tests the edge case where a schema has a discriminator but nothing
123+
inherits from it using allOf.
124+
"""
125+
run_main_and_assert(
126+
input_path=OPEN_API_DATA_PATH / "discriminator_allof_no_subtypes.yaml",
127+
output_path=output_file,
128+
input_file_type="openapi",
129+
assert_func=assert_file_content,
130+
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "allof_no_subtypes.py",
131+
extra_args=[
132+
"--output-model-type",
133+
"pydantic_v2.BaseModel",
134+
],
135+
)
136+
137+
100138
def test_main_pydantic_basemodel(output_file: Path) -> None:
101139
"""Test OpenAPI generation with Pydantic BaseModel output."""
102140
run_main_and_assert(

0 commit comments

Comments
 (0)