diff --git a/.github/workflows/config-types.yaml b/.github/workflows/config-types.yaml new file mode 100644 index 000000000..ccdca9bd6 --- /dev/null +++ b/.github/workflows/config-types.yaml @@ -0,0 +1,31 @@ +name: Check config types + +on: + push: + paths: + - 'src/datamodel_code_generator/config.py' + - 'src/datamodel_code_generator/_types/**' + - 'pyproject.toml' + - '.github/workflows/config-types.yaml' + pull_request: + paths: + - 'src/datamodel_code_generator/config.py' + - 'src/datamodel_code_generator/_types/**' + - 'pyproject.toml' + - '.github/workflows/config-types.yaml' + +jobs: + config-types: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - run: uv python install 3.14 + + - run: uv tool install --python 3.14 tox --with tox-uv + + - run: tox -e config-types -- --check diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index d11444ea2..0d9508d89 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -34,7 +34,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.14" - - run: uvx prek run --all-files --show-diff-on-failure --skip readme + - run: SKIP=readme,config-types uvx prek run --all-files --show-diff-on-failure - if: | github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository || diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ff6ece90..b8b131bc8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,9 +18,9 @@ repos: rev: 'v0.14.9' hooks: - id: ruff-format - exclude: "^tests/data" + exclude: "^tests/data|^src/datamodel_code_generator/_types/" - id: ruff - exclude: "^tests/data" + exclude: "^tests/data|^src/datamodel_code_generator/_types/" args: ["--fix", "--unsafe-fixes", "--exit-non-zero-on-fix"] - repo: https://github.com/codespell-project/codespell rev: v2.4.1 @@ -37,3 +37,9 @@ repos: language: system files: ^(src/datamodel_code_generator/arguments\.py|README\.md|docs/index\.md)$ pass_filenames: false + - id: config-types + name: Generate config TypedDicts + entry: bash -c 'test -x .tox/dev/bin/python || tox run -e dev --notest -qq; .tox/dev/bin/datamodel-codegen --profile generate-config-dict && .tox/dev/bin/datamodel-codegen --profile parser-config-dict && .tox/dev/bin/datamodel-codegen --profile parse-config-dict' + language: system + files: ^src/datamodel_code_generator/config\.py$ + pass_filenames: false diff --git a/pyproject.toml b/pyproject.toml index 79a0d63ab..9c488419f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -244,7 +244,7 @@ paths.other = [ "*\\datamodel-code-generator", ] run.dynamic_context = "none" -run.omit = [ "tests/data/*", "tests/main/test_performance.py" ] +run.omit = [ "tests/data/*", "tests/main/test_performance.py", "*/_types/*" ] report.fail_under = 88 run.parallel = true run.plugins = [ @@ -258,3 +258,30 @@ reportPrivateImportUsage = false [tool.pydantic-pycharm-plugin] ignore-init-method-arguments = true parsable-types.str = [ "int", "float" ] + +[tool.datamodel-codegen.profiles.config-types-base] +enum-field-as-literal = "none" +use-standard-primitive-types = true +disable-warnings = true +disable-timestamp = true +output-model-type = "typing.TypedDict" +formatters = [ "ruff-format", "ruff-check" ] +input-model-ref-strategy = "reuse-foreign" + +[tool.datamodel-codegen.profiles.generate-config-dict] +extends = "config-types-base" +input-model = "src/datamodel_code_generator/config.py:GenerateConfig" +output = "src/datamodel_code_generator/_types/generate_config_dict.py" +class-name = "GenerateConfigDict" + +[tool.datamodel-codegen.profiles.parser-config-dict] +extends = "config-types-base" +input-model = "src/datamodel_code_generator/config.py:ParserConfig" +output = "src/datamodel_code_generator/_types/parser_config_dict.py" +class-name = "ParserConfigDict" + +[tool.datamodel-codegen.profiles.parse-config-dict] +extends = "config-types-base" +input-model = "src/datamodel_code_generator/config.py:ParseConfig" +output = "src/datamodel_code_generator/_types/parse_config_dict.py" +class-name = "ParseConfigDict" diff --git a/src/datamodel_code_generator/__init__.py b/src/datamodel_code_generator/__init__.py index eacca90b7..d544eef7a 100644 --- a/src/datamodel_code_generator/__init__.py +++ b/src/datamodel_code_generator/__init__.py @@ -61,6 +61,7 @@ if TYPE_CHECKING: from collections import defaultdict + from datamodel_code_generator.config import GenerateConfig from datamodel_code_generator.model.pydantic_v2 import UnionMode from datamodel_code_generator.parser.base import Parser from datamodel_code_generator.types import StrictTypes @@ -451,6 +452,7 @@ def _build_module_content( def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915 input_: Path | str | ParseResult | Mapping[str, Any], *, + config: GenerateConfig | None = None, input_filename: str | None = None, input_file_type: InputFileType = InputFileType.Auto, output: Path | None = None, @@ -512,7 +514,7 @@ def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915 model_extra_keys_without_x_prefix: set[str] | None = None, openapi_scopes: list[OpenAPIScope] | None = None, include_path_parameters: bool = False, - graphql_scopes: list[GraphQLScope] | None = None, # noqa: ARG001 + graphql_scopes: list[GraphQLScope] | None = None, wrap_string_literal: bool | None = None, use_title_as_name: bool = False, use_operation_id_as_name: bool = False, @@ -583,6 +585,190 @@ def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915 - When output is None and multiple modules: GeneratedModules (dict mapping module path tuples to generated code strings) """ + if config is not None: + input_filename = config.input_filename if input_filename is None else input_filename + input_file_type = config.input_file_type if input_file_type == InputFileType.Auto else input_file_type + output = config.output if output is None else output + output_model_type = ( + config.output_model_type if output_model_type == DataModelType.PydanticBaseModel else output_model_type + ) + target_python_version = ( + config.target_python_version if target_python_version == PythonVersionMin else target_python_version + ) + target_pydantic_version = ( + config.target_pydantic_version if target_pydantic_version is None else target_pydantic_version + ) + base_class = base_class or config.base_class + base_class_map = config.base_class_map if base_class_map is None else base_class_map + additional_imports = config.additional_imports if additional_imports is None else additional_imports + class_decorators = config.class_decorators if class_decorators is None else class_decorators + custom_template_dir = config.custom_template_dir if custom_template_dir is None else custom_template_dir + if extra_template_data is None and config.extra_template_data is not None: + from collections import defaultdict as _defaultdict # noqa: PLC0415 + + extra_template_data = _defaultdict(dict, config.extra_template_data) + validation = validation or config.validation + field_constraints = field_constraints or config.field_constraints + snake_case_field = snake_case_field or config.snake_case_field + strip_default_none = strip_default_none or config.strip_default_none + aliases = config.aliases if aliases is None else aliases + disable_timestamp = disable_timestamp or config.disable_timestamp + enable_version_header = enable_version_header or config.enable_version_header + enable_command_header = enable_command_header or config.enable_command_header + command_line = config.command_line if command_line is None else command_line + allow_population_by_field_name = allow_population_by_field_name or config.allow_population_by_field_name + allow_extra_fields = allow_extra_fields or config.allow_extra_fields + extra_fields = config.extra_fields if extra_fields is None else extra_fields + use_generic_base_class = use_generic_base_class or config.use_generic_base_class + apply_default_values_for_required_fields = ( + apply_default_values_for_required_fields or config.apply_default_values_for_required_fields + ) + force_optional_for_required_fields = ( + force_optional_for_required_fields or config.force_optional_for_required_fields + ) + class_name = config.class_name if class_name is None else class_name + use_standard_collections = ( + config.use_standard_collections if use_standard_collections else use_standard_collections + ) + use_schema_description = use_schema_description or config.use_schema_description + use_field_description = use_field_description or config.use_field_description + use_field_description_example = use_field_description_example or config.use_field_description_example + use_attribute_docstrings = use_attribute_docstrings or config.use_attribute_docstrings + use_inline_field_description = use_inline_field_description or config.use_inline_field_description + use_default_kwarg = use_default_kwarg or config.use_default_kwarg + reuse_model = reuse_model or config.reuse_model + reuse_scope = config.reuse_scope if reuse_scope == ReuseScope.Module else reuse_scope + shared_module_name = ( + config.shared_module_name if shared_module_name == DEFAULT_SHARED_MODULE_NAME else shared_module_name + ) + encoding = config.encoding if encoding == "utf-8" else encoding + enum_field_as_literal = config.enum_field_as_literal if enum_field_as_literal is None else enum_field_as_literal + enum_field_as_literal_map = ( + config.enum_field_as_literal_map if enum_field_as_literal_map is None else enum_field_as_literal_map + ) + ignore_enum_constraints = ignore_enum_constraints or config.ignore_enum_constraints + use_one_literal_as_default = use_one_literal_as_default or config.use_one_literal_as_default + use_enum_values_in_discriminator = use_enum_values_in_discriminator or config.use_enum_values_in_discriminator + set_default_enum_member = set_default_enum_member or config.set_default_enum_member + use_subclass_enum = use_subclass_enum or config.use_subclass_enum + use_specialized_enum = config.use_specialized_enum if use_specialized_enum else False + strict_nullable = strict_nullable or config.strict_nullable + use_generic_container_types = use_generic_container_types or config.use_generic_container_types + enable_faux_immutability = enable_faux_immutability or config.enable_faux_immutability + disable_appending_item_suffix = disable_appending_item_suffix or config.disable_appending_item_suffix + strict_types = config.strict_types if strict_types is None else strict_types + empty_enum_field_name = config.empty_enum_field_name if empty_enum_field_name is None else empty_enum_field_name + custom_class_name_generator = ( + config.custom_class_name_generator if custom_class_name_generator is None else custom_class_name_generator + ) + field_extra_keys = config.field_extra_keys if field_extra_keys is None else field_extra_keys + field_include_all_keys = field_include_all_keys or config.field_include_all_keys + field_extra_keys_without_x_prefix = ( + config.field_extra_keys_without_x_prefix + if field_extra_keys_without_x_prefix is None + else field_extra_keys_without_x_prefix + ) + model_extra_keys = config.model_extra_keys if model_extra_keys is None else model_extra_keys + model_extra_keys_without_x_prefix = ( + config.model_extra_keys_without_x_prefix + if model_extra_keys_without_x_prefix is None + else model_extra_keys_without_x_prefix + ) + openapi_scopes = config.openapi_scopes if openapi_scopes is None else openapi_scopes + include_path_parameters = include_path_parameters or config.include_path_parameters + graphql_scopes = config.graphql_scopes if graphql_scopes is None else graphql_scopes + wrap_string_literal = config.wrap_string_literal if wrap_string_literal is None else wrap_string_literal + use_title_as_name = use_title_as_name or config.use_title_as_name + use_operation_id_as_name = use_operation_id_as_name or config.use_operation_id_as_name + use_unique_items_as_set = use_unique_items_as_set or config.use_unique_items_as_set + use_tuple_for_fixed_items = use_tuple_for_fixed_items or config.use_tuple_for_fixed_items + allof_merge_mode = ( + config.allof_merge_mode if allof_merge_mode == AllOfMergeMode.Constraints else allof_merge_mode + ) + http_headers = config.http_headers if http_headers is None else http_headers + http_ignore_tls = http_ignore_tls or config.http_ignore_tls + http_timeout = config.http_timeout if http_timeout is None else http_timeout + use_annotated = use_annotated or config.use_annotated + use_serialize_as_any = use_serialize_as_any or config.use_serialize_as_any + use_non_positive_negative_number_constrained_types = ( + use_non_positive_negative_number_constrained_types + or config.use_non_positive_negative_number_constrained_types + ) + use_decimal_for_multiple_of = use_decimal_for_multiple_of or config.use_decimal_for_multiple_of + original_field_name_delimiter = ( + config.original_field_name_delimiter + if original_field_name_delimiter is None + else original_field_name_delimiter + ) + use_double_quotes = use_double_quotes or config.use_double_quotes + use_union_operator = config.use_union_operator if use_union_operator else False + collapse_root_models = collapse_root_models or config.collapse_root_models + collapse_root_models_name_strategy = ( + config.collapse_root_models_name_strategy + if collapse_root_models_name_strategy is None + else collapse_root_models_name_strategy + ) + collapse_reuse_models = collapse_reuse_models or config.collapse_reuse_models + skip_root_model = skip_root_model or config.skip_root_model + use_type_alias = use_type_alias or config.use_type_alias + use_root_model_type_alias = use_root_model_type_alias or config.use_root_model_type_alias + special_field_name_prefix = ( + config.special_field_name_prefix if special_field_name_prefix is None else special_field_name_prefix + ) + remove_special_field_name_prefix = remove_special_field_name_prefix or config.remove_special_field_name_prefix + capitalise_enum_members = capitalise_enum_members or config.capitalise_enum_members + keep_model_order = keep_model_order or config.keep_model_order + custom_file_header = config.custom_file_header if custom_file_header is None else custom_file_header + custom_file_header_path = ( + config.custom_file_header_path if custom_file_header_path is None else custom_file_header_path + ) + custom_formatters = config.custom_formatters if custom_formatters is None else custom_formatters + custom_formatters_kwargs = ( + config.custom_formatters_kwargs if custom_formatters_kwargs is None else custom_formatters_kwargs + ) + use_pendulum = use_pendulum or config.use_pendulum + use_standard_primitive_types = use_standard_primitive_types or config.use_standard_primitive_types + http_query_parameters = config.http_query_parameters if http_query_parameters is None else http_query_parameters + treat_dot_as_module = config.treat_dot_as_module if treat_dot_as_module is None else treat_dot_as_module + use_exact_imports = use_exact_imports or config.use_exact_imports + union_mode = config.union_mode if union_mode is None else union_mode + output_datetime_class = config.output_datetime_class if output_datetime_class is None else output_datetime_class + output_date_class = config.output_date_class if output_date_class is None else output_date_class + keyword_only = keyword_only or config.keyword_only + frozen_dataclasses = frozen_dataclasses or config.frozen_dataclasses + no_alias = no_alias or config.no_alias + use_frozen_field = use_frozen_field or config.use_frozen_field + use_default_factory_for_optional_nested_models = ( + use_default_factory_for_optional_nested_models or config.use_default_factory_for_optional_nested_models + ) + formatters = config.formatters if formatters == DEFAULT_FORMATTERS else formatters + settings_path = config.settings_path if settings_path is None else settings_path + parent_scoped_naming = parent_scoped_naming or config.parent_scoped_naming + naming_strategy = config.naming_strategy if naming_strategy is None else naming_strategy + duplicate_name_suffix = config.duplicate_name_suffix if duplicate_name_suffix is None else duplicate_name_suffix + dataclass_arguments = config.dataclass_arguments if dataclass_arguments is None else dataclass_arguments + disable_future_imports = disable_future_imports or config.disable_future_imports + type_mappings = config.type_mappings if type_mappings is None else type_mappings + type_overrides = config.type_overrides if type_overrides is None else type_overrides + read_only_write_only_model_type = ( + config.read_only_write_only_model_type + if read_only_write_only_model_type is None + else read_only_write_only_model_type + ) + use_status_code_in_response_name = use_status_code_in_response_name or config.use_status_code_in_response_name + all_exports_scope = config.all_exports_scope if all_exports_scope is None else all_exports_scope + all_exports_collision_strategy = ( + config.all_exports_collision_strategy + if all_exports_collision_strategy is None + else all_exports_collision_strategy + ) + field_type_collision_strategy = ( + config.field_type_collision_strategy + if field_type_collision_strategy is None + else field_type_collision_strategy + ) + module_split_mode = config.module_split_mode if module_split_mode is None else module_split_mode + remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict() match input_: case str(): diff --git a/src/datamodel_code_generator/__main__.py b/src/datamodel_code_generator/__main__.py index 9258f052d..ba0f4de73 100644 --- a/src/datamodel_code_generator/__main__.py +++ b/src/datamodel_code_generator/__main__.py @@ -1116,6 +1116,37 @@ def _filter_defs_by_strategy( return {**schema, "$defs": new_defs} +def _try_rebuild_model(obj: type) -> None: + """Try to rebuild a Pydantic model, handling config models specially.""" + module = getattr(obj, "__module__", "") + class_name = getattr(obj, "__name__", "") + config_classes = {"GenerateConfig", "ParserConfig", "ParseConfig"} + if module in {"datamodel_code_generator.config", "config"} and class_name in config_classes: + from datamodel_code_generator.model.base import DataModel, DataModelFieldBase # noqa: PLC0415 + from datamodel_code_generator.types import DataTypeManager, StrictTypes # noqa: PLC0415 + + try: + from datamodel_code_generator.model.pydantic_v2 import UnionMode # noqa: PLC0415 + except ImportError: # pragma: no cover + from typing import Any # noqa: PLC0415 + + runtime_union_mode = Any + else: + runtime_union_mode = UnionMode + + types_namespace = { + "Path": Path, + "DataModel": DataModel, + "DataModelFieldBase": DataModelFieldBase, + "DataTypeManager": DataTypeManager, + "StrictTypes": StrictTypes, + "UnionMode": runtime_union_mode, + } + obj.model_rebuild(_types_namespace=types_namespace) + else: + obj.model_rebuild() + + def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915 input_model: str, input_file_type: InputFileType, @@ -1199,6 +1230,8 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915 if not hasattr(obj, "model_json_schema"): msg = "--input-model with Pydantic model requires Pydantic v2 runtime. Please upgrade Pydantic to v2." raise Error(msg) + if hasattr(obj, "model_rebuild"): # pragma: no branch + _try_rebuild_model(obj) schema_generator = _get_input_model_json_schema_class() schema = obj.model_json_schema(schema_generator=schema_generator) schema = _add_python_type_for_unserializable(schema, obj) diff --git a/src/datamodel_code_generator/_types/__init__.py b/src/datamodel_code_generator/_types/__init__.py new file mode 100644 index 000000000..61c932e82 --- /dev/null +++ b/src/datamodel_code_generator/_types/__init__.py @@ -0,0 +1,9 @@ +"""Auto-generated TypedDict definitions for config classes.""" + +from __future__ import annotations + +from datamodel_code_generator._types.generate_config_dict import GenerateConfigDict +from datamodel_code_generator._types.parse_config_dict import ParseConfigDict +from datamodel_code_generator._types.parser_config_dict import ParserConfigDict + +__all__ = ["GenerateConfigDict", "ParseConfigDict", "ParserConfigDict"] diff --git a/src/datamodel_code_generator/_types/generate_config_dict.py b/src/datamodel_code_generator/_types/generate_config_dict.py new file mode 100644 index 000000000..7cbbe3daf --- /dev/null +++ b/src/datamodel_code_generator/_types/generate_config_dict.py @@ -0,0 +1,158 @@ +# generated by datamodel-codegen: +# filename: + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypedDict + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from collections import defaultdict + from collections.abc import Callable, Mapping, Sequence + from pathlib import Path + + from datamodel_code_generator.enums import ( + AllExportsCollisionStrategy, + AllExportsScope, + AllOfMergeMode, + CollapseRootModelsNameStrategy, + DataclassArguments, + DataModelType, + FieldTypeCollisionStrategy, + GraphQLScope, + InputFileType, + ModuleSplitMode, + NamingStrategy, + OpenAPIScope, + ReadOnlyWriteOnlyModelType, + ReuseScope, + StrictTypes, + TargetPydanticVersion, + UnionMode, + ) + from datamodel_code_generator.format import DateClassType, DatetimeClassType, Formatter, PythonVersion + from datamodel_code_generator.parser import LiteralType + + +class GenerateConfigDict(TypedDict): + input_filename: NotRequired[str | None] + input_file_type: NotRequired[InputFileType] + output: NotRequired[Path | None] + output_model_type: NotRequired[DataModelType] + target_python_version: NotRequired[PythonVersion] + target_pydantic_version: NotRequired[TargetPydanticVersion | None] + base_class: NotRequired[str] + base_class_map: NotRequired[dict[str, str] | None] + additional_imports: NotRequired[list[str] | None] + class_decorators: NotRequired[list[str] | None] + custom_template_dir: NotRequired[Path | None] + extra_template_data: NotRequired[defaultdict[str, dict[str, Any]] | None] + validation: NotRequired[bool] + field_constraints: NotRequired[bool] + snake_case_field: NotRequired[bool] + strip_default_none: NotRequired[bool] + aliases: NotRequired[Mapping[str, str] | None] + disable_timestamp: NotRequired[bool] + enable_version_header: NotRequired[bool] + enable_command_header: NotRequired[bool] + command_line: NotRequired[str | None] + allow_population_by_field_name: NotRequired[bool] + allow_extra_fields: NotRequired[bool] + extra_fields: NotRequired[str | None] + use_generic_base_class: NotRequired[bool] + apply_default_values_for_required_fields: NotRequired[bool] + force_optional_for_required_fields: NotRequired[bool] + class_name: NotRequired[str | None] + use_standard_collections: NotRequired[bool] + use_schema_description: NotRequired[bool] + use_field_description: NotRequired[bool] + use_field_description_example: NotRequired[bool] + use_attribute_docstrings: NotRequired[bool] + use_inline_field_description: NotRequired[bool] + use_default_kwarg: NotRequired[bool] + reuse_model: NotRequired[bool] + reuse_scope: NotRequired[ReuseScope] + shared_module_name: NotRequired[str] + encoding: NotRequired[str] + enum_field_as_literal: NotRequired[LiteralType | None] + enum_field_as_literal_map: NotRequired[dict[str, str] | None] + ignore_enum_constraints: NotRequired[bool] + use_one_literal_as_default: NotRequired[bool] + use_enum_values_in_discriminator: NotRequired[bool] + set_default_enum_member: NotRequired[bool] + use_subclass_enum: NotRequired[bool] + use_specialized_enum: NotRequired[bool] + strict_nullable: NotRequired[bool] + use_generic_container_types: NotRequired[bool] + enable_faux_immutability: NotRequired[bool] + disable_appending_item_suffix: NotRequired[bool] + strict_types: NotRequired[Sequence[StrictTypes] | None] + empty_enum_field_name: NotRequired[str | None] + custom_class_name_generator: NotRequired[Callable[[str], str] | None] + field_extra_keys: NotRequired[set[str] | None] + field_include_all_keys: NotRequired[bool] + field_extra_keys_without_x_prefix: NotRequired[set[str] | None] + model_extra_keys: NotRequired[set[str] | None] + model_extra_keys_without_x_prefix: NotRequired[set[str] | None] + openapi_scopes: NotRequired[list[OpenAPIScope] | None] + include_path_parameters: NotRequired[bool] + graphql_scopes: NotRequired[list[GraphQLScope] | None] + wrap_string_literal: NotRequired[bool | None] + use_title_as_name: NotRequired[bool] + use_operation_id_as_name: NotRequired[bool] + use_unique_items_as_set: NotRequired[bool] + use_tuple_for_fixed_items: NotRequired[bool] + allof_merge_mode: NotRequired[AllOfMergeMode] + http_headers: NotRequired[Sequence[tuple[str, str]] | None] + http_ignore_tls: NotRequired[bool] + http_timeout: NotRequired[float | None] + use_annotated: NotRequired[bool] + use_serialize_as_any: NotRequired[bool] + use_non_positive_negative_number_constrained_types: NotRequired[bool] + use_decimal_for_multiple_of: NotRequired[bool] + original_field_name_delimiter: NotRequired[str | None] + use_double_quotes: NotRequired[bool] + use_union_operator: NotRequired[bool] + collapse_root_models: NotRequired[bool] + collapse_root_models_name_strategy: NotRequired[CollapseRootModelsNameStrategy | None] + collapse_reuse_models: NotRequired[bool] + skip_root_model: NotRequired[bool] + use_type_alias: NotRequired[bool] + use_root_model_type_alias: NotRequired[bool] + special_field_name_prefix: NotRequired[str | None] + remove_special_field_name_prefix: NotRequired[bool] + capitalise_enum_members: NotRequired[bool] + keep_model_order: NotRequired[bool] + custom_file_header: NotRequired[str | None] + custom_file_header_path: NotRequired[Path | None] + custom_formatters: NotRequired[list[str] | None] + custom_formatters_kwargs: NotRequired[dict[str, Any] | None] + use_pendulum: NotRequired[bool] + use_standard_primitive_types: NotRequired[bool] + http_query_parameters: NotRequired[Sequence[tuple[str, str]] | None] + treat_dot_as_module: NotRequired[bool | None] + use_exact_imports: NotRequired[bool] + union_mode: NotRequired[UnionMode | None] + output_datetime_class: NotRequired[DatetimeClassType | None] + output_date_class: NotRequired[DateClassType | None] + keyword_only: NotRequired[bool] + frozen_dataclasses: NotRequired[bool] + no_alias: NotRequired[bool] + use_frozen_field: NotRequired[bool] + use_default_factory_for_optional_nested_models: NotRequired[bool] + formatters: NotRequired[list[Formatter]] + settings_path: NotRequired[Path | None] + parent_scoped_naming: NotRequired[bool] + naming_strategy: NotRequired[NamingStrategy | None] + duplicate_name_suffix: NotRequired[dict[str, str] | None] + dataclass_arguments: NotRequired[DataclassArguments | None] + disable_future_imports: NotRequired[bool] + type_mappings: NotRequired[list[str] | None] + type_overrides: NotRequired[dict[str, str] | None] + read_only_write_only_model_type: NotRequired[ReadOnlyWriteOnlyModelType | None] + use_status_code_in_response_name: NotRequired[bool] + all_exports_scope: NotRequired[AllExportsScope | None] + all_exports_collision_strategy: NotRequired[AllExportsCollisionStrategy | None] + field_type_collision_strategy: NotRequired[FieldTypeCollisionStrategy | None] + module_split_mode: NotRequired[ModuleSplitMode | None] diff --git a/src/datamodel_code_generator/_types/parse_config_dict.py b/src/datamodel_code_generator/_types/parse_config_dict.py new file mode 100644 index 000000000..9a489a1f9 --- /dev/null +++ b/src/datamodel_code_generator/_types/parse_config_dict.py @@ -0,0 +1,23 @@ +# generated by datamodel-codegen: +# filename: + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from pathlib import Path + + from datamodel_code_generator.enums import AllExportsCollisionStrategy, AllExportsScope, ModuleSplitMode + + +class ParseConfigDict(TypedDict): + with_import: NotRequired[bool | None] + format_: NotRequired[bool | None] + settings_path: NotRequired[Path | None] + disable_future_imports: NotRequired[bool] + all_exports_scope: NotRequired[AllExportsScope | None] + all_exports_collision_strategy: NotRequired[AllExportsCollisionStrategy | None] + module_split_mode: NotRequired[ModuleSplitMode | None] diff --git a/src/datamodel_code_generator/_types/parser_config_dict.py b/src/datamodel_code_generator/_types/parser_config_dict.py new file mode 100644 index 000000000..c78b052c2 --- /dev/null +++ b/src/datamodel_code_generator/_types/parser_config_dict.py @@ -0,0 +1,142 @@ +# generated by datamodel-codegen: +# filename: + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypedDict + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from collections import defaultdict + from collections.abc import Callable, Iterable, Mapping, Sequence + from pathlib import Path + + from datamodel_code_generator.enums import ( + AllOfMergeMode, + CollapseRootModelsNameStrategy, + DataclassArguments, + FieldTypeCollisionStrategy, + NamingStrategy, + ReadOnlyWriteOnlyModelType, + ReuseScope, + StrictTypes, + TargetPydanticVersion, + ) + from datamodel_code_generator.format import DateClassType, DatetimeClassType, Formatter, PythonVersion + from datamodel_code_generator.model.base import DataModel, DataModelFieldBase + from datamodel_code_generator.parser import DefaultPutDict, LiteralType + from datamodel_code_generator.types import DataTypeManager + + +class ParserConfigDict(TypedDict): + data_model_type: NotRequired[type[DataModel]] + data_model_root_type: NotRequired[type[DataModel]] + data_type_manager_type: NotRequired[type[DataTypeManager]] + data_model_field_type: NotRequired[type[DataModelFieldBase]] + base_class: NotRequired[str | None] + base_class_map: NotRequired[dict[str, str] | None] + additional_imports: NotRequired[list[str] | None] + class_decorators: NotRequired[list[str] | None] + custom_template_dir: NotRequired[Path | None] + extra_template_data: NotRequired[defaultdict[str, dict[str, Any]] | None] + target_python_version: NotRequired[PythonVersion] + dump_resolve_reference_action: NotRequired[Callable[[Iterable[str]], str] | None] + validation: NotRequired[bool] + field_constraints: NotRequired[bool] + snake_case_field: NotRequired[bool] + strip_default_none: NotRequired[bool] + aliases: NotRequired[Mapping[str, str] | None] + allow_population_by_field_name: NotRequired[bool] + apply_default_values_for_required_fields: NotRequired[bool] + allow_extra_fields: NotRequired[bool] + extra_fields: NotRequired[str | None] + use_generic_base_class: NotRequired[bool] + force_optional_for_required_fields: NotRequired[bool] + class_name: NotRequired[str | None] + use_standard_collections: NotRequired[bool] + base_path: NotRequired[Path | None] + use_schema_description: NotRequired[bool] + use_field_description: NotRequired[bool] + use_field_description_example: NotRequired[bool] + use_attribute_docstrings: NotRequired[bool] + use_inline_field_description: NotRequired[bool] + use_default_kwarg: NotRequired[bool] + reuse_model: NotRequired[bool] + reuse_scope: NotRequired[ReuseScope | None] + shared_module_name: NotRequired[str] + encoding: NotRequired[str] + enum_field_as_literal: NotRequired[LiteralType | None] + enum_field_as_literal_map: NotRequired[dict[str, str] | None] + ignore_enum_constraints: NotRequired[bool] + set_default_enum_member: NotRequired[bool] + use_subclass_enum: NotRequired[bool] + use_specialized_enum: NotRequired[bool] + strict_nullable: NotRequired[bool] + use_generic_container_types: NotRequired[bool] + enable_faux_immutability: NotRequired[bool] + remote_text_cache: NotRequired[DefaultPutDict[str, str] | None] + disable_appending_item_suffix: NotRequired[bool] + strict_types: NotRequired[Sequence[StrictTypes] | None] + empty_enum_field_name: NotRequired[str | None] + custom_class_name_generator: NotRequired[Callable[[str], str] | None] + field_extra_keys: NotRequired[set[str] | None] + field_include_all_keys: NotRequired[bool] + field_extra_keys_without_x_prefix: NotRequired[set[str] | None] + model_extra_keys: NotRequired[set[str] | None] + model_extra_keys_without_x_prefix: NotRequired[set[str] | None] + wrap_string_literal: NotRequired[bool | None] + use_title_as_name: NotRequired[bool] + use_operation_id_as_name: NotRequired[bool] + use_unique_items_as_set: NotRequired[bool] + use_tuple_for_fixed_items: NotRequired[bool] + allof_merge_mode: NotRequired[AllOfMergeMode] + http_headers: NotRequired[Sequence[tuple[str, str]] | None] + http_ignore_tls: NotRequired[bool] + http_timeout: NotRequired[float | None] + use_annotated: NotRequired[bool] + use_serialize_as_any: NotRequired[bool] + use_non_positive_negative_number_constrained_types: NotRequired[bool] + use_decimal_for_multiple_of: NotRequired[bool] + original_field_name_delimiter: NotRequired[str | None] + use_double_quotes: NotRequired[bool] + use_union_operator: NotRequired[bool] + allow_responses_without_content: NotRequired[bool] + collapse_root_models: NotRequired[bool] + collapse_root_models_name_strategy: NotRequired[CollapseRootModelsNameStrategy | None] + collapse_reuse_models: NotRequired[bool] + skip_root_model: NotRequired[bool] + use_type_alias: NotRequired[bool] + special_field_name_prefix: NotRequired[str | None] + remove_special_field_name_prefix: NotRequired[bool] + capitalise_enum_members: NotRequired[bool] + keep_model_order: NotRequired[bool] + use_one_literal_as_default: NotRequired[bool] + use_enum_values_in_discriminator: NotRequired[bool] + known_third_party: NotRequired[list[str] | None] + custom_formatters: NotRequired[list[str] | None] + custom_formatters_kwargs: NotRequired[dict[str, Any] | None] + use_pendulum: NotRequired[bool] + use_standard_primitive_types: NotRequired[bool] + http_query_parameters: NotRequired[Sequence[tuple[str, str]] | None] + treat_dot_as_module: NotRequired[bool | None] + use_exact_imports: NotRequired[bool] + default_field_extras: NotRequired[dict[str, Any] | None] + target_datetime_class: NotRequired[DatetimeClassType | None] + target_date_class: NotRequired[DateClassType | None] + keyword_only: NotRequired[bool] + frozen_dataclasses: NotRequired[bool] + no_alias: NotRequired[bool] + use_frozen_field: NotRequired[bool] + use_default_factory_for_optional_nested_models: NotRequired[bool] + formatters: NotRequired[list[Formatter]] + defer_formatting: NotRequired[bool] + parent_scoped_naming: NotRequired[bool] + naming_strategy: NotRequired[NamingStrategy | None] + duplicate_name_suffix: NotRequired[dict[str, str] | None] + dataclass_arguments: NotRequired[DataclassArguments | None] + type_mappings: NotRequired[list[str] | None] + type_overrides: NotRequired[dict[str, str] | None] + read_only_write_only_model_type: NotRequired[ReadOnlyWriteOnlyModelType | None] + field_type_collision_strategy: NotRequired[FieldTypeCollisionStrategy | None] + target_pydantic_version: NotRequired[TargetPydanticVersion | None] diff --git a/src/datamodel_code_generator/config.py b/src/datamodel_code_generator/config.py new file mode 100644 index 000000000..b8bc22402 --- /dev/null +++ b/src/datamodel_code_generator/config.py @@ -0,0 +1,336 @@ +"""Configuration models for datamodel-code-generator.""" + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Callable, Iterable, Mapping, Sequence +from pathlib import Path # noqa: TC003 - used at runtime by Pydantic +from typing import TYPE_CHECKING, Annotated, Any + +from pydantic import BaseModel, Field, WithJsonSchema + +from datamodel_code_generator.enums import ( + DEFAULT_SHARED_MODULE_NAME, + AllExportsCollisionStrategy, + AllExportsScope, + AllOfMergeMode, + CollapseRootModelsNameStrategy, + DataclassArguments, + DataModelType, + FieldTypeCollisionStrategy, + GraphQLScope, + InputFileType, + ModuleSplitMode, + NamingStrategy, + OpenAPIScope, + ReadOnlyWriteOnlyModelType, + ReuseScope, + TargetPydanticVersion, +) +from datamodel_code_generator.format import ( + DEFAULT_FORMATTERS, + DateClassType, + DatetimeClassType, + Formatter, + PythonVersion, + PythonVersionMin, +) +from datamodel_code_generator.model import pydantic as pydantic_model +from datamodel_code_generator.parser import DefaultPutDict, LiteralType +from datamodel_code_generator.util import ConfigDict, is_pydantic_v2 + +if TYPE_CHECKING: + from datamodel_code_generator.model.base import DataModel, DataModelFieldBase + from datamodel_code_generator.model.pydantic_v2 import UnionMode + from datamodel_code_generator.types import DataTypeManager, StrictTypes + + +CallableSchema = Callable[[str], str] +DumpResolveReferenceAction = Callable[[Iterable[str]], str] +DefaultPutDictSchema = DefaultPutDict[str, str] +ExtraTemplateDataType = Annotated[ + defaultdict[str, Annotated[dict[str, Any], Field(default_factory=dict)]], + WithJsonSchema({"type": "object", "x-python-type": "defaultdict[str, dict[str, Any]]"}), +] + + +class GenerateConfig(BaseModel): + """Configuration model for generate().""" + + if is_pydantic_v2(): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + else: # pragma: no cover + + class Config: + """Pydantic v1 model config.""" + + extra = "forbid" + arbitrary_types_allowed = True + + input_filename: str | None = None + input_file_type: InputFileType = InputFileType.Auto + output: Path | None = None + output_model_type: DataModelType = DataModelType.PydanticBaseModel + target_python_version: PythonVersion = PythonVersionMin + target_pydantic_version: TargetPydanticVersion | None = None + base_class: str = "" + base_class_map: dict[str, str] | None = None + additional_imports: list[str] | None = None + class_decorators: list[str] | None = None + custom_template_dir: Path | None = None + extra_template_data: ExtraTemplateDataType | None = None + validation: bool = False + field_constraints: bool = False + snake_case_field: bool = False + strip_default_none: bool = False + aliases: Mapping[str, str] | None = None + disable_timestamp: bool = False + enable_version_header: bool = False + enable_command_header: bool = False + command_line: str | None = None + allow_population_by_field_name: bool = False + allow_extra_fields: bool = False + extra_fields: str | None = None + use_generic_base_class: bool = False + apply_default_values_for_required_fields: bool = False + force_optional_for_required_fields: bool = False + class_name: str | None = None + use_standard_collections: bool = True + use_schema_description: bool = False + use_field_description: bool = False + use_field_description_example: bool = False + use_attribute_docstrings: bool = False + use_inline_field_description: bool = False + use_default_kwarg: bool = False + reuse_model: bool = False + reuse_scope: ReuseScope = ReuseScope.Module + shared_module_name: str = DEFAULT_SHARED_MODULE_NAME + encoding: str = "utf-8" + enum_field_as_literal: LiteralType | None = None + enum_field_as_literal_map: dict[str, str] | None = None + ignore_enum_constraints: bool = False + use_one_literal_as_default: bool = False + use_enum_values_in_discriminator: bool = False + set_default_enum_member: bool = False + use_subclass_enum: bool = False + use_specialized_enum: bool = True + strict_nullable: bool = False + use_generic_container_types: bool = False + enable_faux_immutability: bool = False + disable_appending_item_suffix: bool = False + strict_types: Sequence[StrictTypes] | None = None + empty_enum_field_name: str | None = None + custom_class_name_generator: CallableSchema | None = None + field_extra_keys: set[str] | None = None + field_include_all_keys: bool = False + field_extra_keys_without_x_prefix: set[str] | None = None + model_extra_keys: set[str] | None = None + model_extra_keys_without_x_prefix: set[str] | None = None + openapi_scopes: list[OpenAPIScope] | None = None + include_path_parameters: bool = False + graphql_scopes: list[GraphQLScope] | None = None + wrap_string_literal: bool | None = None + use_title_as_name: bool = False + use_operation_id_as_name: bool = False + use_unique_items_as_set: bool = False + use_tuple_for_fixed_items: bool = False + allof_merge_mode: AllOfMergeMode = AllOfMergeMode.Constraints + http_headers: Sequence[tuple[str, str]] | None = None + http_ignore_tls: bool = False + http_timeout: float | None = None + use_annotated: bool = False + use_serialize_as_any: bool = False + use_non_positive_negative_number_constrained_types: bool = False + use_decimal_for_multiple_of: bool = False + original_field_name_delimiter: str | None = None + use_double_quotes: bool = False + use_union_operator: bool = True + collapse_root_models: bool = False + collapse_root_models_name_strategy: CollapseRootModelsNameStrategy | None = None + collapse_reuse_models: bool = False + skip_root_model: bool = False + use_type_alias: bool = False + use_root_model_type_alias: bool = False + special_field_name_prefix: str | None = None + remove_special_field_name_prefix: bool = False + capitalise_enum_members: bool = False + keep_model_order: bool = False + custom_file_header: str | None = None + custom_file_header_path: Path | None = None + custom_formatters: list[str] | None = None + custom_formatters_kwargs: dict[str, Any] | None = None + use_pendulum: bool = False + use_standard_primitive_types: bool = False + http_query_parameters: Sequence[tuple[str, str]] | None = None + treat_dot_as_module: bool | None = None + use_exact_imports: bool = False + union_mode: UnionMode | None = None + output_datetime_class: DatetimeClassType | None = None + output_date_class: DateClassType | None = None + keyword_only: bool = False + frozen_dataclasses: bool = False + no_alias: bool = False + use_frozen_field: bool = False + use_default_factory_for_optional_nested_models: bool = False + formatters: list[Formatter] = DEFAULT_FORMATTERS + settings_path: Path | None = None + parent_scoped_naming: bool = False + naming_strategy: NamingStrategy | None = None + duplicate_name_suffix: dict[str, str] | None = None + dataclass_arguments: DataclassArguments | None = None + disable_future_imports: bool = False + type_mappings: list[str] | None = None + type_overrides: dict[str, str] | None = None + read_only_write_only_model_type: ReadOnlyWriteOnlyModelType | None = None + use_status_code_in_response_name: bool = False + all_exports_scope: AllExportsScope | None = None + all_exports_collision_strategy: AllExportsCollisionStrategy | None = None + field_type_collision_strategy: FieldTypeCollisionStrategy | None = None + module_split_mode: ModuleSplitMode | None = None + + +class ParserConfig(BaseModel): + """Configuration model for Parser.__init__().""" + + if is_pydantic_v2(): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + else: # pragma: no cover + + class Config: + """Pydantic v1 model config.""" + + extra = "forbid" + arbitrary_types_allowed = True + + data_model_type: type[DataModel] = pydantic_model.BaseModel + data_model_root_type: type[DataModel] = pydantic_model.CustomRootType + data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager + data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField + base_class: str | None = None + base_class_map: dict[str, str] | None = None + additional_imports: list[str] | None = None + class_decorators: list[str] | None = None + custom_template_dir: Path | None = None + extra_template_data: ExtraTemplateDataType | None = None + target_python_version: PythonVersion = PythonVersionMin + dump_resolve_reference_action: DumpResolveReferenceAction | None = None + validation: bool = False + field_constraints: bool = False + snake_case_field: bool = False + strip_default_none: bool = False + aliases: Mapping[str, str] | None = None + allow_population_by_field_name: bool = False + apply_default_values_for_required_fields: bool = False + allow_extra_fields: bool = False + extra_fields: str | None = None + use_generic_base_class: bool = False + force_optional_for_required_fields: bool = False + class_name: str | None = None + use_standard_collections: bool = False + base_path: Path | None = None + use_schema_description: bool = False + use_field_description: bool = False + use_field_description_example: bool = False + use_attribute_docstrings: bool = False + use_inline_field_description: bool = False + use_default_kwarg: bool = False + reuse_model: bool = False + reuse_scope: ReuseScope | None = None + shared_module_name: str = DEFAULT_SHARED_MODULE_NAME + encoding: str = "utf-8" + enum_field_as_literal: LiteralType | None = None + enum_field_as_literal_map: dict[str, str] | None = None + ignore_enum_constraints: bool = False + set_default_enum_member: bool = False + use_subclass_enum: bool = False + use_specialized_enum: bool = True + strict_nullable: bool = False + use_generic_container_types: bool = False + enable_faux_immutability: bool = False + remote_text_cache: DefaultPutDictSchema | None = None + disable_appending_item_suffix: bool = False + strict_types: Sequence[StrictTypes] | None = None + empty_enum_field_name: str | None = None + custom_class_name_generator: CallableSchema | None = None + field_extra_keys: set[str] | None = None + field_include_all_keys: bool = False + field_extra_keys_without_x_prefix: set[str] | None = None + model_extra_keys: set[str] | None = None + model_extra_keys_without_x_prefix: set[str] | None = None + wrap_string_literal: bool | None = None + use_title_as_name: bool = False + use_operation_id_as_name: bool = False + use_unique_items_as_set: bool = False + use_tuple_for_fixed_items: bool = False + allof_merge_mode: AllOfMergeMode = AllOfMergeMode.Constraints + http_headers: Sequence[tuple[str, str]] | None = None + http_ignore_tls: bool = False + http_timeout: float | None = None + use_annotated: bool = False + use_serialize_as_any: bool = False + use_non_positive_negative_number_constrained_types: bool = False + use_decimal_for_multiple_of: bool = False + original_field_name_delimiter: str | None = None + use_double_quotes: bool = False + use_union_operator: bool = False + allow_responses_without_content: bool = False + collapse_root_models: bool = False + collapse_root_models_name_strategy: CollapseRootModelsNameStrategy | None = None + collapse_reuse_models: bool = False + skip_root_model: bool = False + use_type_alias: bool = False + special_field_name_prefix: str | None = None + remove_special_field_name_prefix: bool = False + capitalise_enum_members: bool = False + keep_model_order: bool = False + use_one_literal_as_default: bool = False + use_enum_values_in_discriminator: bool = False + known_third_party: list[str] | None = None + custom_formatters: list[str] | None = None + custom_formatters_kwargs: dict[str, Any] | None = None + use_pendulum: bool = False + use_standard_primitive_types: bool = False + http_query_parameters: Sequence[tuple[str, str]] | None = None + treat_dot_as_module: bool | None = None + use_exact_imports: bool = False + default_field_extras: dict[str, Any] | None = None + target_datetime_class: DatetimeClassType | None = None + target_date_class: DateClassType | None = None + keyword_only: bool = False + frozen_dataclasses: bool = False + no_alias: bool = False + use_frozen_field: bool = False + use_default_factory_for_optional_nested_models: bool = False + formatters: list[Formatter] = DEFAULT_FORMATTERS + defer_formatting: bool = False + parent_scoped_naming: bool = False + naming_strategy: NamingStrategy | None = None + duplicate_name_suffix: dict[str, str] | None = None + dataclass_arguments: DataclassArguments | None = None + type_mappings: list[str] | None = None + type_overrides: dict[str, str] | None = None + read_only_write_only_model_type: ReadOnlyWriteOnlyModelType | None = None + field_type_collision_strategy: FieldTypeCollisionStrategy | None = None + target_pydantic_version: TargetPydanticVersion | None = None + + +class ParseConfig(BaseModel): + """Configuration model for Parser.parse().""" + + if is_pydantic_v2(): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + else: # pragma: no cover + + class Config: + """Pydantic v1 model config.""" + + extra = "forbid" + arbitrary_types_allowed = True + + with_import: bool | None = True + format_: bool | None = True + settings_path: Path | None = None + disable_future_imports: bool = False + all_exports_scope: AllExportsScope | None = None + all_exports_collision_strategy: AllExportsCollisionStrategy | None = None + module_split_mode: ModuleSplitMode | None = None diff --git a/tests/main/test_main_general.py b/tests/main/test_main_general.py index e9181f162..a5f8fe195 100644 --- a/tests/main/test_main_general.py +++ b/tests/main/test_main_general.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING import black +import pydantic import pytest from inline_snapshot import snapshot @@ -22,6 +23,7 @@ ) from datamodel_code_generator.__main__ import Config, Exit from datamodel_code_generator.arguments import _dataclass_arguments +from datamodel_code_generator.config import GenerateConfig from datamodel_code_generator.format import CodeFormatter, PythonVersion from datamodel_code_generator.parser.openapi import OpenAPIParser from tests.conftest import assert_output, create_assert_file_content, freeze_time @@ -1965,3 +1967,53 @@ def test_pydantic_v1_deprecation_warning(output_file: Path, mocker: MockerFixtur output_path=output_file, input_file_type="jsonschema", ) + + +@pytest.mark.skipif(pydantic.VERSION < "2.0.0", reason="GenerateConfig requires Pydantic v2") +def test_generate_with_config_object(output_file: Path) -> None: + """Test generate() with GenerateConfig object.""" + from datamodel_code_generator.model.pydantic_v2 import UnionMode + from datamodel_code_generator.types import StrictTypes + + GenerateConfig.model_rebuild(_types_namespace={"StrictTypes": StrictTypes, "UnionMode": UnionMode}) + config = GenerateConfig( + input_filename="test.json", + output_model_type=DataModelType.PydanticV2BaseModel, + use_schema_description=True, + snake_case_field=True, + field_constraints=True, + extra_template_data={"Model": {"custom_key": "custom_value"}}, + ) + generate( + input_='{"type": "object", "properties": {"userName": {"type": "string"}}}', + output=output_file, + config=config, + ) + content = output_file.read_text(encoding="utf-8") + assert "class Model" in content + assert "user_name" in content + + +@pytest.mark.skipif(pydantic.VERSION < "2.0.0", reason="GenerateConfig requires Pydantic v2") +def test_generate_with_config_object_extra_template_data_override(output_file: Path) -> None: + """Test generate() with extra_template_data passed directly, overriding config.""" + from collections import defaultdict + + from datamodel_code_generator.model.pydantic_v2 import UnionMode + from datamodel_code_generator.types import StrictTypes + + GenerateConfig.model_rebuild(_types_namespace={"StrictTypes": StrictTypes, "UnionMode": UnionMode}) + config = GenerateConfig( + input_filename="test.json", + output_model_type=DataModelType.PydanticV2BaseModel, + extra_template_data={"Model": {"config_key": "config_value"}}, + ) + # Pass extra_template_data directly - this should override config value + generate( + input_='{"type": "object", "properties": {"name": {"type": "string"}}}', + output=output_file, + config=config, + extra_template_data=defaultdict(dict, {"Model": {"direct_key": "direct_value"}}), + ) + content = output_file.read_text(encoding="utf-8") + assert "class Model" in content diff --git a/tests/main/test_public_api_signature_baseline.py b/tests/main/test_public_api_signature_baseline.py index 6bf21acb3..7042596bd 100644 --- a/tests/main/test_public_api_signature_baseline.py +++ b/tests/main/test_public_api_signature_baseline.py @@ -3,7 +3,11 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any +import types +from typing import TYPE_CHECKING, Annotated, Any, ForwardRef, Union, get_args, get_origin + +import pytest +from typing_extensions import NotRequired from datamodel_code_generator import DEFAULT_FORMATTERS, DEFAULT_SHARED_MODULE_NAME, generate from datamodel_code_generator.enums import ( @@ -27,6 +31,9 @@ from datamodel_code_generator.model import pydantic as pydantic_model from datamodel_code_generator.model.pydantic import BaseModel from datamodel_code_generator.parser.base import Parser, YamlValue, title_to_class_name +from datamodel_code_generator.util import is_pydantic_v2 + +PYDANTIC_V2_SKIP = pytest.mark.skipif(not is_pydantic_v2(), reason="Pydantic v2 required") if TYPE_CHECKING: from collections import defaultdict @@ -34,6 +41,7 @@ from pathlib import Path from urllib.parse import ParseResult + from datamodel_code_generator.config import GenerateConfig from datamodel_code_generator.format import DateClassType, DatetimeClassType, Formatter, PythonVersion from datamodel_code_generator.model.dataclass import DataclassArguments from datamodel_code_generator.model.pydantic import DataTypeManager @@ -45,6 +53,7 @@ def _baseline_generate( input_: Path | str | ParseResult | Mapping[str, Any], *, + config: GenerateConfig | None = None, input_filename: str | None = None, input_file_type: InputFileType = InputFileType.Auto, output: Path | None = None, @@ -287,6 +296,18 @@ def __init__( ) -> None: raise NotImplementedError + def parse( + self, + with_import: bool | None = True, + format_: bool | None = True, + settings_path: Path | None = None, + disable_future_imports: bool = False, + all_exports_scope: AllExportsScope | None = None, + all_exports_collision_strategy: AllExportsCollisionStrategy | None = None, + module_split_mode: ModuleSplitMode | None = None, + ) -> str | dict[tuple[str, ...], Any]: + raise NotImplementedError + def _kwonly_params(signature: inspect.Signature) -> list[inspect.Parameter]: return [param for param in signature.parameters.values() if param.kind is inspect.Parameter.KEYWORD_ONLY] @@ -296,6 +317,83 @@ def _kwonly_by_name(signature: inspect.Signature) -> dict[str, inspect.Parameter return {param.name: param for param in _kwonly_params(signature)} +def _params_by_name(signature: inspect.Signature) -> dict[str, inspect.Parameter]: + return { + name: param + for name, param in signature.parameters.items() + if name != "self" and param.kind in {inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} + } + + +def _type_to_str(tp: Any) -> str: + """Convert type to normalized string.""" + if tp is type(None): + return "None" + if isinstance(tp, type): + return tp.__name__ + if isinstance(tp, str): + return tp + return ( + str(tp) + .replace("collections.abc.", "") + .replace("collections.", "") + .replace("typing.", "") + .replace("pathlib.", "") + ) + + +def _normalize_union_str(type_str: str) -> str: + """Normalize a union type string by sorting its components.""" + if " | " in type_str: + parts = [p.strip() for p in type_str.split(" | ")] + return " | ".join(sorted(parts)) + return type_str + + +def _normalize_type(tp: Any) -> str: # noqa: PLR0911 + """Normalize type for comparison between Config and TypedDict.""" + if tp is None or tp is type(None): + return "None" + + if isinstance(tp, str): + return _normalize_union_str(tp) + + if isinstance(tp, ForwardRef): + arg = tp.__forward_arg__ + if arg.startswith("NotRequired[") and arg.endswith("]"): + arg = arg[12:-1] + return _normalize_union_str(arg) + + if isinstance(tp, list): + return f"[{', '.join(_normalize_type(t) for t in tp)}]" + + origin = get_origin(tp) + args = get_args(tp) + + if origin in {Annotated, NotRequired}: + return _normalize_type(args[0]) if args else _type_to_str(tp) + + if origin is Union or isinstance(tp, types.UnionType): + if isinstance(tp, types.UnionType): + args = get_args(tp) + normalized_args = sorted(_normalize_type(a) for a in args) + return _type_to_str(" | ".join(normalized_args)) + + if origin is not None: + if args: + normalized_args = [_normalize_type(a) for a in args] + origin_name = getattr(origin, "__name__", str(origin)) + return _type_to_str(f"{origin_name}[{', '.join(normalized_args)}]") + return _type_to_str(origin) + + return _type_to_str(tp) + + +def _types_match(config_type: Any, dict_type: Any) -> bool: + """Check if Config type and TypedDict type are equivalent.""" + return _normalize_type(config_type) == _normalize_type(dict_type) + + def test_generate_signature_matches_baseline() -> None: """Ensure generate keeps the origin/main kw-only args and annotations.""" expected = inspect.signature(_baseline_generate) @@ -312,3 +410,196 @@ def test_parser_signature_matches_baseline() -> None: assert _kwonly_by_name(actual).keys() == _kwonly_by_name(expected).keys() for name, param in _kwonly_by_name(expected).items(): assert _kwonly_by_name(actual)[name].annotation == param.annotation + + +@PYDANTIC_V2_SKIP +def test_generate_config_dict_fields_match_generate_config() -> None: + """Ensure GenerateConfigDict has same field names as GenerateConfig.""" + from datamodel_code_generator._types import GenerateConfigDict + from datamodel_code_generator.config import GenerateConfig + + config_fields = set(GenerateConfig.model_fields.keys()) + dict_fields = set(GenerateConfigDict.__annotations__.keys()) + assert config_fields == dict_fields, f"Mismatch: {config_fields ^ dict_fields}" + + +@PYDANTIC_V2_SKIP +def test_generate_config_dict_types_match_generate_config() -> None: + """Ensure GenerateConfigDict field types match GenerateConfig.""" + from datamodel_code_generator._types import GenerateConfigDict + from datamodel_code_generator.config import GenerateConfig + + for field_name, field_info in GenerateConfig.model_fields.items(): + config_type = field_info.annotation + dict_type = GenerateConfigDict.__annotations__[field_name] + assert _types_match(config_type, dict_type), ( + f"Type mismatch for {field_name}: Config={_normalize_type(config_type)}, Dict={_normalize_type(dict_type)}" + ) + + +@PYDANTIC_V2_SKIP +def test_parser_config_dict_fields_match_parser_config() -> None: + """Ensure ParserConfigDict has same field names as ParserConfig.""" + from datamodel_code_generator._types import ParserConfigDict + from datamodel_code_generator.config import ParserConfig + + config_fields = set(ParserConfig.model_fields.keys()) + dict_fields = set(ParserConfigDict.__annotations__.keys()) + assert config_fields == dict_fields, f"Mismatch: {config_fields ^ dict_fields}" + + +@PYDANTIC_V2_SKIP +def test_parse_config_dict_fields_match_parse_config() -> None: + """Ensure ParseConfigDict has same field names as ParseConfig.""" + from datamodel_code_generator._types import ParseConfigDict + from datamodel_code_generator.config import ParseConfig + + config_fields = set(ParseConfig.model_fields.keys()) + dict_fields = set(ParseConfigDict.__annotations__.keys()) + assert config_fields == dict_fields, f"Mismatch: {config_fields ^ dict_fields}" + + +@PYDANTIC_V2_SKIP +def test_parser_config_dict_types_match_parser_config() -> None: + """Ensure ParserConfigDict field types match ParserConfig.""" + from datamodel_code_generator._types import ParserConfigDict + from datamodel_code_generator.config import ParserConfig + + for field_name, field_info in ParserConfig.model_fields.items(): + config_type = field_info.annotation + dict_type = ParserConfigDict.__annotations__[field_name] + assert _types_match(config_type, dict_type), ( + f"Type mismatch for {field_name}: Config={_normalize_type(config_type)}, Dict={_normalize_type(dict_type)}" + ) + + +@PYDANTIC_V2_SKIP +def test_parse_config_dict_types_match_parse_config() -> None: + """Ensure ParseConfigDict field types match ParseConfig.""" + from datamodel_code_generator._types import ParseConfigDict + from datamodel_code_generator.config import ParseConfig + + for field_name, field_info in ParseConfig.model_fields.items(): + config_type = field_info.annotation + dict_type = ParseConfigDict.__annotations__[field_name] + assert _types_match(config_type, dict_type), ( + f"Type mismatch for {field_name}: Config={_normalize_type(config_type)}, Dict={_normalize_type(dict_type)}" + ) + + +@PYDANTIC_V2_SKIP +def test_generate_config_defaults_match_generate_signature() -> None: + """Ensure GenerateConfig default values match generate() signature defaults.""" + from datamodel_code_generator.config import GenerateConfig + + expected_sig = inspect.signature(_baseline_generate) + expected_params = _kwonly_by_name(expected_sig) + + for field_name, field_info in GenerateConfig.model_fields.items(): + if field_name not in expected_params: + continue + + param = expected_params[field_name] + config_default = field_info.default + + # Handle Parameter.empty vs None + if param.default is inspect.Parameter.empty: + # No default in signature means required, but Config may have None default + continue + + assert config_default == param.default, ( + f"Default mismatch for {field_name}: Config={config_default}, generate()={param.default}" + ) + + +@PYDANTIC_V2_SKIP +def test_parser_config_defaults_match_parser_signature() -> None: + """Ensure ParserConfig default values match Parser.__init__ signature defaults.""" + from datamodel_code_generator.config import ParserConfig + + expected_sig = inspect.signature(_BaselineParser.__init__) + expected_params = _kwonly_by_name(expected_sig) + + for field_name, field_info in ParserConfig.model_fields.items(): + if field_name not in expected_params: + continue + + param = expected_params[field_name] + config_default = field_info.default + + if param.default is inspect.Parameter.empty: + continue + + if callable(param.default) and config_default is None: + continue + + assert config_default == param.default, ( + f"Default mismatch for {field_name}: Config={config_default}, Parser.__init__()={param.default}" + ) + + +@PYDANTIC_V2_SKIP +def test_parse_config_defaults_match_parse_signature() -> None: + """Ensure ParseConfig default values match Parser.parse() signature defaults.""" + from datamodel_code_generator.config import ParseConfig + + expected_sig = inspect.signature(_BaselineParser.parse) + expected_params = _params_by_name(expected_sig) + + for field_name, field_info in ParseConfig.model_fields.items(): + if field_name not in expected_params: + continue + + param = expected_params[field_name] + config_default = field_info.default + + if param.default is inspect.Parameter.empty: + continue + + assert config_default == param.default, ( + f"Default mismatch for {field_name}: Config={config_default}, Parser.parse()={param.default}" + ) + + +@PYDANTIC_V2_SKIP +def test_generate_with_config_produces_same_result_as_kwargs(tmp_path: Path) -> None: + """Ensure generate() with GenerateConfig produces same result as kwargs.""" + from datamodel_code_generator.config import GenerateConfig + from datamodel_code_generator.enums import DataModelType + from datamodel_code_generator.types import StrictTypes + + if hasattr(GenerateConfig, "model_rebuild"): + types_namespace: dict[str, type | None] = {"StrictTypes": StrictTypes, "UnionMode": None} + try: + from datamodel_code_generator.model.pydantic_v2 import UnionMode + + types_namespace["UnionMode"] = UnionMode + except ImportError: + pass + GenerateConfig.model_rebuild(_types_namespace=types_namespace) + + schema = '{"type": "object", "properties": {"name": {"type": "string"}}}' + output_kwargs = tmp_path / "output_kwargs.py" + output_config = tmp_path / "output_config.py" + + # Generate with kwargs + generate( + input_=schema, + output=output_kwargs, + output_model_type=DataModelType.PydanticV2BaseModel, + ) + + # Generate with config + config = GenerateConfig( + output_model_type=DataModelType.PydanticV2BaseModel, + ) + generate( + input_=schema, + output=output_config, + config=config, + ) + + # Compare results + kwargs_content = output_kwargs.read_text(encoding="utf-8") + config_content = output_config.read_text(encoding="utf-8") + assert kwargs_content == config_content, "Output differs between kwargs and config" diff --git a/tests/test_input_model.py b/tests/test_input_model.py index ada8ee317..451895602 100644 --- a/tests/test_input_model.py +++ b/tests/test_input_model.py @@ -1129,3 +1129,14 @@ def test_input_model_ref_strategy_reuse_foreign_msgspec_output(tmp_path: Path) - "class ModelWithPydantic", ], ) + + +@SKIP_PYDANTIC_V1 +def test_input_model_config_class(tmp_path: Path) -> None: + """Test that config classes like GenerateConfig are properly handled.""" + run_input_model_and_assert( + input_model="datamodel_code_generator.config:GenerateConfig", + output_path=tmp_path / "output.py", + extra_args=["--output-model-type", "typing.TypedDict"], + expected_output_contains=["TypedDict", "Callable[[str], str]"], + ) diff --git a/tox.ini b/tox.ini index be296e36a..6c14593f7 100644 --- a/tox.ini +++ b/tox.ini @@ -112,6 +112,15 @@ commands = check-wheel-contents --no-config {env_tmp_dir} dependency_groups = pkg-meta +[testenv:config-types] +description = Generate TypedDict files from config models (use --check to validate only) +commands = + datamodel-codegen --profile generate-config-dict {posargs} + datamodel-codegen --profile parser-config-dict {posargs} + datamodel-codegen --profile parse-config-dict {posargs} +dependency_groups = dev +no_default_groups = true + [testenv:type] description = run type check on code base commands =