@@ -498,6 +498,7 @@ def __init__( # noqa: PLR0913, PLR0915
498498 capitalise_enum_members : bool = False ,
499499 keep_model_order : bool = False ,
500500 use_one_literal_as_default : bool = False ,
501+ use_enum_values_in_discriminator : bool = False ,
501502 known_third_party : list [str ] | None = None ,
502503 custom_formatters : list [str ] | None = None ,
503504 custom_formatters_kwargs : dict [str , Any ] | None = None ,
@@ -636,6 +637,7 @@ def __init__( # noqa: PLR0913, PLR0915
636637 self .capitalise_enum_members = capitalise_enum_members
637638 self .keep_model_order = keep_model_order
638639 self .use_one_literal_as_default = use_one_literal_as_default
640+ self .use_enum_values_in_discriminator = use_enum_values_in_discriminator
639641 self .known_third_party = known_third_party
640642 self .custom_formatter = custom_formatters
641643 self .custom_formatters_kwargs = custom_formatters_kwargs
@@ -912,6 +914,30 @@ def __extract_inherited_enum(cls, models: list[DataModel]) -> None:
912914 )
913915 models .remove (model )
914916
917+ def _create_discriminator_data_type (
918+ self ,
919+ enum_source : Enum | None ,
920+ type_names : list [str ],
921+ discriminator_model : DataModel ,
922+ imports : Imports ,
923+ ) -> DataType :
924+ """Create a data type for discriminator field, using enum literals if available."""
925+ if enum_source :
926+ enum_class_name = enum_source .reference .short_name
927+ enum_member_literals : list [tuple [str , str ]] = []
928+ for value in type_names :
929+ member = enum_source .find_member (value )
930+ if member and member .field .name :
931+ enum_member_literals .append ((enum_class_name , member .field .name ))
932+ else : # pragma: no cover
933+ enum_member_literals .append ((enum_class_name , value ))
934+ data_type = self .data_type (enum_member_literals = enum_member_literals )
935+ if enum_source .module_path != discriminator_model .module_path : # pragma: no cover
936+ imports .append (Import .from_full_path (enum_source .name ))
937+ else :
938+ data_type = self .data_type (literals = type_names )
939+ return data_type
940+
915941 def __apply_discriminator_type ( # noqa: PLR0912, PLR0915
916942 self ,
917943 models : list [DataModel ],
@@ -988,6 +1014,31 @@ def check_paths(
9881014 msg = f"Discriminator type is not found. { data_type .reference .path } "
9891015 raise RuntimeError (msg )
9901016
1017+ enum_from_base : Enum | None = None
1018+ if self .use_enum_values_in_discriminator :
1019+ for base_class in discriminator_model .base_classes :
1020+ if not base_class .reference or not base_class .reference .source : # pragma: no cover
1021+ continue
1022+ base_model = base_class .reference .source
1023+ if not isinstance ( # pragma: no cover
1024+ base_model ,
1025+ (
1026+ pydantic_model .BaseModel ,
1027+ pydantic_model_v2 .BaseModel ,
1028+ dataclass_model .DataClass ,
1029+ msgspec_model .Struct ,
1030+ ),
1031+ ):
1032+ continue
1033+ for base_field in base_model .fields : # pragma: no branch
1034+ if field_name not in {base_field .original_name , base_field .name }: # pragma: no cover
1035+ continue
1036+ enum_from_base = base_field .data_type .find_source (Enum )
1037+ if enum_from_base : # pragma: no branch
1038+ break
1039+ if enum_from_base : # pragma: no branch
1040+ break
1041+
9911042 has_one_literal = False
9921043 for discriminator_field in discriminator_model .fields :
9931044 if field_name not in {discriminator_field .original_name , discriminator_field .name }:
@@ -1001,19 +1052,32 @@ def check_paths(
10011052 discriminator_field .extras ["is_classvar" ] = True
10021053 # Found the discriminator field, no need to keep looking
10031054 break
1055+
1056+ enum_source : Enum | None = None
1057+ if self .use_enum_values_in_discriminator :
1058+ enum_source = ( # pragma: no cover
1059+ discriminator_field .data_type .find_source (Enum ) or enum_from_base
1060+ )
1061+
10041062 for field_data_type in discriminator_field .data_type .all_data_types :
10051063 if field_data_type .reference : # pragma: no cover
10061064 field_data_type .remove_reference ()
1007- discriminator_field .data_type = self .data_type (literals = type_names )
1065+
1066+ discriminator_field .data_type = self ._create_discriminator_data_type (
1067+ enum_source , type_names , discriminator_model , imports
1068+ )
10081069 discriminator_field .data_type .parent = discriminator_field
10091070 discriminator_field .required = True
10101071 imports .append (discriminator_field .imports )
10111072 has_one_literal = True
10121073 if not has_one_literal :
1074+ new_data_type = self ._create_discriminator_data_type (
1075+ enum_from_base , type_names , discriminator_model , imports
1076+ )
10131077 discriminator_model .fields .append (
10141078 self .data_model_field_type (
10151079 name = field_name ,
1016- data_type = self . data_type ( literals = type_names ) ,
1080+ data_type = new_data_type ,
10171081 required = True ,
10181082 alias = alias ,
10191083 )
0 commit comments