Skip to content

Commit 2df42fb

Browse files
authored
Fix discriminator to use enum value instead of model name for single-value enums (#2652)
* Fix discriminator handling for single-value enums and add related tests * Fix type name handling for single-value literals in discriminator logic
1 parent 35871c2 commit 2df42fb

8 files changed

Lines changed: 263 additions & 1 deletion

File tree

src/datamodel_code_generator/parser/base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,27 @@ def check_paths(
10861086
for base_class in discriminator_model.base_classes:
10871087
check_paths(base_class.reference, mapping) # pyright: ignore[reportArgumentType]
10881088
else:
1089-
type_names = [discriminator_model.path.split("/")[-1]]
1089+
for discriminator_field in discriminator_model.fields:
1090+
if field_name not in {discriminator_field.original_name, discriminator_field.name}:
1091+
continue
1092+
1093+
literals = discriminator_field.data_type.literals
1094+
if literals and len(literals) == 1: # pragma: no cover
1095+
type_names = [str(v) for v in literals]
1096+
break
1097+
1098+
enum_source = discriminator_field.data_type.find_source(Enum)
1099+
if enum_source and len(enum_source.fields) == 1:
1100+
first_field = enum_source.fields[0]
1101+
raw_default = first_field.default
1102+
if isinstance(raw_default, str):
1103+
type_names = [raw_default.strip("'\"")]
1104+
else: # pragma: no cover
1105+
type_names = [str(raw_default)]
1106+
break
1107+
1108+
if not type_names:
1109+
type_names = [discriminator_model.path.split("/")[-1]]
10901110

10911111
if not type_names: # pragma: no cover
10921112
msg = f"Discriminator type is not found. {data_type.reference.path}"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_enum_single_value.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from enum import Enum
8+
9+
from pydantic import BaseModel, Field, RootModel
10+
11+
12+
class ToolType(Enum):
13+
function = 'function'
14+
15+
16+
class ToolBase(BaseModel):
17+
type: ToolType
18+
19+
20+
class FunctionToolCall(ToolBase):
21+
id: str
22+
23+
24+
class ToolCall(RootModel[FunctionToolCall]):
25+
root: FunctionToolCall = Field(..., discriminator='type')
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_enum_single_value_anyof.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from enum import Enum
8+
from typing import Literal, Union
9+
10+
from pydantic import BaseModel, Field, RootModel
11+
12+
13+
class ToolType(Enum):
14+
function = 'function'
15+
16+
17+
class FunctionToolCall(BaseModel):
18+
id: str
19+
type: Literal['function']
20+
21+
22+
class CustomToolCall(BaseModel):
23+
type: Literal['CustomToolCall']
24+
25+
26+
class ToolCallUnion(RootModel[Union[FunctionToolCall, CustomToolCall]]):
27+
root: Union[FunctionToolCall, CustomToolCall] = Field(..., discriminator='type')
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_enum_single_value_anyof.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from enum import Enum
8+
from typing import Literal, Union
9+
10+
from pydantic import BaseModel, Field, RootModel
11+
12+
13+
class ToolType(Enum):
14+
function = 'function'
15+
16+
17+
class FunctionToolCall(BaseModel):
18+
id: str
19+
type: Literal[ToolType.function]
20+
21+
22+
class CustomToolCall(BaseModel):
23+
type: Literal['CustomToolCall']
24+
25+
26+
class ToolCallUnion(RootModel[Union[FunctionToolCall, CustomToolCall]]):
27+
root: Union[FunctionToolCall, CustomToolCall] = Field(..., discriminator='type')
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_enum_single_value.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from enum import Enum
8+
9+
from pydantic import BaseModel, Field, RootModel
10+
11+
12+
class ToolType(Enum):
13+
function = 'function'
14+
15+
16+
class ToolBase(BaseModel):
17+
type: ToolType
18+
19+
20+
class FunctionToolCall(ToolBase):
21+
id: str
22+
23+
24+
class ToolCall(RootModel[FunctionToolCall]):
25+
root: FunctionToolCall = Field(..., discriminator='type')
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
openapi: "3.0.0"
2+
components:
3+
schemas:
4+
ToolCall:
5+
oneOf:
6+
- $ref: '#/components/schemas/FunctionToolCall'
7+
discriminator:
8+
propertyName: type
9+
10+
ToolType:
11+
type: string
12+
enum:
13+
- function
14+
15+
ToolBase:
16+
type: object
17+
properties:
18+
type:
19+
$ref: '#/components/schemas/ToolType'
20+
required:
21+
- type
22+
23+
FunctionToolCall:
24+
allOf:
25+
- $ref: '#/components/schemas/ToolBase'
26+
type: object
27+
properties:
28+
id:
29+
type: string
30+
required:
31+
- id
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
openapi: "3.0.0"
2+
components:
3+
schemas:
4+
ToolCallUnion:
5+
anyOf:
6+
- $ref: '#/components/schemas/FunctionToolCall'
7+
- $ref: '#/components/schemas/CustomToolCall'
8+
discriminator:
9+
propertyName: type
10+
11+
ToolType:
12+
type: string
13+
enum:
14+
- function
15+
16+
FunctionToolCall:
17+
type: object
18+
properties:
19+
id:
20+
type: string
21+
type:
22+
$ref: '#/components/schemas/ToolType'
23+
required:
24+
- id
25+
- type
26+
27+
CustomToolCall:
28+
type: object
29+
properties:
30+
type:
31+
type: string

tests/main/openapi/test_main_openapi.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,82 @@ def test_main_openapi_discriminator_enum_duplicate(output_file: Path) -> None:
133133
)
134134

135135

136+
@pytest.mark.skipif(
137+
black.__version__.split(".")[0] == "19",
138+
reason="Installed black doesn't support the old style",
139+
)
140+
def test_main_openapi_discriminator_enum_single_value(output_file: Path) -> None:
141+
"""Single-value enum discriminator with allOf inheritance."""
142+
run_main_and_assert(
143+
input_path=OPEN_API_DATA_PATH / "discriminator_enum_single_value.yaml",
144+
output_path=output_file,
145+
input_file_type="openapi",
146+
assert_func=assert_file_content,
147+
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "enum_single_value.py",
148+
extra_args=["--target-python-version", "3.10", "--output-model-type", "pydantic_v2.BaseModel"],
149+
)
150+
151+
152+
@pytest.mark.skipif(
153+
black.__version__.split(".")[0] == "19",
154+
reason="Installed black doesn't support the old style",
155+
)
156+
def test_main_openapi_discriminator_enum_single_value_use_enum(output_file: Path) -> None:
157+
"""Single-value enum with allOf + --use-enum-values-in-discriminator."""
158+
run_main_and_assert(
159+
input_path=OPEN_API_DATA_PATH / "discriminator_enum_single_value.yaml",
160+
output_path=output_file,
161+
input_file_type="openapi",
162+
assert_func=assert_file_content,
163+
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "enum_single_value_use_enum.py",
164+
extra_args=[
165+
"--target-python-version",
166+
"3.10",
167+
"--output-model-type",
168+
"pydantic_v2.BaseModel",
169+
"--use-enum-values-in-discriminator",
170+
],
171+
)
172+
173+
174+
@pytest.mark.skipif(
175+
black.__version__.split(".")[0] == "19",
176+
reason="Installed black doesn't support the old style",
177+
)
178+
def test_main_openapi_discriminator_enum_single_value_anyof(output_file: Path) -> None:
179+
"""Single-value enum discriminator with anyOf - uses enum value, not model name."""
180+
run_main_and_assert(
181+
input_path=OPEN_API_DATA_PATH / "discriminator_enum_single_value_anyof.yaml",
182+
output_path=output_file,
183+
input_file_type="openapi",
184+
assert_func=assert_file_content,
185+
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "enum_single_value_anyof.py",
186+
extra_args=["--target-python-version", "3.10", "--output-model-type", "pydantic_v2.BaseModel"],
187+
)
188+
189+
190+
@pytest.mark.skipif(
191+
black.__version__.split(".")[0] == "19",
192+
reason="Installed black doesn't support the old style",
193+
)
194+
def test_main_openapi_discriminator_enum_single_value_anyof_use_enum(output_file: Path) -> None:
195+
"""Single-value enum with anyOf + --use-enum-values-in-discriminator."""
196+
run_main_and_assert(
197+
input_path=OPEN_API_DATA_PATH / "discriminator_enum_single_value_anyof.yaml",
198+
output_path=output_file,
199+
input_file_type="openapi",
200+
assert_func=assert_file_content,
201+
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "enum_single_value_anyof_use_enum.py",
202+
extra_args=[
203+
"--target-python-version",
204+
"3.10",
205+
"--output-model-type",
206+
"pydantic_v2.BaseModel",
207+
"--use-enum-values-in-discriminator",
208+
],
209+
)
210+
211+
136212
def test_main_openapi_discriminator_with_properties(output_file: Path) -> None:
137213
"""Test OpenAPI generation with discriminator properties."""
138214
run_main_and_assert(

0 commit comments

Comments
 (0)