@@ -110,6 +110,7 @@ def __ge__(self, value: Any, /) -> bool: ... # noqa: D105
110110ModelNames : TypeAlias = set [ModelName ]
111111ModelDeps : TypeAlias = dict [ModelName , set [ModelName ]]
112112OrderIndex : TypeAlias = dict [ModelName , int ]
113+ DiscriminatorValue : TypeAlias = str | int | bool
113114
114115_BUILTIN_NAMES : frozenset [str ] = frozenset (name for name in builtins .__dict__ if not name .startswith ("_" ))
115116_BUILTIN_NAMES_INTRODUCED_IN : dict [PythonVersion , frozenset [str ]] = {
@@ -1514,25 +1515,25 @@ def __extract_inherited_enum(cls, models: list[DataModel]) -> None:
15141515 def _create_discriminator_data_type (
15151516 self ,
15161517 enum_source : Enum | None ,
1517- type_names : list [str ],
1518+ discriminator_values : list [DiscriminatorValue ],
15181519 discriminator_model : DataModel ,
15191520 imports : Imports ,
15201521 ) -> DataType :
15211522 """Create a data type for discriminator field, using enum literals if available."""
15221523 if enum_source :
15231524 enum_class_name = enum_source .reference .short_name
15241525 enum_member_literals : list [tuple [str , str ]] = []
1525- for value in type_names :
1526+ for value in discriminator_values :
15261527 member = enum_source .find_member (value )
15271528 if member and member .field .name :
15281529 enum_member_literals .append ((enum_class_name , member .field .name ))
15291530 else : # pragma: no cover
1530- enum_member_literals .append ((enum_class_name , value ))
1531+ enum_member_literals .append ((enum_class_name , str ( value ) ))
15311532 data_type = self .data_type (enum_member_literals = enum_member_literals )
15321533 if enum_source .module_path != discriminator_model .module_path : # pragma: no cover
15331534 imports .append (Import .from_full_path (enum_source .name ))
15341535 else :
1535- data_type = self .data_type (literals = type_names )
1536+ data_type = self .data_type (literals = discriminator_values )
15361537 return data_type
15371538
15381539 def __apply_discriminator_type ( # noqa: PLR0912, PLR0914, PLR0915
@@ -1572,12 +1573,12 @@ def __apply_discriminator_type( # noqa: PLR0912, PLR0914, PLR0915
15721573 ): # pragma: no cover
15731574 continue
15741575
1575- type_names : list [str ] = []
1576+ discriminator_values : list [DiscriminatorValue ] = []
15761577
15771578 def check_paths (
15781579 model : pydantic_model_v2 .BaseModel | Reference ,
15791580 mapping : dict [str , str ],
1580- type_names : list [str ] = type_names ,
1581+ discriminator_values : list [DiscriminatorValue ] = discriminator_values ,
15811582 ) -> None :
15821583 """Validate discriminator mapping paths for a model."""
15831584 for name , path in mapping .items ():
@@ -1589,50 +1590,49 @@ def check_paths(
15891590 t_disc_2 = "/" .join (t_disc .split ("/" )[1 :])
15901591 if t_path not in {t_disc , t_disc_2 }: # pragma: no branch
15911592 continue
1592- type_names .append (name )
1593+ discriminator_values .append (name )
1594+
1595+ def get_discriminator_field_value (
1596+ discriminator_field : DataModelFieldBase ,
1597+ ) -> DiscriminatorValue | None :
1598+ const_value = discriminator_field .extras .get ("const" )
1599+ if const_value is not None :
1600+ return const_value
1601+
1602+ literals = discriminator_field .data_type .literals
1603+ if len (literals ) == 1 :
1604+ return literals [0 ]
1605+
1606+ enum_source = discriminator_field .data_type .find_source (Enum )
1607+ if enum_source and len (enum_source .fields ) == 1 :
1608+ raw_default = enum_source .fields [0 ].default
1609+ if isinstance (raw_default , str ):
1610+ return raw_default .strip ("'\" " )
1611+ return raw_default
1612+ return None
15931613
1594- # First try to get the discriminator value from the const field
15951614 for discriminator_field in discriminator_model .fields :
15961615 if field_name not in {discriminator_field .original_name , discriminator_field .name }:
15971616 continue
1598- if discriminator_field .extras .get ("const" ):
1599- type_names = [discriminator_field .extras ["const" ]]
1617+ discriminator_value = get_discriminator_field_value (discriminator_field )
1618+ if discriminator_value is not None :
1619+ discriminator_values = [discriminator_value ]
16001620 break
16011621
1602- # If no const value found, try to get it from the mapping
1603- if not type_names :
1604- # Check the main discriminator model path
1605- if mapping :
1606- check_paths (discriminator_model , mapping ) # ty: ignore
1622+ if not discriminator_values and mapping :
1623+ check_paths (discriminator_model , mapping ) # ty: ignore
16071624
1608- # Check the base_classes if they exist
1609- if len (type_names ) == 0 :
1610- for base_class in discriminator_model .base_classes :
1611- check_paths (base_class .reference , mapping ) # ty: ignore
1612- else :
1613- for discriminator_field in discriminator_model .fields :
1614- if field_name not in {discriminator_field .original_name , discriminator_field .name }:
1615- continue
1625+ if len (discriminator_values ) == 0 :
1626+ for base_class in discriminator_model .base_classes :
1627+ check_paths (base_class .reference , mapping ) # ty: ignore
16161628
1617- literals = discriminator_field .data_type .literals
1618- if literals and len (literals ) == 1 : # pragma: no cover
1619- type_names = [str (v ) for v in literals ]
1620- break
1621-
1622- enum_source = discriminator_field .data_type .find_source (Enum )
1623- if enum_source and len (enum_source .fields ) == 1 :
1624- first_field = enum_source .fields [0 ]
1625- raw_default = first_field .default
1626- if isinstance (raw_default , str ):
1627- type_names = [raw_default .strip ("'\" " )]
1628- else : # pragma: no cover
1629- type_names = [str (raw_default )]
1630- break
1629+ if not discriminator_values :
1630+ discriminator_values = [discriminator_model .path .split ("/" )[- 1 ]]
16311631
1632- if not type_names :
1633- type_names = [discriminator_model .path .split ("/" )[- 1 ]]
1632+ if not discriminator_values :
1633+ discriminator_values = [discriminator_model .path .split ("/" )[- 1 ]]
16341634
1635- if not type_names : # pragma: no cover
1635+ if not discriminator_values : # pragma: no cover
16361636 msg = f"Discriminator type is not found. { data_type .reference .path } "
16371637 raise RuntimeError (msg )
16381638
@@ -1666,7 +1666,7 @@ def check_paths(
16661666 continue
16671667 literals = discriminator_field .data_type .literals
16681668 const_value = discriminator_field .extras .get ("const" )
1669- expected_value = type_names [0 ] if type_names else None
1669+ expected_value = discriminator_values [0 ] if discriminator_values else None
16701670
16711671 # Check if literals match (existing behavior)
16721672 literals_match = len (literals ) == 1 and literals [0 ] == expected_value
@@ -1701,15 +1701,15 @@ def check_paths(
17011701 field_data_type .remove_reference ()
17021702
17031703 discriminator_field .data_type = self ._create_discriminator_data_type (
1704- enum_source , type_names , discriminator_model , imports
1704+ enum_source , discriminator_values , discriminator_model , imports
17051705 )
17061706 discriminator_field .data_type .parent = discriminator_field
17071707 discriminator_field .required = True
17081708 imports .append (discriminator_field .imports )
17091709 has_one_literal = True
17101710 if not has_one_literal :
17111711 new_data_type = self ._create_discriminator_data_type (
1712- enum_from_base , type_names , discriminator_model , imports
1712+ enum_from_base , discriminator_values , discriminator_model , imports
17131713 )
17141714 # Handle multiple aliases (Pydantic v2 AliasChoices)
17151715 single_alias : str | None = None
0 commit comments