@@ -344,6 +344,8 @@ def __init__( # noqa: PLR0913
344344 )
345345 self .open_api_scopes : list [OpenAPIScope ] = openapi_scopes or [OpenAPIScope .Schemas ]
346346 self .include_path_parameters : bool = include_path_parameters
347+ self ._discriminator_schemas : dict [str , dict [str , Any ]] = {}
348+ self ._discriminator_subtypes : dict [str , list [str ]] = defaultdict (list )
347349
348350 def get_ref_model (self , ref : str ) -> dict [str , Any ]:
349351 """Resolve a reference to its model definition."""
@@ -362,6 +364,49 @@ def get_data_type(self, obj: JsonSchemaObject) -> DataType:
362364
363365 return super ().get_data_type (obj )
364366
367+ def _get_discriminator_union_type (self , ref : str ) -> DataType | None :
368+ """Create a union type for discriminator subtypes if available."""
369+ subtypes = self ._discriminator_subtypes .get (ref , [])
370+ if not subtypes :
371+ return None
372+ refs = map (self .model_resolver .add_ref , subtypes )
373+ return self .data_type (data_types = [self .data_type (reference = r ) for r in refs ])
374+
375+ def get_ref_data_type (self , ref : str ) -> DataType :
376+ """Get data type for a reference, handling discriminator polymorphism."""
377+ if ref in self ._discriminator_schemas and (union_type := self ._get_discriminator_union_type (ref )):
378+ return union_type
379+ return super ().get_ref_data_type (ref )
380+
381+ def parse_object_fields (
382+ self ,
383+ obj : JsonSchemaObject ,
384+ path : list [str ],
385+ module_name : Optional [str ] = None , # noqa: UP045
386+ ) -> list [DataModelFieldBase ]:
387+ """Parse object fields, adding discriminator info for allOf polymorphism."""
388+ fields = super ().parse_object_fields (obj , path , module_name )
389+ properties = obj .properties or {}
390+
391+ result_fields : list [DataModelFieldBase ] = []
392+ for field_obj in fields :
393+ field = properties .get (field_obj .original_name )
394+
395+ if (
396+ isinstance (field , JsonSchemaObject )
397+ and field .ref
398+ and (discriminator := self ._discriminator_schemas .get (field .ref ))
399+ ):
400+ new_field_type = self ._get_discriminator_union_type (field .ref ) or field_obj .data_type
401+ field_obj = self .data_model_field_type (** { # noqa: PLW2901
402+ ** field_obj .__dict__ ,
403+ "data_type" : new_field_type ,
404+ "extras" : {** field_obj .extras , "discriminator" : discriminator },
405+ })
406+ result_fields .append (field_obj )
407+
408+ return result_fields
409+
365410 def resolve_object (self , obj : ReferenceObject | BaseModelT , object_type : type [BaseModelT ]) -> BaseModelT :
366411 """Resolve a reference object to its actual type or return the object as-is."""
367412 if isinstance (obj , ReferenceObject ):
@@ -651,6 +696,7 @@ def parse_raw(self) -> None: # noqa: PLR0912, PLR0915
651696
652697 specification : dict [str , Any ] = load_yaml_dict (source .text )
653698 self .raw_obj = specification
699+ self ._collect_discriminator_schemas ()
654700 schemas : dict [str , Any ] = specification .get ("components" , {}).get ("schemas" , {})
655701 security : list [dict [str , list [str ]]] | None = specification .get ("security" )
656702 if OpenAPIScope .Schemas in self .open_api_scopes :
@@ -727,3 +773,25 @@ def parse_raw(self) -> None: # noqa: PLR0912, PLR0915
727773 )
728774
729775 self ._resolve_unparsed_json_pointer ()
776+
777+ def _collect_discriminator_schemas (self ) -> None :
778+ """Collect schemas with discriminators but no oneOf/anyOf, and find their subtypes."""
779+ schemas : dict [str , Any ] = self .raw_obj .get ("components" , {}).get ("schemas" , {})
780+
781+ for schema_name , schema in schemas .items ():
782+ discriminator = schema .get ("discriminator" )
783+ if not discriminator :
784+ continue
785+
786+ if schema .get ("oneOf" ) or schema .get ("anyOf" ):
787+ continue
788+
789+ ref = f"#/components/schemas/{ schema_name } "
790+ self ._discriminator_schemas [ref ] = discriminator
791+
792+ for schema_name , schema in schemas .items ():
793+ for all_of_item in schema .get ("allOf" , []):
794+ ref_in_allof = all_of_item .get ("$ref" )
795+ if ref_in_allof and ref_in_allof in self ._discriminator_schemas :
796+ subtype_ref = f"#/components/schemas/{ schema_name } "
797+ self ._discriminator_subtypes [ref_in_allof ].append (subtype_ref )
0 commit comments