Skip to content

Commit 94b081c

Browse files
authored
Use Enum Members in Discriminators with Safe Imports (#2609)
* Feat: Add support for using enum member literals in discriminator fields * Feat: Implement enum member imports for discriminator literals and add tests * Refactor: Remove deprecated test for enum values in discriminator literals * Refactor: Clean up code with pragma comments for better coverage tracking * Refactor: Add pragma comments to improve coverage tracking in base.py * Refactor: Simplify discriminator data type creation and enhance enum source resolution
1 parent 3930969 commit 94b081c

14 files changed

Lines changed: 249 additions & 3 deletions

File tree

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ Typing customization:
381381
times.
382382
--use-annotated Use typing.Annotated for Field(). Also, `--field-constraints` option
383383
will be enabled.
384+
--use-enum-values-in-discriminator
385+
Use enum member literals in discriminator fields instead of string
386+
literals
384387
--use-generic-container-types
385388
Use generic container types for type hinting (typing.Sequence,
386389
typing.Mapping). If `--use-standard-collections` option is set, then

docs/index.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,9 @@ Typing customization:
373373
times.
374374
--use-annotated Use typing.Annotated for Field(). Also, `--field-constraints` option
375375
will be enabled.
376+
--use-enum-values-in-discriminator
377+
Use enum member literals in discriminator fields instead of string
378+
literals
376379
--use-generic-container-types
377380
Use generic container types for type hinting (typing.Sequence,
378381
typing.Mapping). If `--use-standard-collections` option is set, then

src/datamodel_code_generator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
394394
encoding: str = "utf-8",
395395
enum_field_as_literal: LiteralType | None = None,
396396
use_one_literal_as_default: bool = False,
397+
use_enum_values_in_discriminator: bool = False,
397398
set_default_enum_member: bool = False,
398399
use_subclass_enum: bool = False,
399400
use_specialized_enum: bool = True,
@@ -631,6 +632,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
631632
if output_model_type == DataModelType.TypingTypedDict
632633
else enum_field_as_literal,
633634
use_one_literal_as_default=use_one_literal_as_default,
635+
use_enum_values_in_discriminator=use_enum_values_in_discriminator,
634636
set_default_enum_member=True
635637
if output_model_type == DataModelType.DataclassesDataclass
636638
else set_default_enum_member,

src/datamodel_code_generator/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def validate_all_exports_collision_strategy(cls, values: dict[str, Any]) -> dict
398398
encoding: str = DEFAULT_ENCODING
399399
enum_field_as_literal: Optional[LiteralType] = None # noqa: UP045
400400
use_one_literal_as_default: bool = False
401+
use_enum_values_in_discriminator: bool = False
401402
set_default_enum_member: bool = False
402403
use_subclass_enum: bool = False
403404
use_specialized_enum: bool = True
@@ -803,6 +804,7 @@ def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912,
803804
encoding=config.encoding,
804805
enum_field_as_literal=config.enum_field_as_literal,
805806
use_one_literal_as_default=config.use_one_literal_as_default,
807+
use_enum_values_in_discriminator=config.use_enum_values_in_discriminator,
806808
set_default_enum_member=config.set_default_enum_member,
807809
use_subclass_enum=config.use_subclass_enum,
808810
use_specialized_enum=config.use_specialized_enum,

src/datamodel_code_generator/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,12 @@ def start_section(self, heading: str | None) -> None:
373373
action="store_true",
374374
default=None,
375375
)
376+
typing_options.add_argument(
377+
"--use-enum-values-in-discriminator",
378+
help="Use enum member literals in discriminator fields instead of string literals",
379+
action="store_true",
380+
default=None,
381+
)
376382
typing_options.add_argument(
377383
"--use-standard-collections",
378384
help="Use standard collections for type hinting (list, dict)",

src/datamodel_code_generator/parser/base.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ def __init__( # noqa: PLR0913, PLR0915
498498
capitalise_enum_members: bool = False,
499499
keep_model_order: bool = False,
500500
use_one_literal_as_default: bool = False,
501+
use_enum_values_in_discriminator: bool = False,
501502
known_third_party: list[str] | None = None,
502503
custom_formatters: list[str] | None = None,
503504
custom_formatters_kwargs: dict[str, Any] | None = None,
@@ -636,6 +637,7 @@ def __init__( # noqa: PLR0913, PLR0915
636637
self.capitalise_enum_members = capitalise_enum_members
637638
self.keep_model_order = keep_model_order
638639
self.use_one_literal_as_default = use_one_literal_as_default
640+
self.use_enum_values_in_discriminator = use_enum_values_in_discriminator
639641
self.known_third_party = known_third_party
640642
self.custom_formatter = custom_formatters
641643
self.custom_formatters_kwargs = custom_formatters_kwargs
@@ -912,6 +914,30 @@ def __extract_inherited_enum(cls, models: list[DataModel]) -> None:
912914
)
913915
models.remove(model)
914916

917+
def _create_discriminator_data_type(
918+
self,
919+
enum_source: Enum | None,
920+
type_names: list[str],
921+
discriminator_model: DataModel,
922+
imports: Imports,
923+
) -> DataType:
924+
"""Create a data type for discriminator field, using enum literals if available."""
925+
if enum_source:
926+
enum_class_name = enum_source.reference.short_name
927+
enum_member_literals: list[tuple[str, str]] = []
928+
for value in type_names:
929+
member = enum_source.find_member(value)
930+
if member and member.field.name:
931+
enum_member_literals.append((enum_class_name, member.field.name))
932+
else: # pragma: no cover
933+
enum_member_literals.append((enum_class_name, value))
934+
data_type = self.data_type(enum_member_literals=enum_member_literals)
935+
if enum_source.module_path != discriminator_model.module_path: # pragma: no cover
936+
imports.append(Import.from_full_path(enum_source.name))
937+
else:
938+
data_type = self.data_type(literals=type_names)
939+
return data_type
940+
915941
def __apply_discriminator_type( # noqa: PLR0912, PLR0915
916942
self,
917943
models: list[DataModel],
@@ -988,6 +1014,31 @@ def check_paths(
9881014
msg = f"Discriminator type is not found. {data_type.reference.path}"
9891015
raise RuntimeError(msg)
9901016

1017+
enum_from_base: Enum | None = None
1018+
if self.use_enum_values_in_discriminator:
1019+
for base_class in discriminator_model.base_classes:
1020+
if not base_class.reference or not base_class.reference.source: # pragma: no cover
1021+
continue
1022+
base_model = base_class.reference.source
1023+
if not isinstance( # pragma: no cover
1024+
base_model,
1025+
(
1026+
pydantic_model.BaseModel,
1027+
pydantic_model_v2.BaseModel,
1028+
dataclass_model.DataClass,
1029+
msgspec_model.Struct,
1030+
),
1031+
):
1032+
continue
1033+
for base_field in base_model.fields: # pragma: no branch
1034+
if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover
1035+
continue
1036+
enum_from_base = base_field.data_type.find_source(Enum)
1037+
if enum_from_base: # pragma: no branch
1038+
break
1039+
if enum_from_base: # pragma: no branch
1040+
break
1041+
9911042
has_one_literal = False
9921043
for discriminator_field in discriminator_model.fields:
9931044
if field_name not in {discriminator_field.original_name, discriminator_field.name}:
@@ -1001,19 +1052,32 @@ def check_paths(
10011052
discriminator_field.extras["is_classvar"] = True
10021053
# Found the discriminator field, no need to keep looking
10031054
break
1055+
1056+
enum_source: Enum | None = None
1057+
if self.use_enum_values_in_discriminator:
1058+
enum_source = ( # pragma: no cover
1059+
discriminator_field.data_type.find_source(Enum) or enum_from_base
1060+
)
1061+
10041062
for field_data_type in discriminator_field.data_type.all_data_types:
10051063
if field_data_type.reference: # pragma: no cover
10061064
field_data_type.remove_reference()
1007-
discriminator_field.data_type = self.data_type(literals=type_names)
1065+
1066+
discriminator_field.data_type = self._create_discriminator_data_type(
1067+
enum_source, type_names, discriminator_model, imports
1068+
)
10081069
discriminator_field.data_type.parent = discriminator_field
10091070
discriminator_field.required = True
10101071
imports.append(discriminator_field.imports)
10111072
has_one_literal = True
10121073
if not has_one_literal:
1074+
new_data_type = self._create_discriminator_data_type(
1075+
enum_from_base, type_names, discriminator_model, imports
1076+
)
10131077
discriminator_model.fields.append(
10141078
self.data_model_field_type(
10151079
name=field_name,
1016-
data_type=self.data_type(literals=type_names),
1080+
data_type=new_data_type,
10171081
required=True,
10181082
alias=alias,
10191083
)

src/datamodel_code_generator/parser/graphql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def __init__( # noqa: PLR0913
167167
capitalise_enum_members: bool = False,
168168
keep_model_order: bool = False,
169169
use_one_literal_as_default: bool = False,
170+
use_enum_values_in_discriminator: bool = False,
170171
known_third_party: list[str] | None = None,
171172
custom_formatters: list[str] | None = None,
172173
custom_formatters_kwargs: dict[str, Any] | None = None,
@@ -222,6 +223,7 @@ def __init__( # noqa: PLR0913
222223
encoding=encoding,
223224
enum_field_as_literal=enum_field_as_literal,
224225
use_one_literal_as_default=use_one_literal_as_default,
226+
use_enum_values_in_discriminator=use_enum_values_in_discriminator,
225227
set_default_enum_member=set_default_enum_member,
226228
use_subclass_enum=use_subclass_enum,
227229
use_specialized_enum=use_specialized_enum,

src/datamodel_code_generator/parser/jsonschema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def __init__( # noqa: PLR0913
496496
encoding: str = "utf-8",
497497
enum_field_as_literal: LiteralType | None = None,
498498
use_one_literal_as_default: bool = False,
499+
use_enum_values_in_discriminator: bool = False,
499500
set_default_enum_member: bool = False,
500501
use_subclass_enum: bool = False,
501502
use_specialized_enum: bool = True,
@@ -585,6 +586,7 @@ def __init__( # noqa: PLR0913
585586
encoding=encoding,
586587
enum_field_as_literal=enum_field_as_literal,
587588
use_one_literal_as_default=use_one_literal_as_default,
589+
use_enum_values_in_discriminator=use_enum_values_in_discriminator,
588590
set_default_enum_member=set_default_enum_member,
589591
use_subclass_enum=use_subclass_enum,
590592
use_specialized_enum=use_specialized_enum,

src/datamodel_code_generator/parser/openapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __init__( # noqa: PLR0913
214214
encoding: str = "utf-8",
215215
enum_field_as_literal: LiteralType | None = None,
216216
use_one_literal_as_default: bool = False,
217+
use_enum_values_in_discriminator: bool = False,
217218
set_default_enum_member: bool = False,
218219
use_subclass_enum: bool = False,
219220
use_specialized_enum: bool = True,
@@ -305,6 +306,7 @@ def __init__( # noqa: PLR0913
305306
encoding=encoding,
306307
enum_field_as_literal=enum_field_as_literal,
307308
use_one_literal_as_default=use_one_literal_as_default,
309+
use_enum_values_in_discriminator=use_enum_values_in_discriminator,
308310
set_default_enum_member=set_default_enum_member,
309311
use_subclass_enum=use_subclass_enum,
310312
use_specialized_enum=use_specialized_enum,

src/datamodel_code_generator/types.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict
5555

5656
T = TypeVar("T")
57+
SourceT = TypeVar("SourceT")
5758

5859
OPTIONAL = "Optional"
5960
OPTIONAL_PREFIX = f"{OPTIONAL}["
@@ -321,6 +322,7 @@ class Config:
321322
is_set: bool = False
322323
is_custom_type: bool = False
323324
literals: list[Union[StrictBool, StrictInt, StrictStr]] = [] # noqa: RUF012, UP007
325+
enum_member_literals: list[tuple[str, str]] = [] # noqa: RUF012 # [(EnumClassName, member_name), ...]
324326
use_standard_collections: bool = False
325327
use_generic_container: bool = False
326328
use_union_operator: bool = False
@@ -406,6 +408,16 @@ def all_data_types(self) -> Iterator[DataType]:
406408
yield from data_type.all_data_types
407409
yield self
408410

411+
def find_source(self, source_type: type[SourceT]) -> SourceT | None:
412+
"""Find the first reference source matching the given type from all nested data types."""
413+
for data_type in self.all_data_types: # pragma: no branch
414+
if not data_type.reference: # pragma: no cover
415+
continue
416+
source = data_type.reference.source
417+
if isinstance(source, source_type): # pragma: no cover
418+
return source
419+
return None # pragma: no cover
420+
409421
@property
410422
def all_imports(self) -> Iterator[Import]:
411423
"""Recursively yield all imports from nested DataTypes and self."""
@@ -424,7 +436,7 @@ def imports(self) -> Iterator[Import]:
424436
imports: tuple[tuple[bool, Import], ...] = (
425437
(self.is_optional and not self.use_union_operator, IMPORT_OPTIONAL),
426438
(len(self.data_types) > 1 and not self.use_union_operator, IMPORT_UNION),
427-
(bool(self.literals), IMPORT_LITERAL),
439+
(bool(self.literals) or bool(self.enum_member_literals), IMPORT_LITERAL),
428440
)
429441

430442
if self.use_generic_container:
@@ -513,6 +525,9 @@ def type_hint(self) -> str: # noqa: PLR0912, PLR0915
513525
type_ = f"{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]"
514526
elif len(self.data_types) == 1:
515527
type_ = self.data_types[0].type_hint
528+
elif self.enum_member_literals:
529+
parts = [f"{enum_class}.{member}" for enum_class, member in self.enum_member_literals]
530+
type_ = f"{LITERAL}[{', '.join(parts)}]"
516531
elif self.literals:
517532
type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]"
518533
elif self.reference:

0 commit comments

Comments
 (0)