Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,33 +1610,45 @@ def __create_shared_module_from_duplicates( # noqa: PLR0912
),
)

module_models_sets: dict[tuple[str, ...], set[DataModel]] = {
module: set(models) for module, models in module_models
}
models_to_remove: dict[tuple[str, ...], set[DataModel]] = defaultdict(set)

for duplicate_module, duplicate_model, _, canonical_model in duplicates:
shared_ref = canonical_to_shared_ref[canonical_model]
models_set = module_models_sets.get(duplicate_module)
if not models_set or duplicate_model not in models_set: # pragma: no cover
msg = f"Duplicate model {duplicate_model.name} not found in module {duplicate_module}"
raise RuntimeError(msg)

for module, models in module_models:
if module != duplicate_module or duplicate_model not in models:
if module != duplicate_module:
continue
if isinstance(duplicate_model, Enum) or not supports_inheritance or self.collapse_reuse_models:
duplicate_model.replace_children_in_models(models, shared_ref)
models.remove(duplicate_model)
models_to_remove[module].add(duplicate_model)
else:
inherited_model = duplicate_model.create_reuse_model(shared_ref)
if shared_ref.path in require_update_action_models:
add_model_path_to_list(require_update_action_models, inherited_model)
self._replace_model_in_list(models, duplicate_model, inherited_model)
break
else: # pragma: no cover
msg = f"Duplicate model {duplicate_model.name} not found in module {duplicate_module}"
raise RuntimeError(msg)

for canonical in canonical_models_seen:
for _module, models in module_models:
if canonical in models:
models.remove(canonical)
for module, models_set in module_models_sets.items():
if canonical in models_set:
models_to_remove[module].add(canonical)
break
else: # pragma: no cover
msg = f"Canonical model {canonical.name} not found in any module"
raise RuntimeError(msg)

for module, models in module_models:
to_remove = models_to_remove.get(module)
if to_remove:
models[:] = [m for m in models if m not in to_remove]

return (shared_module,), shared_models

def __reuse_model_tree_scope(
Expand Down
2 changes: 1 addition & 1 deletion src/datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2548,7 +2548,7 @@ def parse_array(
reference = self.model_resolver.add(path, name, loaded=True, class_name=True)
field = self.parse_array_fields(original_name or name, obj, [*path, name])

if reference in [d.reference for d in field.data_type.all_data_types if d.reference]:
if any(d.reference == reference for d in field.data_type.all_data_types if d.reference):
# self-reference
field = self.data_model_field_type(
data_type=self.data_type(
Expand Down
31 changes: 17 additions & 14 deletions src/datamodel_code_generator/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def parse_all_parameters(
"""Parse all operation parameters into a data model."""
fields: list[DataModelFieldBase] = []
exclude_field_names: set[str] = set()
seen_parameter_names: set[str] = set()
reference = self.model_resolver.add(path, name, class_name=True, unique=True)
for parameter_ in parameters:
parameter = self.resolve_object(parameter_, ParameterObject)
Expand All @@ -687,9 +688,10 @@ def parse_all_parameters(
):
continue

if any(field.original_name == parameter_name for field in fields):
if parameter_name in seen_parameter_names:
msg = f"Parameter name '{parameter_name}' is used more than once."
raise Exception(msg) # noqa: TRY002
seen_parameter_names.add(parameter_name)

field_name, alias = self.model_resolver.get_valid_field_name_and_alias(
field_name=parameter_name,
Expand Down Expand Up @@ -919,21 +921,22 @@ def parse_raw(self) -> None:
def _collect_discriminator_schemas(self) -> None:
"""Collect schemas with discriminators but no oneOf/anyOf, and find their subtypes."""
schemas: dict[str, Any] = self.raw_obj.get("components", {}).get("schemas", {})
potential_subtypes: dict[str, list[str]] = {}

for schema_name, schema in schemas.items():
discriminator = schema.get("discriminator")
if not discriminator:
continue

if schema.get("oneOf") or schema.get("anyOf"):
continue

ref = f"#/components/schemas/{schema_name}"
self._discriminator_schemas[ref] = discriminator

for schema_name, schema in schemas.items():
for all_of_item in schema.get("allOf", []):
ref_in_allof = all_of_item.get("$ref")
if ref_in_allof and ref_in_allof in self._discriminator_schemas:
if discriminator and not schema.get("oneOf") and not schema.get("anyOf"):
ref = f"#/components/schemas/{schema_name}"
self._discriminator_schemas[ref] = discriminator

all_of = schema.get("allOf")
if all_of:
refs = [item.get("$ref") for item in all_of if item.get("$ref")]
if refs:
potential_subtypes[schema_name] = refs

for schema_name, refs in potential_subtypes.items():
for ref_in_allof in refs:
if ref_in_allof in self._discriminator_schemas:
subtype_ref = f"#/components/schemas/{schema_name}"
self._discriminator_subtypes[ref_in_allof].append(subtype_ref)
Loading