@@ -385,6 +385,16 @@ def validate_all_exports_collision_strategy(self: Self) -> Self: # pyright: ign
385385 raise Error (self .__validate_all_exports_collision_strategy_err )
386386 return self
387387
388+ from pydantic import field_validator as _field_validator # noqa: PLC0415
389+
390+ @_field_validator ("input_model" , mode = "before" )
391+ @classmethod
392+ def coerce_input_model_to_list (cls , v : str | list [str ] | None ) -> list [str ] | None : # pyright: ignore[reportRedeclaration]
393+ """Convert string input_model to list for backwards compatibility."""
394+ if isinstance (v , str ):
395+ return [v ]
396+ return v
397+
388398 else :
389399
390400 @model_validator () # pyright: ignore[reportArgumentType]
@@ -443,8 +453,16 @@ def validate_all_exports_collision_strategy(cls, values: dict[str, Any]) -> dict
443453 raise Error (cls .__validate_all_exports_collision_strategy_err )
444454 return values
445455
456+ @field_validator ("input_model" , mode = "before" )
457+ @classmethod
458+ def coerce_input_model_to_list (cls , v : str | list [str ] | None ) -> list [str ] | None :
459+ """Convert string input_model to list for backwards compatibility."""
460+ if isinstance (v , str ):
461+ return [v ]
462+ return v
463+
446464 input : Optional [Union [Path , str ]] = None # noqa: UP007, UP045
447- input_model : Optional [str ] = None # noqa: UP045
465+ input_model : Optional [list [ str ] ] = None # noqa: UP045
448466 input_model_ref_strategy : Optional [InputModelRefStrategy ] = None # noqa: UP045
449467 input_file_type : InputFileType = InputFileType .Auto
450468 output_model_type : DataModelType = DataModelType .PydanticBaseModel
@@ -1172,6 +1190,231 @@ def _try_rebuild_model(obj: type) -> None:
11721190 obj .model_rebuild ()
11731191
11741192
1193+ def _get_base_model_parents (model_class : type ) -> list [type ]:
1194+ """Get parent classes that are BaseModel subclasses (excluding BaseModel itself)."""
1195+ return [
1196+ p
1197+ for p in model_class .__bases__
1198+ if isinstance (p , type ) and issubclass (p , BaseModel ) and p is not BaseModel
1199+ ]
1200+
1201+
1202+ def _transform_single_model_to_inheritance ( # noqa: PLR0912
1203+ schema : dict [str , object ],
1204+ model_class : type ,
1205+ schema_generator : type ,
1206+ processed_parents : dict [str , dict [str , object ]] | None = None ,
1207+ ) -> dict [str , object ]:
1208+ """Transform a single model's schema to use allOf inheritance structure.
1209+
1210+ Args:
1211+ schema: The JSON schema generated by Pydantic
1212+ model_class: The Pydantic model class
1213+ schema_generator: The schema generator class
1214+ processed_parents: Cache of already processed parent schemas
1215+
1216+ Returns:
1217+ Transformed schema with allOf structure for inheritance
1218+ """
1219+ if processed_parents is None :
1220+ processed_parents = {}
1221+
1222+ direct_parents = _get_base_model_parents (model_class )
1223+
1224+ if not direct_parents :
1225+ return schema
1226+
1227+ parent = direct_parents [0 ]
1228+ parent_name = parent .__name__
1229+ parent_fields = set (parent .model_fields .keys ())
1230+
1231+ defs = dict (cast ("dict[str, object]" , schema .get ("$defs" , {})))
1232+
1233+ if parent_name in processed_parents :
1234+ parent_schema = processed_parents [parent_name ]
1235+ else :
1236+ if hasattr (parent , "model_rebuild" ):
1237+ _try_rebuild_model (parent )
1238+ parent_schema = parent .model_json_schema (schema_generator = schema_generator )
1239+ parent_schema = _add_python_type_for_unserializable (parent_schema , parent )
1240+ parent_schema = _add_python_type_info (parent_schema , parent )
1241+ parent_schema = _transform_single_model_to_inheritance (
1242+ parent_schema , parent , schema_generator , processed_parents
1243+ )
1244+ processed_parents [parent_name ] = parent_schema
1245+
1246+ if "$defs" in parent_schema :
1247+ parent_defs = cast ("dict[str, object]" , parent_schema ["$defs" ])
1248+ for k , v in parent_defs .items ():
1249+ if k not in defs :
1250+ defs [k ] = v
1251+
1252+ parent_def = {k : v for k , v in parent_schema .items () if k != "$defs" }
1253+ defs [parent_name ] = parent_def
1254+
1255+ original_props = cast ("dict[str, object]" , schema .get ("properties" , {}))
1256+ child_props = {k : v for k , v in original_props .items () if k not in parent_fields }
1257+
1258+ new_schema : dict [str , object ] = {}
1259+ if defs :
1260+ new_schema ["$defs" ] = defs
1261+ new_schema ["allOf" ] = [{"$ref" : f"#/$defs/{ parent_name } " }]
1262+ if child_props :
1263+ new_schema ["properties" ] = child_props
1264+ original_required = cast ("list[str]" , schema .get ("required" , []))
1265+ child_required = [r for r in original_required if r not in parent_fields ]
1266+ if child_required :
1267+ new_schema ["required" ] = child_required
1268+ new_schema ["title" ] = schema .get ("title" )
1269+ new_schema ["type" ] = "object"
1270+
1271+ for key in schema :
1272+ if key not in {"$defs" , "properties" , "required" , "title" , "type" , "allOf" }:
1273+ new_schema [key ] = schema [key ]
1274+
1275+ return new_schema
1276+
1277+
1278+ def _load_multiple_model_schemas ( # noqa: PLR0912, PLR0914, PLR0915
1279+ input_models : list [str ],
1280+ input_file_type : InputFileType ,
1281+ ref_strategy : InputModelRefStrategy | None = None ,
1282+ output_model_type : DataModelType = DataModelType .PydanticBaseModel ,
1283+ ) -> dict [str , object ]:
1284+ """Load and merge schemas from multiple Python import paths with inheritance support.
1285+
1286+ Args:
1287+ input_models: List of import paths in 'module.path:ObjectName' format
1288+ input_file_type: Current input file type setting for validation
1289+ ref_strategy: Strategy for handling referenced types
1290+ output_model_type: Target output model type for reuse-foreign strategy
1291+
1292+ Returns:
1293+ Merged schema dict with anyOf referencing all root models
1294+ """
1295+ import importlib .util # noqa: PLC0415
1296+ import sys # noqa: PLC0415
1297+
1298+ if len (input_models ) == 1 :
1299+ return _load_model_schema (
1300+ input_models [0 ], input_file_type , ref_strategy , output_model_type
1301+ )
1302+
1303+ cwd = str (Path .cwd ())
1304+ if cwd not in sys .path :
1305+ sys .path .insert (0 , cwd )
1306+
1307+ model_classes : list [type ] = []
1308+ loaded_modules : dict [str , object ] = {}
1309+
1310+ for input_model in input_models :
1311+ modname , sep , qualname = input_model .rpartition (":" )
1312+ if not sep or not modname :
1313+ msg = f"Invalid --input-model format: { input_model !r} . Expected 'module:Object' or 'path/to/file.py:Object'."
1314+ raise Error (msg )
1315+
1316+ if modname not in loaded_modules :
1317+ is_path = "/" in modname or "\\ " in modname
1318+ if not is_path and modname .endswith (".py" ):
1319+ is_path = Path (modname ).exists ()
1320+
1321+ if is_path :
1322+ file_path = Path (modname ).resolve ()
1323+ if not file_path .exists ():
1324+ msg = f"File not found: { modname !r} "
1325+ raise Error (msg )
1326+ module_name = file_path .stem
1327+ spec = importlib .util .spec_from_file_location (module_name , file_path )
1328+ if spec is None or spec .loader is None :
1329+ msg = f"Cannot load module from { modname !r} "
1330+ raise Error (msg )
1331+ module = importlib .util .module_from_spec (spec )
1332+ sys .modules [module_name ] = module
1333+ spec .loader .exec_module (module )
1334+ else :
1335+ try :
1336+ found_spec = importlib .util .find_spec (modname )
1337+ if found_spec is None :
1338+ msg = f"Cannot find module { modname !r} "
1339+ raise Error (msg )
1340+ module = importlib .import_module (modname )
1341+ except ImportError as e :
1342+ msg = f"Cannot import module { modname !r} : { e } "
1343+ raise Error (msg ) from e
1344+ loaded_modules [modname ] = module
1345+ else :
1346+ module = loaded_modules [modname ]
1347+
1348+ try :
1349+ obj = getattr (module , qualname )
1350+ except AttributeError as e :
1351+ msg = f"Module { modname !r} has no attribute { qualname !r} "
1352+ raise Error (msg ) from e
1353+
1354+ if not (isinstance (obj , type ) and issubclass (obj , BaseModel )):
1355+ msg = f"Multiple --input-model only supports Pydantic v2 BaseModel classes, got { type (obj ).__name__ } "
1356+ raise Error (msg )
1357+
1358+ if not hasattr (obj , "model_json_schema" ):
1359+ msg = "Multiple --input-model with Pydantic model requires Pydantic v2 runtime. Please upgrade Pydantic to v2."
1360+ raise Error (msg )
1361+
1362+ model_classes .append (obj )
1363+
1364+ if input_file_type not in {InputFileType .Auto , InputFileType .JsonSchema }:
1365+ msg = (
1366+ f"--input-file-type must be 'jsonschema' (or omitted) "
1367+ f"when --input-model points to Pydantic models, "
1368+ f"got '{ input_file_type .value } '"
1369+ )
1370+ raise Error (msg )
1371+
1372+ schema_generator = _get_input_model_json_schema_class ()
1373+ merged_defs : dict [str , object ] = {}
1374+ root_refs : list [dict [str , str ]] = []
1375+ processed_parents : dict [str , dict [str , object ]] = {}
1376+
1377+ for model_class in model_classes :
1378+ model_name = model_class .__name__
1379+ if hasattr (model_class , "model_rebuild" ):
1380+ _try_rebuild_model (model_class )
1381+
1382+ schema = model_class .model_json_schema (schema_generator = schema_generator )
1383+ schema = _add_python_type_for_unserializable (schema , model_class )
1384+ schema = _add_python_type_info (schema , model_class )
1385+
1386+ schema = _transform_single_model_to_inheritance (
1387+ schema , model_class , schema_generator , processed_parents
1388+ )
1389+
1390+ if "$defs" in schema :
1391+ schema_defs = cast ("dict[str, object]" , schema ["$defs" ])
1392+ for k , v in schema_defs .items ():
1393+ if k not in merged_defs :
1394+ merged_defs [k ] = v
1395+
1396+ model_def = {k : v for k , v in schema .items () if k != "$defs" }
1397+ merged_defs [model_name ] = model_def
1398+
1399+ root_refs .append ({"$ref" : f"#/$defs/{ model_name } " })
1400+
1401+ final_schema : dict [str , object ] = {"$defs" : merged_defs }
1402+ if len (root_refs ) == 1 :
1403+ final_schema .update (root_refs [0 ])
1404+ else :
1405+ final_schema ["anyOf" ] = root_refs
1406+
1407+ if ref_strategy and ref_strategy != InputModelRefStrategy .RegenerateAll :
1408+ all_nested_models : dict [str , type ] = {}
1409+ for model_class in model_classes :
1410+ all_nested_models .update (_collect_nested_models (model_class ))
1411+ final_schema = _filter_defs_by_strategy (
1412+ final_schema , all_nested_models , output_model_type , ref_strategy
1413+ )
1414+
1415+ return final_schema
1416+
1417+
11751418def _load_model_schema ( # noqa: PLR0912, PLR0914, PLR0915
11761419 input_model : str ,
11771420 input_file_type : InputFileType ,
@@ -1262,6 +1505,9 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
12621505 schema = _add_python_type_for_unserializable (schema , obj )
12631506 schema = _add_python_type_info (schema , obj )
12641507
1508+ # Transform to inheritance structure if the model has BaseModel parents
1509+ schema = _transform_single_model_to_inheritance (schema , obj , schema_generator )
1510+
12651511 if ref_strategy and ref_strategy != InputModelRefStrategy .RegenerateAll :
12661512 nested_models = _collect_nested_models (obj )
12671513 model_name = getattr (obj , "__name__" , None )
@@ -1890,7 +2136,7 @@ def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912,
18902136 try :
18912137 input_ : Path | str | ParseResult
18922138 if config .input_model :
1893- schema = _load_model_schema (
2139+ schema = _load_multiple_model_schemas (
18942140 config .input_model ,
18952141 config .input_file_type ,
18962142 config .input_model_ref_strategy ,
0 commit comments