|
20 | 20 |
|
21 | 21 | from pydantic import BaseModel |
22 | 22 |
|
| 23 | +from datamodel_code_generator import DEFAULT_SHARED_MODULE_NAME, Error, ReuseScope |
23 | 24 | from datamodel_code_generator.format import ( |
24 | 25 | DEFAULT_FORMATTERS, |
25 | 26 | CodeFormatter, |
@@ -405,6 +406,8 @@ def __init__( # noqa: PLR0913, PLR0915 |
405 | 406 | use_inline_field_description: bool = False, |
406 | 407 | use_default_kwarg: bool = False, |
407 | 408 | reuse_model: bool = False, |
| 409 | + reuse_scope: ReuseScope | None = None, |
| 410 | + shared_module_name: str = DEFAULT_SHARED_MODULE_NAME, |
408 | 411 | encoding: str = "utf-8", |
409 | 412 | enum_field_as_literal: LiteralType | None = None, |
410 | 413 | set_default_enum_member: bool = False, |
@@ -495,6 +498,8 @@ def __init__( # noqa: PLR0913, PLR0915 |
495 | 498 | self.use_inline_field_description: bool = use_inline_field_description |
496 | 499 | self.use_default_kwarg: bool = use_default_kwarg |
497 | 500 | self.reuse_model: bool = reuse_model |
| 501 | + self.reuse_scope: ReuseScope | None = reuse_scope |
| 502 | + self.shared_module_name: str = shared_module_name |
498 | 503 | self.encoding: str = encoding |
499 | 504 | self.enum_field_as_literal: LiteralType | None = enum_field_as_literal |
500 | 505 | self.set_default_enum_member: bool = set_default_enum_member |
@@ -1002,7 +1007,7 @@ def __set_reference_default_value_to_field(cls, models: list[DataModel]) -> None |
1002 | 1007 | model_field.default = model_field.data_type.reference.source.default |
1003 | 1008 |
|
1004 | 1009 | def __reuse_model(self, models: list[DataModel], require_update_action_models: list[str]) -> None: |
1005 | | - if not self.reuse_model: |
| 1010 | + if not self.reuse_model or self.reuse_scope == ReuseScope.Tree: |
1006 | 1011 | return |
1007 | 1012 | model_cache: dict[tuple[HashableComparable, ...], Reference] = {} |
1008 | 1013 | duplicates = [] |
@@ -1041,6 +1046,133 @@ def __reuse_model(self, models: list[DataModel], require_update_action_models: l |
1041 | 1046 | for duplicate in duplicates: |
1042 | 1047 | models.remove(duplicate) |
1043 | 1048 |
|
| 1049 | + def __find_duplicate_models_across_modules( # noqa: PLR6301 |
| 1050 | + self, |
| 1051 | + module_models: list[tuple[tuple[str, ...], list[DataModel]]], |
| 1052 | + ) -> list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]]: |
| 1053 | + """Find duplicate models across all modules by comparing render output and imports.""" |
| 1054 | + all_models: list[tuple[tuple[str, ...], DataModel]] = [] |
| 1055 | + for module, models in module_models: |
| 1056 | + all_models.extend((module, model) for model in models) |
| 1057 | + |
| 1058 | + model_cache: dict[tuple[HashableComparable, ...], tuple[tuple[str, ...], DataModel]] = {} |
| 1059 | + duplicates: list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]] = [] |
| 1060 | + |
| 1061 | + for module, model in all_models: |
| 1062 | + model_key = tuple(to_hashable(v) for v in (model.render(class_name="M"), model.imports)) |
| 1063 | + cached = model_cache.get(model_key) |
| 1064 | + if cached: |
| 1065 | + canonical_module, canonical_model = cached |
| 1066 | + duplicates.append((module, model, canonical_module, canonical_model)) |
| 1067 | + else: |
| 1068 | + model_cache[model_key] = (module, model) |
| 1069 | + |
| 1070 | + return duplicates |
| 1071 | + |
| 1072 | + def __validate_shared_module_name( |
| 1073 | + self, |
| 1074 | + module_models: list[tuple[tuple[str, ...], list[DataModel]]], |
| 1075 | + ) -> None: |
| 1076 | + """Validate that the shared module name doesn't conflict with existing modules.""" |
| 1077 | + shared_module = self.shared_module_name |
| 1078 | + existing_module_names = {module[0] for module, _ in module_models} |
| 1079 | + if shared_module in existing_module_names: |
| 1080 | + msg = ( |
| 1081 | + f"Schema file or directory '{shared_module}' conflicts with the shared module name. " |
| 1082 | + f"Use --shared-module-name to specify a different name." |
| 1083 | + ) |
| 1084 | + raise Error(msg) |
| 1085 | + |
| 1086 | + def __create_shared_module_from_duplicates( # noqa: PLR0912 |
| 1087 | + self, |
| 1088 | + module_models: list[tuple[tuple[str, ...], list[DataModel]]], |
| 1089 | + duplicates: list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]], |
| 1090 | + require_update_action_models: list[str], |
| 1091 | + ) -> tuple[tuple[str, ...], list[DataModel]]: |
| 1092 | + """Create shared module with canonical models and replace duplicates with inherited models.""" |
| 1093 | + shared_module = self.shared_module_name |
| 1094 | + |
| 1095 | + shared_models: list[DataModel] = [] |
| 1096 | + canonical_to_shared_ref: dict[DataModel, Reference] = {} |
| 1097 | + canonical_models_seen: set[DataModel] = set() |
| 1098 | + |
| 1099 | + # Process in order of first appearance in duplicates to ensure stable ordering |
| 1100 | + for _, _, _, canonical in duplicates: |
| 1101 | + if canonical in canonical_models_seen: |
| 1102 | + continue |
| 1103 | + canonical_models_seen.add(canonical) |
| 1104 | + canonical.file_path = Path(f"{shared_module}.py") |
| 1105 | + canonical_to_shared_ref[canonical] = canonical.reference |
| 1106 | + shared_models.append(canonical) |
| 1107 | + |
| 1108 | + supports_inheritance = issubclass( |
| 1109 | + self.data_model_type, |
| 1110 | + ( |
| 1111 | + pydantic_model.BaseModel, |
| 1112 | + pydantic_model_v2.BaseModel, |
| 1113 | + dataclass_model.DataClass, |
| 1114 | + ), |
| 1115 | + ) |
| 1116 | + |
| 1117 | + for duplicate_module, duplicate_model, _, canonical_model in duplicates: |
| 1118 | + shared_ref = canonical_to_shared_ref[canonical_model] |
| 1119 | + for module, models in module_models: |
| 1120 | + if module != duplicate_module or duplicate_model not in models: |
| 1121 | + continue |
| 1122 | + if isinstance(duplicate_model, Enum) or not supports_inheritance: |
| 1123 | + for child in duplicate_model.reference.children[:]: |
| 1124 | + data_model = get_most_of_parent(child) |
| 1125 | + if data_model in models and isinstance(child, DataType): |
| 1126 | + child.replace_reference(shared_ref) |
| 1127 | + models.remove(duplicate_model) |
| 1128 | + else: |
| 1129 | + index = models.index(duplicate_model) |
| 1130 | + inherited_model = duplicate_model.__class__( |
| 1131 | + fields=[], |
| 1132 | + base_classes=[shared_ref], |
| 1133 | + description=duplicate_model.description, |
| 1134 | + reference=Reference( |
| 1135 | + name=duplicate_model.name, |
| 1136 | + path=duplicate_model.reference.path + "/reuse", |
| 1137 | + ), |
| 1138 | + custom_template_dir=duplicate_model._custom_template_dir, # noqa: SLF001 |
| 1139 | + ) |
| 1140 | + if shared_ref.path in require_update_action_models: |
| 1141 | + add_model_path_to_list(require_update_action_models, inherited_model) |
| 1142 | + models.insert(index, inherited_model) |
| 1143 | + models.remove(duplicate_model) |
| 1144 | + break |
| 1145 | + else: # pragma: no cover |
| 1146 | + msg = f"Duplicate model {duplicate_model.name} not found in module {duplicate_module}" |
| 1147 | + raise RuntimeError(msg) |
| 1148 | + |
| 1149 | + for canonical in canonical_models_seen: |
| 1150 | + for _module, models in module_models: |
| 1151 | + if canonical in models: |
| 1152 | + models.remove(canonical) |
| 1153 | + break |
| 1154 | + else: # pragma: no cover |
| 1155 | + msg = f"Canonical model {canonical.name} not found in any module" |
| 1156 | + raise RuntimeError(msg) |
| 1157 | + |
| 1158 | + return (shared_module,), shared_models |
| 1159 | + |
| 1160 | + def __reuse_model_tree_scope( |
| 1161 | + self, |
| 1162 | + module_models: list[tuple[tuple[str, ...], list[DataModel]]], |
| 1163 | + require_update_action_models: list[str], |
| 1164 | + ) -> tuple[tuple[str, ...], list[DataModel]] | None: |
| 1165 | + """Deduplicate models across all modules, placing shared models in shared.py.""" |
| 1166 | + if not self.reuse_model or self.reuse_scope != ReuseScope.Tree: |
| 1167 | + return None |
| 1168 | + |
| 1169 | + duplicates = self.__find_duplicate_models_across_modules(module_models) |
| 1170 | + if not duplicates: |
| 1171 | + return None |
| 1172 | + |
| 1173 | + self.__validate_shared_module_name(module_models) |
| 1174 | + return self.__create_shared_module_from_duplicates(module_models, duplicates, require_update_action_models) |
| 1175 | + |
1044 | 1176 | def __collapse_root_models( # noqa: PLR0912 |
1045 | 1177 | self, |
1046 | 1178 | models: list[DataModel], |
@@ -1499,6 +1631,10 @@ def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]: |
1499 | 1631 | )) |
1500 | 1632 | previous_module = module |
1501 | 1633 |
|
| 1634 | + shared_module_entry = self.__reuse_model_tree_scope(module_models, require_update_action_models) |
| 1635 | + if shared_module_entry: |
| 1636 | + module_models.insert(0, shared_module_entry) |
| 1637 | + |
1502 | 1638 | class Processed(NamedTuple): |
1503 | 1639 | module: tuple[str, ...] |
1504 | 1640 | models: list[DataModel] |
|
0 commit comments