Skip to content

Commit 5a8cd0d

Browse files
authored
Fix non-string OpenAPI discriminator literals (#3070)
* Fix non-string OpenAPI discriminator literals * Tighten discriminator value types
1 parent 5dcbc09 commit 5a8cd0d

9 files changed

Lines changed: 334 additions & 43 deletions

File tree

src/datamodel_code_generator/parser/base.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __ge__(self, value: Any, /) -> bool: ... # noqa: D105
110110
ModelNames: TypeAlias = set[ModelName]
111111
ModelDeps: TypeAlias = dict[ModelName, set[ModelName]]
112112
OrderIndex: TypeAlias = dict[ModelName, int]
113+
DiscriminatorValue: TypeAlias = str | int | bool
113114

114115
_BUILTIN_NAMES: frozenset[str] = frozenset(name for name in builtins.__dict__ if not name.startswith("_"))
115116
_BUILTIN_NAMES_INTRODUCED_IN: dict[PythonVersion, frozenset[str]] = {
@@ -1514,25 +1515,25 @@ def __extract_inherited_enum(cls, models: list[DataModel]) -> None:
15141515
def _create_discriminator_data_type(
15151516
self,
15161517
enum_source: Enum | None,
1517-
type_names: list[str],
1518+
discriminator_values: list[DiscriminatorValue],
15181519
discriminator_model: DataModel,
15191520
imports: Imports,
15201521
) -> DataType:
15211522
"""Create a data type for discriminator field, using enum literals if available."""
15221523
if enum_source:
15231524
enum_class_name = enum_source.reference.short_name
15241525
enum_member_literals: list[tuple[str, str]] = []
1525-
for value in type_names:
1526+
for value in discriminator_values:
15261527
member = enum_source.find_member(value)
15271528
if member and member.field.name:
15281529
enum_member_literals.append((enum_class_name, member.field.name))
15291530
else: # pragma: no cover
1530-
enum_member_literals.append((enum_class_name, value))
1531+
enum_member_literals.append((enum_class_name, str(value)))
15311532
data_type = self.data_type(enum_member_literals=enum_member_literals)
15321533
if enum_source.module_path != discriminator_model.module_path: # pragma: no cover
15331534
imports.append(Import.from_full_path(enum_source.name))
15341535
else:
1535-
data_type = self.data_type(literals=type_names)
1536+
data_type = self.data_type(literals=discriminator_values)
15361537
return data_type
15371538

15381539
def __apply_discriminator_type( # noqa: PLR0912, PLR0914, PLR0915
@@ -1572,12 +1573,12 @@ def __apply_discriminator_type( # noqa: PLR0912, PLR0914, PLR0915
15721573
): # pragma: no cover
15731574
continue
15741575

1575-
type_names: list[str] = []
1576+
discriminator_values: list[DiscriminatorValue] = []
15761577

15771578
def check_paths(
15781579
model: pydantic_model_v2.BaseModel | Reference,
15791580
mapping: dict[str, str],
1580-
type_names: list[str] = type_names,
1581+
discriminator_values: list[DiscriminatorValue] = discriminator_values,
15811582
) -> None:
15821583
"""Validate discriminator mapping paths for a model."""
15831584
for name, path in mapping.items():
@@ -1589,50 +1590,49 @@ def check_paths(
15891590
t_disc_2 = "/".join(t_disc.split("/")[1:])
15901591
if t_path not in {t_disc, t_disc_2}: # pragma: no branch
15911592
continue
1592-
type_names.append(name)
1593+
discriminator_values.append(name)
1594+
1595+
def get_discriminator_field_value(
1596+
discriminator_field: DataModelFieldBase,
1597+
) -> DiscriminatorValue | None:
1598+
const_value = discriminator_field.extras.get("const")
1599+
if const_value is not None:
1600+
return const_value
1601+
1602+
literals = discriminator_field.data_type.literals
1603+
if len(literals) == 1:
1604+
return literals[0]
1605+
1606+
enum_source = discriminator_field.data_type.find_source(Enum)
1607+
if enum_source and len(enum_source.fields) == 1:
1608+
raw_default = enum_source.fields[0].default
1609+
if isinstance(raw_default, str):
1610+
return raw_default.strip("'\"")
1611+
return raw_default
1612+
return None
15931613

1594-
# First try to get the discriminator value from the const field
15951614
for discriminator_field in discriminator_model.fields:
15961615
if field_name not in {discriminator_field.original_name, discriminator_field.name}:
15971616
continue
1598-
if discriminator_field.extras.get("const"):
1599-
type_names = [discriminator_field.extras["const"]]
1617+
discriminator_value = get_discriminator_field_value(discriminator_field)
1618+
if discriminator_value is not None:
1619+
discriminator_values = [discriminator_value]
16001620
break
16011621

1602-
# If no const value found, try to get it from the mapping
1603-
if not type_names:
1604-
# Check the main discriminator model path
1605-
if mapping:
1606-
check_paths(discriminator_model, mapping) # ty: ignore
1622+
if not discriminator_values and mapping:
1623+
check_paths(discriminator_model, mapping) # ty: ignore
16071624

1608-
# Check the base_classes if they exist
1609-
if len(type_names) == 0:
1610-
for base_class in discriminator_model.base_classes:
1611-
check_paths(base_class.reference, mapping) # ty: ignore
1612-
else:
1613-
for discriminator_field in discriminator_model.fields:
1614-
if field_name not in {discriminator_field.original_name, discriminator_field.name}:
1615-
continue
1625+
if len(discriminator_values) == 0:
1626+
for base_class in discriminator_model.base_classes:
1627+
check_paths(base_class.reference, mapping) # ty: ignore
16161628

1617-
literals = discriminator_field.data_type.literals
1618-
if literals and len(literals) == 1: # pragma: no cover
1619-
type_names = [str(v) for v in literals]
1620-
break
1621-
1622-
enum_source = discriminator_field.data_type.find_source(Enum)
1623-
if enum_source and len(enum_source.fields) == 1:
1624-
first_field = enum_source.fields[0]
1625-
raw_default = first_field.default
1626-
if isinstance(raw_default, str):
1627-
type_names = [raw_default.strip("'\"")]
1628-
else: # pragma: no cover
1629-
type_names = [str(raw_default)]
1630-
break
1629+
if not discriminator_values:
1630+
discriminator_values = [discriminator_model.path.split("/")[-1]]
16311631

1632-
if not type_names:
1633-
type_names = [discriminator_model.path.split("/")[-1]]
1632+
if not discriminator_values:
1633+
discriminator_values = [discriminator_model.path.split("/")[-1]]
16341634

1635-
if not type_names: # pragma: no cover
1635+
if not discriminator_values: # pragma: no cover
16361636
msg = f"Discriminator type is not found. {data_type.reference.path}"
16371637
raise RuntimeError(msg)
16381638

@@ -1666,7 +1666,7 @@ def check_paths(
16661666
continue
16671667
literals = discriminator_field.data_type.literals
16681668
const_value = discriminator_field.extras.get("const")
1669-
expected_value = type_names[0] if type_names else None
1669+
expected_value = discriminator_values[0] if discriminator_values else None
16701670

16711671
# Check if literals match (existing behavior)
16721672
literals_match = len(literals) == 1 and literals[0] == expected_value
@@ -1701,15 +1701,15 @@ def check_paths(
17011701
field_data_type.remove_reference()
17021702

17031703
discriminator_field.data_type = self._create_discriminator_data_type(
1704-
enum_source, type_names, discriminator_model, imports
1704+
enum_source, discriminator_values, discriminator_model, imports
17051705
)
17061706
discriminator_field.data_type.parent = discriminator_field
17071707
discriminator_field.required = True
17081708
imports.append(discriminator_field.imports)
17091709
has_one_literal = True
17101710
if not has_one_literal:
17111711
new_data_type = self._create_discriminator_data_type(
1712-
enum_from_base, type_names, discriminator_model, imports
1712+
enum_from_base, discriminator_values, discriminator_model, imports
17131713
)
17141714
# Handle multiple aliases (Pydantic v2 AliasChoices)
17151715
single_alias: str | None = None
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_integer_mapping.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from enum import IntEnum
8+
from typing import Literal
9+
10+
from pydantic import BaseModel, Field, RootModel
11+
12+
13+
class Kind(IntEnum):
14+
integer_1 = 1
15+
16+
17+
class Foo(BaseModel):
18+
kind: Literal[1]
19+
20+
21+
class Kind1(IntEnum):
22+
integer_2 = 2
23+
24+
25+
class Bar(BaseModel):
26+
kind: Literal[2]
27+
28+
29+
class Base(RootModel[Foo | Bar]):
30+
root: Foo | Bar = Field(..., discriminator='kind')
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_integer_no_mapping.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from enum import IntEnum
8+
from typing import Literal
9+
10+
from pydantic import BaseModel, Field, RootModel
11+
12+
13+
class Kind(IntEnum):
14+
integer_1 = 1
15+
16+
17+
class Foo(BaseModel):
18+
kind: Literal[1]
19+
20+
21+
class Kind1(IntEnum):
22+
integer_2 = 2
23+
24+
25+
class Bar(BaseModel):
26+
kind: Literal[2]
27+
28+
29+
class Base(RootModel[Foo | Bar]):
30+
root: Foo | Bar = Field(..., discriminator='kind')
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_integer_no_mapping.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Literal
8+
9+
from pydantic import BaseModel, Field, RootModel
10+
11+
12+
class Foo(BaseModel):
13+
kind: Literal[1]
14+
15+
16+
class Bar(BaseModel):
17+
kind: Literal[2]
18+
19+
20+
class Base(RootModel[Foo | Bar]):
21+
root: Foo | Bar = Field(..., discriminator='kind')
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_partial_mapping.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Literal
8+
9+
from pydantic import BaseModel, Field
10+
11+
12+
class BaseItem(BaseModel):
13+
itemType: str
14+
15+
16+
class FooItem(BaseItem):
17+
fooValue: str | None = None
18+
itemType: Literal['foo']
19+
20+
21+
class BarItem(BaseItem):
22+
barValue: int | None = None
23+
itemType: Literal['BarItem']
24+
25+
26+
class ItemContainer(BaseModel):
27+
item: FooItem | BarItem = Field(..., discriminator='itemType')
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
openapi: 3.0.3
2+
info:
3+
title: Minimal Integer Discriminator
4+
version: 1.0.0
5+
paths: {}
6+
components:
7+
schemas:
8+
Base:
9+
oneOf:
10+
- $ref: '#/components/schemas/Foo'
11+
- $ref: '#/components/schemas/Bar'
12+
discriminator:
13+
propertyName: kind
14+
mapping:
15+
'1': '#/components/schemas/Foo'
16+
'2': '#/components/schemas/Bar'
17+
Foo:
18+
type: object
19+
properties:
20+
kind:
21+
type: integer
22+
enum:
23+
- 1
24+
required:
25+
- kind
26+
Bar:
27+
type: object
28+
properties:
29+
kind:
30+
type: integer
31+
enum:
32+
- 2
33+
required:
34+
- kind
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
openapi: 3.0.3
2+
info:
3+
title: Minimal Integer Discriminator
4+
version: 1.0.0
5+
paths: {}
6+
components:
7+
schemas:
8+
Base:
9+
oneOf:
10+
- $ref: '#/components/schemas/Foo'
11+
- $ref: '#/components/schemas/Bar'
12+
discriminator:
13+
propertyName: kind
14+
Foo:
15+
type: object
16+
properties:
17+
kind:
18+
type: integer
19+
enum:
20+
- 1
21+
required:
22+
- kind
23+
Bar:
24+
type: object
25+
properties:
26+
kind:
27+
type: integer
28+
enum:
29+
- 2
30+
required:
31+
- kind
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
openapi: 3.1.0
2+
info:
3+
title: Test
4+
version: 0.0.0
5+
paths:
6+
/item:
7+
get:
8+
responses:
9+
'200':
10+
description: Item
11+
content:
12+
application/json:
13+
schema:
14+
$ref: '#/components/schemas/ItemContainer'
15+
components:
16+
schemas:
17+
ItemContainer:
18+
type: object
19+
required:
20+
- item
21+
properties:
22+
item:
23+
$ref: '#/components/schemas/BaseItem'
24+
BaseItem:
25+
type: object
26+
required:
27+
- itemType
28+
properties:
29+
itemType:
30+
type: string
31+
discriminator:
32+
propertyName: itemType
33+
mapping:
34+
foo: '#/components/schemas/FooItem'
35+
FooItem:
36+
allOf:
37+
- $ref: '#/components/schemas/BaseItem'
38+
- type: object
39+
properties:
40+
fooValue:
41+
type: string
42+
BarItem:
43+
allOf:
44+
- $ref: '#/components/schemas/BaseItem'
45+
- type: object
46+
properties:
47+
barValue:
48+
type: integer

0 commit comments

Comments
 (0)