diff --git a/src/datamodel_code_generator/imports.py b/src/datamodel_code_generator/imports.py index 01e85c08a..593f17dd8 100644 --- a/src/datamodel_code_generator/imports.py +++ b/src/datamodel_code_generator/imports.py @@ -168,6 +168,33 @@ def dump_all(self, *, multiline: bool = False) -> str: items = ", ".join(f'"{name}"' for name in name_list) return f"__all__ = [{items}]" + def get_effective_name(self, from_: str | None, import_: str) -> str: + """Get the effective name after alias resolution.""" + return self.alias.get(from_, {}).get(import_, import_) + + def remove_unused(self, used_names: set[str]) -> None: + """Remove imports not referenced in used_names. + + Note: Checks both effective name (after alias) and original name to handle + cases where code may reference either form (e.g., type annotations may use + original name while runtime code uses alias). + """ + unused = [ + (from_, import_) + for from_, imports_ in self.items() + for import_ in imports_ + if not {self.get_effective_name(from_, import_), import_}.intersection(used_names) + ] + for from_, import_ in unused: + alias = self.alias.get(from_, {}).get(import_) + reference_path = next( + (p for p, i in self.reference_paths.items() if i.from_ == from_ and i.import_ == import_), + None, + ) + import_obj = Import(from_=from_, import_=import_, alias=alias, reference_path=reference_path) + while self.counter.get((from_, import_), 0) > 0: + self.remove(import_obj) + IMPORT_ANNOTATED = Import.from_full_path("typing.Annotated") IMPORT_ANY = Import.from_full_path("typing.Any") diff --git a/src/datamodel_code_generator/model/base.py b/src/datamodel_code_generator/model/base.py index 0eded4cb0..47a7f7467 100644 --- a/src/datamodel_code_generator/model/base.py +++ b/src/datamodel_code_generator/model/base.py @@ -478,6 +478,10 @@ class DataModel(TemplateBase, Nullable, ABC): # noqa: PLR0904 DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = () IS_ALIAS: ClassVar[bool] = False SUPPORTS_GENERIC_BASE_CLASS: ClassVar[bool] = True + SUPPORTS_DISCRIMINATOR: ClassVar[bool] = False + SUPPORTS_FIELD_RENAMING: ClassVar[bool] = False + SUPPORTS_WRAPPED_DEFAULT: ClassVar[bool] = False + SUPPORTS_KW_ONLY: ClassVar[bool] = False has_forward_reference: bool = False def __init__( # noqa: PLR0913 diff --git a/src/datamodel_code_generator/model/dataclass.py b/src/datamodel_code_generator/model/dataclass.py index c3ef3b0c6..545b6c42c 100644 --- a/src/datamodel_code_generator/model/dataclass.py +++ b/src/datamodel_code_generator/model/dataclass.py @@ -42,6 +42,8 @@ class DataClass(DataModel): TEMPLATE_FILE_PATH: ClassVar[str] = "dataclass.jinja2" DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_DATACLASS,) + SUPPORTS_DISCRIMINATOR: ClassVar[bool] = True + SUPPORTS_KW_ONLY: ClassVar[bool] = True def __init__( # noqa: PLR0913 self, diff --git a/src/datamodel_code_generator/model/msgspec.py b/src/datamodel_code_generator/model/msgspec.py index e34c75d9a..e6d9ac1ed 100644 --- a/src/datamodel_code_generator/model/msgspec.py +++ b/src/datamodel_code_generator/model/msgspec.py @@ -113,6 +113,7 @@ class Struct(DataModel): BASE_CLASS_NAME: ClassVar[str] = "Struct" BASE_CLASS_ALIAS: ClassVar[str] = "_Struct" DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = () + SUPPORTS_DISCRIMINATOR: ClassVar[bool] = True CONFIG_MAPPING: ClassVar[dict[tuple[str, Any], tuple[str, Any] | None]] = { ("allow_mutation", False): ("frozen", True), ("extra_fields", "forbid"): ("forbid_unknown_fields", True), diff --git a/src/datamodel_code_generator/model/pydantic/base_model.py b/src/datamodel_code_generator/model/pydantic/base_model.py index 249f3f4f0..57d88711c 100644 --- a/src/datamodel_code_generator/model/pydantic/base_model.py +++ b/src/datamodel_code_generator/model/pydantic/base_model.py @@ -328,6 +328,7 @@ class BaseModel(BaseModelBase): TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic/BaseModel.jinja2" BASE_CLASS: ClassVar[str] = "pydantic.BaseModel" + SUPPORTS_DISCRIMINATOR: ClassVar[bool] = True def __init__( # noqa: PLR0912, PLR0913 self, diff --git a/src/datamodel_code_generator/model/pydantic_v2/base_model.py b/src/datamodel_code_generator/model/pydantic_v2/base_model.py index da2ecaf18..97e8d0db7 100644 --- a/src/datamodel_code_generator/model/pydantic_v2/base_model.py +++ b/src/datamodel_code_generator/model/pydantic_v2/base_model.py @@ -174,6 +174,9 @@ class BaseModel(BaseModelBase): BASE_CLASS: ClassVar[str] = "pydantic.BaseModel" BASE_CLASS_NAME: ClassVar[str] = "BaseModel" BASE_CLASS_ALIAS: ClassVar[str] = "_BaseModel" + SUPPORTS_DISCRIMINATOR: ClassVar[bool] = True + SUPPORTS_FIELD_RENAMING: ClassVar[bool] = True + SUPPORTS_WRAPPED_DEFAULT: ClassVar[bool] = True CONFIG_ATTRIBUTES: ClassVar[list[ConfigAttribute]] = [ ConfigAttribute("allow_population_by_field_name", "populate_by_name", False), # noqa: FBT003 ConfigAttribute("populate_by_name", "populate_by_name", False), # noqa: FBT003 diff --git a/src/datamodel_code_generator/parser/base.py b/src/datamodel_code_generator/parser/base.py index 7a21111aa..f6963e359 100644 --- a/src/datamodel_code_generator/parser/base.py +++ b/src/datamodel_code_generator/parser/base.py @@ -101,6 +101,32 @@ def __ge__(self, value: Any, /) -> bool: ... # noqa: D105 ClassNode: TypeAlias = tuple[ModelName, ...] ClassGraph: TypeAlias = dict[ClassNode, set[ClassNode]] +ModulePath: TypeAlias = tuple[str, ...] +ModuleModels: TypeAlias = list[tuple[ModulePath, list[DataModel]]] +ForwarderMap: TypeAlias = dict[ModulePath, tuple[ModulePath, list[tuple[str, str]]]] + + +class ModuleContext(NamedTuple): + """Context for processing a single module during code generation.""" + + module: ModulePath + module_key: ModulePath + models: list[DataModel] + is_init: bool + imports: Imports + scoped_model_resolver: ModelResolver + + +class ParseConfig(NamedTuple): + """Configuration for the parse operation.""" + + with_import: bool + use_deferred_annotations: bool + code_formatter: CodeFormatter | None + module_split_mode: ModuleSplitMode | None + all_exports_scope: AllExportsScope | None + all_exports_collision_strategy: AllExportsCollisionStrategy | None + class _KeepModelOrderDeps(NamedTuple): strong: ModelDeps @@ -1242,16 +1268,10 @@ def __apply_discriminator_type( # noqa: PLR0912, PLR0914, PLR0915 continue discriminator_model = data_type.reference.source - if not isinstance( # pragma: no cover - discriminator_model, - ( - pydantic_model.BaseModel, - pydantic_model_v2.BaseModel, - dataclass_model.DataClass, - msgspec_model.Struct, - ), - ): - continue # pragma: no cover + if ( + not isinstance(discriminator_model, DataModel) or not discriminator_model.SUPPORTS_DISCRIMINATOR + ): # pragma: no cover + continue type_names: list[str] = [] @@ -1744,7 +1764,7 @@ def __wrap_root_model_default_values( models: list[DataModel], ) -> None: """Wrap RootModel reference default values with their type constructors.""" - if not self.use_annotated: + if not self.use_annotated or not self.data_model_type.SUPPORTS_WRAPPED_DEFAULT: return for model, model_field, data_type in iter_models_field_data_types(models): if isinstance(model, (Enum, self.data_model_root_type)): @@ -1755,8 +1775,7 @@ def __wrap_root_model_default_values( continue if isinstance(model_field.default, list): continue - if data_type.reference and isinstance(data_type.reference.source, pydantic_model_v2.RootModel): - # Use alias if available (handles import collisions) + if data_type.reference and isinstance(data_type.reference.source, self.data_model_root_type): type_name = data_type.alias or data_type.reference.short_name model_field.default = WrappedDefault( value=model_field.default, @@ -1821,7 +1840,7 @@ def __change_field_name( self, models: list[DataModel], ) -> None: - if not issubclass(self.data_model_type, pydantic_model_v2.BaseModel): + if not self.data_model_type.SUPPORTS_FIELD_RENAMING: return for model in models: if "Enum" in model.base_class: @@ -1879,7 +1898,7 @@ def __fix_dataclass_field_ordering(self, models: list[DataModel]) -> None: @classmethod def __get_dataclass_inherited_info(cls, model: DataModel) -> tuple[set[str], bool] | None: """Get inherited field names and whether any has default. Returns None if not applicable.""" - if not isinstance(model, dataclass_model.DataClass): + if not model.SUPPORTS_KW_ONLY: return None if not model.base_classes or model.dataclass_arguments.get("kw_only"): return None @@ -2011,9 +2030,7 @@ def __alias_shadowed_imports( # noqa: PLR6301 def __apply_generic_base_class( # noqa: PLR0912, PLR0914, PLR0915 self, - processed_models: Sequence[ - tuple[tuple[str, ...], tuple[str, ...], list[DataModel], bool, Imports, ModelResolver] - ], + processed_models: Sequence[ModuleContext], ) -> None: if not self.use_generic_base_class or not self.generic_base_class_config: return @@ -2105,9 +2122,7 @@ def __apply_generic_base_class( # noqa: PLR0912, PLR0914, PLR0915 def _collect_exports_for_init( cls, module: tuple[str, ...], - processed_models: Sequence[ - tuple[tuple[str, ...], tuple[str, ...], Sequence[DataModel], bool, Imports, ModelResolver] - ], + processed_models: Sequence[ModuleContext], scope: AllExportsScope, ) -> list[tuple[str, tuple[str, ...], str]]: """Collect exports for __init__.py based on scope.""" @@ -2218,17 +2233,12 @@ def _collect_used_names_from_models(cls, models: list[DataModel]) -> set[str]: def add(name: str | None) -> None: if not name: return - # first segment is sufficient to match import target or alias names.add(name.split(".")[0]) - def walk_data_type(data_type: DataType) -> None: + def collect_data_type_names(data_type: DataType) -> None: add(data_type.alias or data_type.type) if data_type.reference: add(data_type.reference.short_name) - for child in data_type.data_types: - walk_data_type(child) - if data_type.dict_key: - walk_data_type(data_type.dict_key) for model in models: add(model.class_name) @@ -2242,7 +2252,7 @@ def walk_data_type(data_type: DataType) -> None: continue add(field.name) add(field.alias) - walk_data_type(field.data_type) + field.data_type.walk(collect_data_type_names) return names def __generate_forwarder_content( # noqa: PLR6301 @@ -2534,19 +2544,17 @@ def __get_resolve_reference_action_parts( ), ] - def parse( # noqa: PLR0912, PLR0913, PLR0914, PLR0915, PLR0917 + def _prepare_parse_config( # noqa: PLR0913, PLR0917 self, - with_import: bool | None = True, # noqa: FBT001, FBT002 - format_: bool | None = True, # noqa: FBT001, FBT002 - settings_path: Path | None = None, - disable_future_imports: bool = False, # noqa: FBT001, FBT002 - all_exports_scope: AllExportsScope | None = None, - all_exports_collision_strategy: AllExportsCollisionStrategy | None = None, - module_split_mode: ModuleSplitMode | None = None, - ) -> str | dict[tuple[str, ...], Result]: - """Parse schema and generate code, returning single file or module dict.""" - self.parse_raw() - + with_import: bool | None, # noqa: FBT001 + format_: bool | None, # noqa: FBT001 + settings_path: Path | None, + disable_future_imports: bool, # noqa: FBT001 + all_exports_scope: AllExportsScope | None, + all_exports_collision_strategy: AllExportsCollisionStrategy | None, + module_split_mode: ModuleSplitMode | None, + ) -> ParseConfig: + """Prepare configuration for the parse operation.""" use_deferred_annotations = bool( self.target_python_version.has_native_deferred_annotations or (with_import and not disable_future_imports) ) @@ -2558,8 +2566,9 @@ def parse( # noqa: PLR0912, PLR0913, PLR0914, PLR0915, PLR0917 ): self.imports.append(IMPORT_ANNOTATIONS) + code_formatter: CodeFormatter | None = None if format_: - code_formatter: CodeFormatter | None = CodeFormatter( + code_formatter = CodeFormatter( self.target_python_version, settings_path, self.wrap_string_literal, @@ -2570,36 +2579,51 @@ def parse( # noqa: PLR0912, PLR0913, PLR0914, PLR0915, PLR0917 encoding=self.encoding, formatters=self.formatters, ) - else: - code_formatter = None - _, sorted_data_models, require_update_action_models = sort_data_models(self.results) + return ParseConfig( + with_import=bool(with_import), + use_deferred_annotations=use_deferred_annotations, + code_formatter=code_formatter, + module_split_mode=module_split_mode, + all_exports_scope=all_exports_scope, + all_exports_collision_strategy=all_exports_collision_strategy, + ) - results: dict[tuple[str, ...], Result] = {} + def _build_module_structure( + self, + sorted_data_models: SortedDataModels, + require_update_action_models: list[str], + module_split_mode: ModuleSplitMode | None, + ) -> tuple[ + ModuleModels, + set[ModulePath], + ForwarderMap, + dict[str, str], + dict[DataModel, tuple[ModulePath, list[DataModel]]], + dict[str, str], + ]: + """Build module structure from sorted models.""" - def module_key(data_model: DataModel) -> tuple[str, ...]: + def module_key(data_model: DataModel) -> ModulePath: if module_split_mode == ModuleSplitMode.Single: file_name = camel_to_snake(data_model.class_name) return (*data_model.module_path, file_name) return tuple(data_model.module_path) - def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]: + def sort_key(data_model: DataModel) -> tuple[int, ModulePath]: key = module_key(data_model) return (len(key), key) - # process in reverse order to correctly establish module levels grouped_models = groupby( sorted(sorted_data_models.values(), key=sort_key, reverse=True), key=module_key, ) - module_models: list[tuple[tuple[str, ...], list[DataModel]]] = [] - unused_models: list[DataModel] = [] - model_to_module_models: dict[DataModel, tuple[tuple[str, ...], list[DataModel]]] = {} - module_to_import: dict[tuple[str, ...], Imports] = {} + module_models: ModuleModels = [] + model_to_module_models: dict[DataModel, tuple[ModulePath, list[DataModel]]] = {} model_path_to_module_name: dict[str, str] = {} - previous_module: tuple[str, ...] = () + previous_module: ModulePath = () for module, models in ((k, [*v]) for k, v in grouped_models): for model in models: model_to_module_models[model] = module, models @@ -2609,210 +2633,281 @@ def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]: self.__replace_duplicate_name_in_module(models) if len(previous_module) - len(module) > 1: module_models.extend( - ( - previous_module[:parts], - [], - ) - for parts in range(len(previous_module) - 1, len(module), -1) + (previous_module[:parts], []) for parts in range(len(previous_module) - 1, len(module), -1) ) - module_models.append(( - module, - models, - )) + module_models.append((module, models)) previous_module = module shared_module_entry = self.__reuse_model_tree_scope(module_models, require_update_action_models) if shared_module_entry: module_models.insert(0, shared_module_entry) - # Resolve circular imports by moving models to _internal.py modules module_models, internal_modules, forwarder_map, path_mapping = self.__resolve_circular_imports(module_models) - # Update require_update_action_models with new paths for relocated models if path_mapping: require_update_action_models[:] = [path_mapping.get(path, path) for path in require_update_action_models] - class Processed(NamedTuple): - module: tuple[str, ...] - module_key: tuple[str, ...] # Original module tuple (without file extension) - models: list[DataModel] - init: bool - imports: Imports - scoped_model_resolver: ModelResolver - - processed_models: list[Processed] = [] + return ( + module_models, + internal_modules, + forwarder_map, + path_mapping, + model_to_module_models, + model_path_to_module_name, + ) - for module_, models in module_models: - imports = module_to_import[module_] = Imports(self.use_exact_imports) - init = False - if module_: - if len(module_) == 1: - parent = ("__init__.py",) + def _process_single_module( # noqa: PLR0913, PLR0917 + self, + module_: ModulePath, + models: list[DataModel], + results: dict[ModulePath, Result], + config: ParseConfig, + internal_modules: set[ModulePath], + model_path_to_module_name: dict[str, str], + require_update_action_models: list[str], + unused_models: list[DataModel], + ) -> ModuleContext: + """Process a single module and return its context.""" + imports = Imports(self.use_exact_imports) + is_init = False + + if module_: + if len(module_) == 1: + parent: ModulePath = ("__init__.py",) + if parent not in results: + results[parent] = Result(body="") + else: + for i in range(1, len(module_)): + parent = (*module_[:i], "__init__.py") if parent not in results: results[parent] = Result(body="") - else: - for i in range(1, len(module_)): - parent = (*module_[:i], "__init__.py") - if parent not in results: - results[parent] = Result(body="") - if (*module_, "__init__.py") in results: - module = (*module_, "__init__.py") - init = True - else: - module = tuple(part.replace("-", "_") for part in (*module_[:-1], f"{module_[-1]}.py")) + if (*module_, "__init__.py") in results: + module = (*module_, "__init__.py") + is_init = True else: - module = ("__init__.py",) - - all_module_fields = {field.name for model in models for field in model.fields if field.name is not None} - scoped_model_resolver = ModelResolver(exclude_names=all_module_fields) - - self.__alias_shadowed_imports(models, all_module_fields) - self.__override_required_field(models) - self.__replace_unique_list_to_set(models) - self.__change_from_import( - models, - imports, - scoped_model_resolver, - init=init, - internal_modules=internal_modules, - model_path_to_module_name=model_path_to_module_name, - ) - self.__extract_inherited_enum(models) - self.__set_reference_default_value_to_field(models) - self.__reuse_model(models, require_update_action_models) - self.__collapse_root_models(models, unused_models, imports, scoped_model_resolver) - self.__set_default_enum_member(models) - self.__wrap_root_model_default_values(models) - self.__sort_models( - models, - imports, - use_deferred_annotations=bool( - self.target_python_version.has_native_deferred_annotations - or (with_import and not disable_future_imports) - ), - ) - self.__change_field_name(models) - self.__apply_discriminator_type(models, imports) - self.__set_one_literal_on_default(models) - self.__fix_dataclass_field_ordering(models) - self.__update_type_aliases(models) - - processed_models.append(Processed(module, module_, models, init, imports, scoped_model_resolver)) - - self.__apply_generic_base_class(processed_models) + module = tuple(part.replace("-", "_") for part in (*module_[:-1], f"{module_[-1]}.py")) + else: + module = ("__init__.py",) + + all_module_fields = {field.name for model in models for field in model.fields if field.name is not None} + scoped_model_resolver = ModelResolver(exclude_names=all_module_fields) + + self.__alias_shadowed_imports(models, all_module_fields) + self.__override_required_field(models) + self.__replace_unique_list_to_set(models) + self.__change_from_import( + models, + imports, + scoped_model_resolver, + init=is_init, + internal_modules=internal_modules, + model_path_to_module_name=model_path_to_module_name, + ) + self.__extract_inherited_enum(models) + self.__set_reference_default_value_to_field(models) + self.__reuse_model(models, require_update_action_models) + self.__collapse_root_models(models, unused_models, imports, scoped_model_resolver) + self.__set_default_enum_member(models) + self.__wrap_root_model_default_values(models) + self.__sort_models(models, imports, use_deferred_annotations=config.use_deferred_annotations) + self.__change_field_name(models) + self.__apply_discriminator_type(models, imports) + self.__set_one_literal_on_default(models) + self.__fix_dataclass_field_ordering(models) + self.__update_type_aliases(models) + + return ModuleContext(module, module_, models, is_init, imports, scoped_model_resolver) + + def _finalize_modules( + self, + contexts: list[ModuleContext], + unused_models: list[DataModel], + model_to_module_models: dict[DataModel, tuple[ModulePath, list[DataModel]]], + module_to_import: dict[ModulePath, Imports], + ) -> None: + """Finalize module processing: apply generic base class and remove unused imports.""" + self.__apply_generic_base_class(contexts) - for processed_model in processed_models: - for model in processed_model.models: - processed_model.imports.append(model.imports) + for ctx in contexts: + for model in ctx.models: + ctx.imports.append(model.imports) for unused_model in unused_models: module, models = model_to_module_models[unused_model] - if unused_model in models: # pragma: no cover + if unused_model in models: # pragma: no branch imports = module_to_import[module] imports.remove(unused_model.imports) models.remove(unused_model) - for processed_model in processed_models: - # postprocess imports to remove unused imports. - used_names = self._collect_used_names_from_models(processed_model.models) - unused_imports = [ - (from_, import_) - for from_, imports_ in processed_model.imports.items() - for import_ in imports_ - if not {processed_model.imports.alias.get(from_, {}).get(import_, import_), import_}.intersection( - used_names - ) - ] - for from_, import_ in unused_imports: - import_obj = Import(from_=from_, import_=import_) - while processed_model.imports.counter.get((from_, import_), 0) > 0: - processed_model.imports.remove(import_obj) + for ctx in contexts: + used_names = self._collect_used_names_from_models(ctx.models) + ctx.imports.remove_unused(used_names) - for module, mod_key, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007 - # process after removing unused models - self.__change_imported_model_name(models, imports, scoped_model_resolver) + for ctx in contexts: + self.__change_imported_model_name(ctx.models, ctx.imports, ctx.scoped_model_resolver) - future_imports = self.imports.extract_future() - future_imports_str = str(future_imports) + def _generate_module_output( # noqa: PLR0913, PLR0917 + self, + ctx: ModuleContext, + config: ParseConfig, + contexts: list[ModuleContext], + forwarder_map: ForwarderMap, + require_update_action_models: list[str], + future_imports_str: str, + ) -> Result | None: + """Generate output for a single module.""" + result: list[str] = [] + export_imports: Imports | None = None + + if config.all_exports_scope is not None and ctx.module[-1] == "__init__.py": + child_exports = self._collect_exports_for_init(ctx.module, contexts, config.all_exports_scope) + if child_exports: + local_model_names = { + m.reference.short_name + for m in ctx.models + if m.reference and not m.reference.short_name.startswith("_") # pragma: no branch + } + resolved_exports = self._resolve_export_collisions( + child_exports, config.all_exports_collision_strategy, local_model_names + ) + export_imports = self._build_all_exports_code(resolved_exports) + + if ctx.models: + if config.with_import: + import_parts = [s for s in [future_imports_str, str(self.imports), str(ctx.imports)] if s] + result += [*import_parts, "\n"] + + if export_imports: + result += [str(export_imports), ""] + for m in ctx.models: + if m.reference and not m.reference.short_name.startswith("_"): # pragma: no branch + export_imports.add_export(m.reference.short_name) + result += [export_imports.dump_all(multiline=True) + "\n"] + + code = dump_templates(ctx.models) + result += [code] + + result += self.__get_resolve_reference_action_parts( + ctx.models, + require_update_action_models, + use_deferred_annotations=config.use_deferred_annotations, + ) - for module, mod_key, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007 - result: list[str] = [] - export_imports: Imports | None = None - - if all_exports_scope is not None and module[-1] == "__init__.py": - child_exports = self._collect_exports_for_init(module, processed_models, all_exports_scope) - if child_exports: - local_model_names = { - m.reference.short_name - for m in models - if m.reference and not m.reference.short_name.startswith("_") - } - resolved_exports = self._resolve_export_collisions( - child_exports, all_exports_collision_strategy, local_model_names - ) - export_imports = self._build_all_exports_code(resolved_exports) + if not result and ctx.module_key in forwarder_map: + internal_module, class_mappings = forwarder_map[ctx.module_key] + forwarder_content = self.__generate_forwarder_content( + ctx.module_key, internal_module, class_mappings, is_init=ctx.is_init + ) + result = [forwarder_content] - if models: - if with_import: - import_parts = [s for s in [future_imports_str, str(self.imports), str(imports)] if s] - result += [*import_parts, "\n"] + if not result and not ctx.is_init: + return None - if export_imports: - result += [str(export_imports), ""] - for m in models: - if m.reference and not m.reference.short_name.startswith("_"): # pragma: no branch - export_imports.add_export(m.reference.short_name) - result += [export_imports.dump_all(multiline=True) + "\n"] + body = "\n".join(result) + if config.code_formatter: + body = config.code_formatter.format_code(body) - code = dump_templates(models) - result += [code] + return Result( + body=body, + future_imports=future_imports_str, + source=ctx.models[0].file_path if ctx.models else None, + ) - result += self.__get_resolve_reference_action_parts( - models, - require_update_action_models, - use_deferred_annotations=use_deferred_annotations, + def _generate_empty_init_exports( + self, + results: dict[ModulePath, Result], + contexts: list[ModuleContext], + config: ParseConfig, + future_imports_str: str, + ) -> None: + """Generate exports for empty __init__.py files.""" + if config.all_exports_scope is None: # pragma: no cover + return + processed_init_modules = {ctx.module for ctx in contexts if ctx.module[-1] == "__init__.py"} + for init_module, init_result in list(results.items()): + if init_module[-1] != "__init__.py" or init_module in processed_init_modules or init_result.body: + continue + child_exports = self._collect_exports_for_init(init_module, contexts, config.all_exports_scope) + if child_exports: + resolved = self._resolve_export_collisions(child_exports, config.all_exports_collision_strategy, set()) + export_imports = self._build_all_exports_code(resolved) + import_parts = [s for s in [future_imports_str, str(self.imports)] if s] if config.with_import else [] + parts = import_parts + (["\n"] if import_parts else []) + parts += [str(export_imports), "", export_imports.dump_all(multiline=True)] + body = "\n".join(parts) + results[init_module] = Result( + body=config.code_formatter.format_code(body) if config.code_formatter else body, + future_imports=future_imports_str, ) - # Generate forwarder content for modules that had models moved to _internal - if not result and mod_key in forwarder_map: - internal_module, class_mappings = forwarder_map[mod_key] - forwarder_content = self.__generate_forwarder_content( - mod_key, internal_module, class_mappings, is_init=init - ) - result = [forwarder_content] + def parse( # noqa: PLR0913, PLR0914, PLR0917 + self, + with_import: bool | None = True, # noqa: FBT001, FBT002 + format_: bool | None = True, # noqa: FBT001, FBT002 + settings_path: Path | None = None, + disable_future_imports: bool = False, # noqa: FBT001, FBT002 + all_exports_scope: AllExportsScope | None = None, + all_exports_collision_strategy: AllExportsCollisionStrategy | None = None, + module_split_mode: ModuleSplitMode | None = None, + ) -> str | dict[tuple[str, ...], Result]: + """Parse schema and generate code, returning single file or module dict.""" + self.parse_raw() - if not result and not init: - continue - body = "\n".join(result) - if code_formatter: - body = code_formatter.format_code(body) - - results[module] = Result( - body=body, - future_imports=future_imports_str, - source=models[0].file_path if models else None, + config = self._prepare_parse_config( + with_import, + format_, + settings_path, + disable_future_imports, + all_exports_scope, + all_exports_collision_strategy, + module_split_mode, + ) + + _, sorted_data_models, require_update_action_models = sort_data_models(self.results) + + ( + module_models, + internal_modules, + forwarder_map, + _path_mapping, + model_to_module_models, + model_path_to_module_name, + ) = self._build_module_structure(sorted_data_models, require_update_action_models, module_split_mode) + + results: dict[ModulePath, Result] = {} + unused_models: list[DataModel] = [] + module_to_import: dict[ModulePath, Imports] = {} + contexts: list[ModuleContext] = [] + + for module_, models in module_models: + ctx = self._process_single_module( + module_, + models, + results, + config, + internal_modules, + model_path_to_module_name, + require_update_action_models, + unused_models, ) + module_to_import[module_] = ctx.imports + contexts.append(ctx) - if all_exports_scope is not None: - processed_init_modules = {m for m, _, _, _, _, _ in processed_models if m[-1] == "__init__.py"} - for init_module, init_result in list(results.items()): - if init_module[-1] != "__init__.py" or init_module in processed_init_modules or init_result.body: - continue - if child_exports := self._collect_exports_for_init( - init_module, processed_models, all_exports_scope - ): # pragma: no branch - resolved = self._resolve_export_collisions(child_exports, all_exports_collision_strategy, set()) - export_imports = self._build_all_exports_code(resolved) - import_parts = [s for s in [future_imports_str, str(self.imports)] if s] if with_import else [] - parts = import_parts + (["\n"] if import_parts else []) - parts += [str(export_imports), "", export_imports.dump_all(multiline=True)] - body = "\n".join(parts) - results[init_module] = Result( - body=code_formatter.format_code(body) if code_formatter else body, - future_imports=future_imports_str, - ) + self._finalize_modules(contexts, unused_models, model_to_module_models, module_to_import) + + future_imports = self.imports.extract_future() + future_imports_str = str(future_imports) + + for ctx in contexts: + result = self._generate_module_output( + ctx, config, contexts, forwarder_map, require_update_action_models, future_imports_str + ) + if result is not None: + results[ctx.module] = result + + if config.all_exports_scope is not None: + self._generate_empty_init_exports(results, contexts, config, future_imports_str) - # retain existing behaviour if [*results] == [("__init__.py",)]: single_result = results["__init__.py",] return single_result.body diff --git a/src/datamodel_code_generator/types.py b/src/datamodel_code_generator/types.py index 9d309dae1..eca9192b7 100644 --- a/src/datamodel_code_generator/types.py +++ b/src/datamodel_code_generator/types.py @@ -429,6 +429,24 @@ def all_data_types(self) -> Iterator[DataType]: yield from self.dict_key.all_data_types yield self + def walk( + self, + visitor: Callable[[DataType], None], + visited: set[int] | None = None, + ) -> None: + """Recursively walk this DataType tree, calling visitor on each node.""" + if visited is None: + visited = set() + node_id = id(self) + if node_id in visited: + return + visited.add(node_id) + visitor(self) + for child in self.data_types: + child.walk(visitor, visited) + if self.dict_key: + self.dict_key.walk(visitor, visited) + def find_source(self, source_type: type[SourceT]) -> SourceT | None: """Find the first reference source matching the given type from all nested data types.""" for data_type in self.all_data_types: # pragma: no branch