Skip to content

Commit 00a3dff

Browse files
committed
Add multiple --input-model support with inheritance preservation
1 parent efe8dfa commit 00a3dff

12 files changed

Lines changed: 1030 additions & 6 deletions

src/datamodel_code_generator/__main__.py

Lines changed: 248 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,16 @@ def validate_all_exports_collision_strategy(self: Self) -> Self: # pyright: ign
385385
raise Error(self.__validate_all_exports_collision_strategy_err)
386386
return self
387387

388+
from pydantic import field_validator as _field_validator # noqa: PLC0415
389+
390+
@_field_validator("input_model", mode="before")
391+
@classmethod
392+
def coerce_input_model_to_list(cls, v: str | list[str] | None) -> list[str] | None: # pyright: ignore[reportRedeclaration]
393+
"""Convert string input_model to list for backwards compatibility."""
394+
if isinstance(v, str):
395+
return [v]
396+
return v
397+
388398
else:
389399

390400
@model_validator() # pyright: ignore[reportArgumentType]
@@ -443,8 +453,16 @@ def validate_all_exports_collision_strategy(cls, values: dict[str, Any]) -> dict
443453
raise Error(cls.__validate_all_exports_collision_strategy_err)
444454
return values
445455

456+
@field_validator("input_model", mode="before")
457+
@classmethod
458+
def coerce_input_model_to_list(cls, v: str | list[str] | None) -> list[str] | None:
459+
"""Convert string input_model to list for backwards compatibility."""
460+
if isinstance(v, str):
461+
return [v]
462+
return v
463+
446464
input: Optional[Union[Path, str]] = None # noqa: UP007, UP045
447-
input_model: Optional[str] = None # noqa: UP045
465+
input_model: Optional[list[str]] = None # noqa: UP045
448466
input_model_ref_strategy: Optional[InputModelRefStrategy] = None # noqa: UP045
449467
input_file_type: InputFileType = InputFileType.Auto
450468
output_model_type: DataModelType = DataModelType.PydanticBaseModel
@@ -1172,6 +1190,231 @@ def _try_rebuild_model(obj: type) -> None:
11721190
obj.model_rebuild()
11731191

11741192

1193+
def _get_base_model_parents(model_class: type) -> list[type]:
1194+
"""Get parent classes that are BaseModel subclasses (excluding BaseModel itself)."""
1195+
return [
1196+
p
1197+
for p in model_class.__bases__
1198+
if isinstance(p, type) and issubclass(p, BaseModel) and p is not BaseModel
1199+
]
1200+
1201+
1202+
def _transform_single_model_to_inheritance( # noqa: PLR0912
1203+
schema: dict[str, object],
1204+
model_class: type,
1205+
schema_generator: type,
1206+
processed_parents: dict[str, dict[str, object]] | None = None,
1207+
) -> dict[str, object]:
1208+
"""Transform a single model's schema to use allOf inheritance structure.
1209+
1210+
Args:
1211+
schema: The JSON schema generated by Pydantic
1212+
model_class: The Pydantic model class
1213+
schema_generator: The schema generator class
1214+
processed_parents: Cache of already processed parent schemas
1215+
1216+
Returns:
1217+
Transformed schema with allOf structure for inheritance
1218+
"""
1219+
if processed_parents is None:
1220+
processed_parents = {}
1221+
1222+
direct_parents = _get_base_model_parents(model_class)
1223+
1224+
if not direct_parents:
1225+
return schema
1226+
1227+
parent = direct_parents[0]
1228+
parent_name = parent.__name__
1229+
parent_fields = set(parent.model_fields.keys())
1230+
1231+
defs = dict(cast("dict[str, object]", schema.get("$defs", {})))
1232+
1233+
if parent_name in processed_parents:
1234+
parent_schema = processed_parents[parent_name]
1235+
else:
1236+
if hasattr(parent, "model_rebuild"):
1237+
_try_rebuild_model(parent)
1238+
parent_schema = parent.model_json_schema(schema_generator=schema_generator)
1239+
parent_schema = _add_python_type_for_unserializable(parent_schema, parent)
1240+
parent_schema = _add_python_type_info(parent_schema, parent)
1241+
parent_schema = _transform_single_model_to_inheritance(
1242+
parent_schema, parent, schema_generator, processed_parents
1243+
)
1244+
processed_parents[parent_name] = parent_schema
1245+
1246+
if "$defs" in parent_schema:
1247+
parent_defs = cast("dict[str, object]", parent_schema["$defs"])
1248+
for k, v in parent_defs.items():
1249+
if k not in defs:
1250+
defs[k] = v
1251+
1252+
parent_def = {k: v for k, v in parent_schema.items() if k != "$defs"}
1253+
defs[parent_name] = parent_def
1254+
1255+
original_props = cast("dict[str, object]", schema.get("properties", {}))
1256+
child_props = {k: v for k, v in original_props.items() if k not in parent_fields}
1257+
1258+
new_schema: dict[str, object] = {}
1259+
if defs:
1260+
new_schema["$defs"] = defs
1261+
new_schema["allOf"] = [{"$ref": f"#/$defs/{parent_name}"}]
1262+
if child_props:
1263+
new_schema["properties"] = child_props
1264+
original_required = cast("list[str]", schema.get("required", []))
1265+
child_required = [r for r in original_required if r not in parent_fields]
1266+
if child_required:
1267+
new_schema["required"] = child_required
1268+
new_schema["title"] = schema.get("title")
1269+
new_schema["type"] = "object"
1270+
1271+
for key in schema:
1272+
if key not in {"$defs", "properties", "required", "title", "type", "allOf"}:
1273+
new_schema[key] = schema[key]
1274+
1275+
return new_schema
1276+
1277+
1278+
def _load_multiple_model_schemas( # noqa: PLR0912, PLR0914, PLR0915
1279+
input_models: list[str],
1280+
input_file_type: InputFileType,
1281+
ref_strategy: InputModelRefStrategy | None = None,
1282+
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
1283+
) -> dict[str, object]:
1284+
"""Load and merge schemas from multiple Python import paths with inheritance support.
1285+
1286+
Args:
1287+
input_models: List of import paths in 'module.path:ObjectName' format
1288+
input_file_type: Current input file type setting for validation
1289+
ref_strategy: Strategy for handling referenced types
1290+
output_model_type: Target output model type for reuse-foreign strategy
1291+
1292+
Returns:
1293+
Merged schema dict with anyOf referencing all root models
1294+
"""
1295+
import importlib.util # noqa: PLC0415
1296+
import sys # noqa: PLC0415
1297+
1298+
if len(input_models) == 1:
1299+
return _load_model_schema(
1300+
input_models[0], input_file_type, ref_strategy, output_model_type
1301+
)
1302+
1303+
cwd = str(Path.cwd())
1304+
if cwd not in sys.path:
1305+
sys.path.insert(0, cwd)
1306+
1307+
model_classes: list[type] = []
1308+
loaded_modules: dict[str, object] = {}
1309+
1310+
for input_model in input_models:
1311+
modname, sep, qualname = input_model.rpartition(":")
1312+
if not sep or not modname:
1313+
msg = f"Invalid --input-model format: {input_model!r}. Expected 'module:Object' or 'path/to/file.py:Object'."
1314+
raise Error(msg)
1315+
1316+
if modname not in loaded_modules:
1317+
is_path = "/" in modname or "\\" in modname
1318+
if not is_path and modname.endswith(".py"):
1319+
is_path = Path(modname).exists()
1320+
1321+
if is_path:
1322+
file_path = Path(modname).resolve()
1323+
if not file_path.exists():
1324+
msg = f"File not found: {modname!r}"
1325+
raise Error(msg)
1326+
module_name = file_path.stem
1327+
spec = importlib.util.spec_from_file_location(module_name, file_path)
1328+
if spec is None or spec.loader is None:
1329+
msg = f"Cannot load module from {modname!r}"
1330+
raise Error(msg)
1331+
module = importlib.util.module_from_spec(spec)
1332+
sys.modules[module_name] = module
1333+
spec.loader.exec_module(module)
1334+
else:
1335+
try:
1336+
found_spec = importlib.util.find_spec(modname)
1337+
if found_spec is None:
1338+
msg = f"Cannot find module {modname!r}"
1339+
raise Error(msg)
1340+
module = importlib.import_module(modname)
1341+
except ImportError as e:
1342+
msg = f"Cannot import module {modname!r}: {e}"
1343+
raise Error(msg) from e
1344+
loaded_modules[modname] = module
1345+
else:
1346+
module = loaded_modules[modname]
1347+
1348+
try:
1349+
obj = getattr(module, qualname)
1350+
except AttributeError as e:
1351+
msg = f"Module {modname!r} has no attribute {qualname!r}"
1352+
raise Error(msg) from e
1353+
1354+
if not (isinstance(obj, type) and issubclass(obj, BaseModel)):
1355+
msg = f"Multiple --input-model only supports Pydantic v2 BaseModel classes, got {type(obj).__name__}"
1356+
raise Error(msg)
1357+
1358+
if not hasattr(obj, "model_json_schema"):
1359+
msg = "Multiple --input-model with Pydantic model requires Pydantic v2 runtime. Please upgrade Pydantic to v2."
1360+
raise Error(msg)
1361+
1362+
model_classes.append(obj)
1363+
1364+
if input_file_type not in {InputFileType.Auto, InputFileType.JsonSchema}:
1365+
msg = (
1366+
f"--input-file-type must be 'jsonschema' (or omitted) "
1367+
f"when --input-model points to Pydantic models, "
1368+
f"got '{input_file_type.value}'"
1369+
)
1370+
raise Error(msg)
1371+
1372+
schema_generator = _get_input_model_json_schema_class()
1373+
merged_defs: dict[str, object] = {}
1374+
root_refs: list[dict[str, str]] = []
1375+
processed_parents: dict[str, dict[str, object]] = {}
1376+
1377+
for model_class in model_classes:
1378+
model_name = model_class.__name__
1379+
if hasattr(model_class, "model_rebuild"):
1380+
_try_rebuild_model(model_class)
1381+
1382+
schema = model_class.model_json_schema(schema_generator=schema_generator)
1383+
schema = _add_python_type_for_unserializable(schema, model_class)
1384+
schema = _add_python_type_info(schema, model_class)
1385+
1386+
schema = _transform_single_model_to_inheritance(
1387+
schema, model_class, schema_generator, processed_parents
1388+
)
1389+
1390+
if "$defs" in schema:
1391+
schema_defs = cast("dict[str, object]", schema["$defs"])
1392+
for k, v in schema_defs.items():
1393+
if k not in merged_defs:
1394+
merged_defs[k] = v
1395+
1396+
model_def = {k: v for k, v in schema.items() if k != "$defs"}
1397+
merged_defs[model_name] = model_def
1398+
1399+
root_refs.append({"$ref": f"#/$defs/{model_name}"})
1400+
1401+
final_schema: dict[str, object] = {"$defs": merged_defs}
1402+
if len(root_refs) == 1:
1403+
final_schema.update(root_refs[0])
1404+
else:
1405+
final_schema["anyOf"] = root_refs
1406+
1407+
if ref_strategy and ref_strategy != InputModelRefStrategy.RegenerateAll:
1408+
all_nested_models: dict[str, type] = {}
1409+
for model_class in model_classes:
1410+
all_nested_models.update(_collect_nested_models(model_class))
1411+
final_schema = _filter_defs_by_strategy(
1412+
final_schema, all_nested_models, output_model_type, ref_strategy
1413+
)
1414+
1415+
return final_schema
1416+
1417+
11751418
def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
11761419
input_model: str,
11771420
input_file_type: InputFileType,
@@ -1262,6 +1505,9 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
12621505
schema = _add_python_type_for_unserializable(schema, obj)
12631506
schema = _add_python_type_info(schema, obj)
12641507

1508+
# Transform to inheritance structure if the model has BaseModel parents
1509+
schema = _transform_single_model_to_inheritance(schema, obj, schema_generator)
1510+
12651511
if ref_strategy and ref_strategy != InputModelRefStrategy.RegenerateAll:
12661512
nested_models = _collect_nested_models(obj)
12671513
model_name = getattr(obj, "__name__", None)
@@ -1890,7 +2136,7 @@ def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912,
18902136
try:
18912137
input_: Path | str | ParseResult
18922138
if config.input_model:
1893-
schema = _load_model_schema(
2139+
schema = _load_multiple_model_schemas(
18942140
config.input_model,
18952141
config.input_file_type,
18962142
config.input_model_ref_strategy,

src/datamodel_code_generator/_types/graphql_parser_config_dict.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from datamodel_code_generator.types import DataTypeManager
3131

3232

33-
class GraphQLParserConfigDict(TypedDict):
33+
class ParserConfig(TypedDict):
3434
data_model_type: NotRequired[type[DataModel]]
3535
data_model_root_type: NotRequired[type[DataModel]]
3636
data_type_manager_type: NotRequired[type[DataTypeManager]]
@@ -142,5 +142,8 @@ class GraphQLParserConfigDict(TypedDict):
142142
read_only_write_only_model_type: NotRequired[ReadOnlyWriteOnlyModelType | None]
143143
field_type_collision_strategy: NotRequired[FieldTypeCollisionStrategy | None]
144144
target_pydantic_version: NotRequired[TargetPydanticVersion | None]
145+
146+
147+
class GraphQLParserConfigDict(ParserConfig):
145148
data_model_scalar_type: NotRequired[type[DataModel]]
146149
data_model_union_type: NotRequired[type[DataModel]]

src/datamodel_code_generator/_types/jsonschema_parser_config_dict.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from datamodel_code_generator.types import DataTypeManager
3131

3232

33-
class JSONSchemaParserConfigDict(TypedDict):
33+
class ParserConfig(TypedDict):
3434
data_model_type: NotRequired[type[DataModel]]
3535
data_model_root_type: NotRequired[type[DataModel]]
3636
data_type_manager_type: NotRequired[type[DataTypeManager]]
@@ -142,3 +142,7 @@ class JSONSchemaParserConfigDict(TypedDict):
142142
read_only_write_only_model_type: NotRequired[ReadOnlyWriteOnlyModelType | None]
143143
field_type_collision_strategy: NotRequired[FieldTypeCollisionStrategy | None]
144144
target_pydantic_version: NotRequired[TargetPydanticVersion | None]
145+
146+
147+
class JSONSchemaParserConfigDict(ParserConfig):
148+
pass

src/datamodel_code_generator/_types/openapi_parser_config_dict.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from datamodel_code_generator.types import DataTypeManager
3232

3333

34-
class OpenAPIParserConfigDict(TypedDict):
34+
class ParserConfig(TypedDict):
3535
data_model_type: NotRequired[type[DataModel]]
3636
data_model_root_type: NotRequired[type[DataModel]]
3737
data_type_manager_type: NotRequired[type[DataTypeManager]]
@@ -143,6 +143,13 @@ class OpenAPIParserConfigDict(TypedDict):
143143
read_only_write_only_model_type: NotRequired[ReadOnlyWriteOnlyModelType | None]
144144
field_type_collision_strategy: NotRequired[FieldTypeCollisionStrategy | None]
145145
target_pydantic_version: NotRequired[TargetPydanticVersion | None]
146+
147+
148+
class JSONSchemaParserConfig(ParserConfig):
149+
pass
150+
151+
152+
class OpenAPIParserConfigDict(JSONSchemaParserConfig):
146153
openapi_scopes: NotRequired[list[OpenAPIScope] | None]
147154
include_path_parameters: NotRequired[bool]
148155
use_status_code_in_response_name: NotRequired[bool]

src/datamodel_code_generator/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,10 @@ def start_section(self, heading: str | None) -> None:
159159
)
160160
base_options.add_argument(
161161
"--input-model",
162+
action="append",
162163
help="Python import path to a Pydantic v2 model or schema dict "
163164
"(e.g., 'mypackage.module:ClassName' or 'mypackage.schemas:SCHEMA_DICT'). "
165+
"Can be specified multiple times for related models with inheritance. "
164166
"For dict input, --input-file-type is required. "
165167
"Cannot be used with --input or --url.",
166168
metavar="MODULE:NAME",
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# generated by datamodel-codegen:
2+
# filename: <stdin>
3+
# timestamp: 1985-10-26T08:21:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import TypeAlias, TypedDict
8+
9+
10+
class GrandParent(TypedDict):
11+
grand_field: str
12+
13+
14+
class Parent(GrandParent):
15+
parent_field: int
16+
17+
18+
class ChildA(Parent):
19+
child_a_field: float
20+
21+
22+
class ChildB(Parent):
23+
child_b_field: bool
24+
25+
26+
Model: TypeAlias = ChildA | ChildB

0 commit comments

Comments
 (0)