Skip to content

Commit 83776ca

Browse files
committed
Add GenerateConfig class and auto-generate TypedDicts for type-safe config (#2844)
* Add GenerateConfig class and auto-generate TypedDicts for type-safe config * Add --unsafe-fixes to ruff-check formatter * Skip config-types hook in lint workflow * Fix UnionType serialization and skip tox-dependent hooks in lint CI * Revert --unsafe-fixes and exclude _types from pre-commit ruff * Regenerate config types after merging main * Remove UnionType handling (should be separate PR) * Pin Python 3.14 for config-types CI * Remove unused _from_generate_config methods * Fix use_standard_collections merge logic * Fix coverage omit pattern for _types directory * Regenerate TypedDicts after merging set/frozenset fix * Add test for extra_template_data override and coverage pragmas * Restore --unsafe-fixes flag in ruff commands * Restore _try_rebuild_model and test_input_model_config_class after merge * Use --input-model-ref-strategy to simplify config type generation * Regenerate config types with improved reuse-foreign strategy * Remove WithJsonSchema annotations - automatic handling works * Add tests to ensure Config and ConfigDict fields/types match * Add comprehensive Config/TypedDict compatibility tests Add tests to ensure Config classes and generated TypedDicts remain in sync: - Type matching tests for ParserConfigDict and ParseConfigDict - Default value matching test for GenerateConfig vs generate() signature - Runtime compatibility test verifying generate() produces identical output * Use defaultdict for extra_template_data to match generate() signature - Update GenerateConfig and ParserConfig to use defaultdict with Annotated/WithJsonSchema for Pydantic compatibility - Regenerate TypedDicts with proper defaultdict type and import - Update test normalization to handle Annotated types and collections prefix * Fix N806 lint: rename UnionMode variable to lowercase * Simplify type comparison tests for Config/TypedDict Replace complex _normalize_type_str function and detailed type comparison with simpler field count checks. The integration test test_generate_with_config_produces_same_result_as_kwargs validates that Config classes work correctly at runtime, making detailed type string comparison unnecessary. Also fix lint errors: - N813: Restructure UnionMode import to avoid CamelCase-to-lowercase alias - I001: Fix import sorting order * Add type comparison tests for Config/TypedDict equivalence Implement _normalize_type function using typing APIs (get_origin, get_args) to normalize types for comparison between Config classes and generated TypedDicts. Handles: - Union types with sorted components - Annotated and NotRequired wrappers - ForwardRef string parsing - Callable with list args - Module prefix removal (collections.abc, typing, etc.) Tests now verify: 1. Field names match 2. Field types match (normalized) 3. Default values match 4. Integration test validates runtime behavior * Fix formatting * Skip Config/TypedDict tests on Pydantic v1 * Add default value tests for ParserConfig and ParseConfig * Use @PYDANTIC_V2_SKIP decorator for pydantic v2 tests * Fix type normalization for types.UnionType (Python 3.10+)
1 parent 12fc77c commit 83776ca

15 files changed

Lines changed: 1320 additions & 6 deletions
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: Check config types
2+
3+
on:
4+
push:
5+
paths:
6+
- 'src/datamodel_code_generator/config.py'
7+
- 'src/datamodel_code_generator/_types/**'
8+
- 'pyproject.toml'
9+
- '.github/workflows/config-types.yaml'
10+
pull_request:
11+
paths:
12+
- 'src/datamodel_code_generator/config.py'
13+
- 'src/datamodel_code_generator/_types/**'
14+
- 'pyproject.toml'
15+
- '.github/workflows/config-types.yaml'
16+
17+
jobs:
18+
config-types:
19+
runs-on: ubuntu-latest
20+
steps:
21+
- uses: actions/checkout@v4
22+
23+
- uses: astral-sh/setup-uv@v5
24+
with:
25+
enable-cache: true
26+
27+
- run: uv python install 3.14
28+
29+
- run: uv tool install --python 3.14 tox --with tox-uv
30+
31+
- run: tox -e config-types -- --check

.github/workflows/lint.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
- uses: actions/setup-python@v5
3535
with:
3636
python-version: "3.14"
37-
- run: uvx prek run --all-files --show-diff-on-failure --skip readme
37+
- run: SKIP=readme,config-types uvx prek run --all-files --show-diff-on-failure
3838
- if: |
3939
github.event_name == 'push' ||
4040
github.event.pull_request.head.repo.full_name == github.repository ||

.pre-commit-config.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ repos:
1818
rev: 'v0.14.9'
1919
hooks:
2020
- id: ruff-format
21-
exclude: "^tests/data"
21+
exclude: "^tests/data|^src/datamodel_code_generator/_types/"
2222
- id: ruff
23-
exclude: "^tests/data"
23+
exclude: "^tests/data|^src/datamodel_code_generator/_types/"
2424
args: ["--fix", "--unsafe-fixes", "--exit-non-zero-on-fix"]
2525
- repo: https://github.com/codespell-project/codespell
2626
rev: v2.4.1
@@ -37,3 +37,9 @@ repos:
3737
language: system
3838
files: ^(src/datamodel_code_generator/arguments\.py|README\.md|docs/index\.md)$
3939
pass_filenames: false
40+
- id: config-types
41+
name: Generate config TypedDicts
42+
entry: bash -c 'test -x .tox/dev/bin/python || tox run -e dev --notest -qq; .tox/dev/bin/datamodel-codegen --profile generate-config-dict && .tox/dev/bin/datamodel-codegen --profile parser-config-dict && .tox/dev/bin/datamodel-codegen --profile parse-config-dict'
43+
language: system
44+
files: ^src/datamodel_code_generator/config\.py$
45+
pass_filenames: false

pyproject.toml

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ paths.other = [
244244
"*\\datamodel-code-generator",
245245
]
246246
run.dynamic_context = "none"
247-
run.omit = [ "tests/data/*", "tests/main/test_performance.py" ]
247+
run.omit = [ "tests/data/*", "tests/main/test_performance.py", "*/_types/*" ]
248248
report.fail_under = 88
249249
run.parallel = true
250250
run.plugins = [
@@ -258,3 +258,30 @@ reportPrivateImportUsage = false
258258
[tool.pydantic-pycharm-plugin]
259259
ignore-init-method-arguments = true
260260
parsable-types.str = [ "int", "float" ]
261+
262+
[tool.datamodel-codegen.profiles.config-types-base]
263+
enum-field-as-literal = "none"
264+
use-standard-primitive-types = true
265+
disable-warnings = true
266+
disable-timestamp = true
267+
output-model-type = "typing.TypedDict"
268+
formatters = [ "ruff-format", "ruff-check" ]
269+
input-model-ref-strategy = "reuse-foreign"
270+
271+
[tool.datamodel-codegen.profiles.generate-config-dict]
272+
extends = "config-types-base"
273+
input-model = "src/datamodel_code_generator/config.py:GenerateConfig"
274+
output = "src/datamodel_code_generator/_types/generate_config_dict.py"
275+
class-name = "GenerateConfigDict"
276+
277+
[tool.datamodel-codegen.profiles.parser-config-dict]
278+
extends = "config-types-base"
279+
input-model = "src/datamodel_code_generator/config.py:ParserConfig"
280+
output = "src/datamodel_code_generator/_types/parser_config_dict.py"
281+
class-name = "ParserConfigDict"
282+
283+
[tool.datamodel-codegen.profiles.parse-config-dict]
284+
extends = "config-types-base"
285+
input-model = "src/datamodel_code_generator/config.py:ParseConfig"
286+
output = "src/datamodel_code_generator/_types/parse_config_dict.py"
287+
class-name = "ParseConfigDict"

src/datamodel_code_generator/__init__.py

Lines changed: 187 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
if TYPE_CHECKING:
6262
from collections import defaultdict
6363

64+
from datamodel_code_generator.config import GenerateConfig
6465
from datamodel_code_generator.model.pydantic_v2 import UnionMode
6566
from datamodel_code_generator.parser.base import Parser
6667
from datamodel_code_generator.types import StrictTypes
@@ -451,6 +452,7 @@ def _build_module_content(
451452
def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
452453
input_: Path | str | ParseResult | Mapping[str, Any],
453454
*,
455+
config: GenerateConfig | None = None,
454456
input_filename: str | None = None,
455457
input_file_type: InputFileType = InputFileType.Auto,
456458
output: Path | None = None,
@@ -512,7 +514,7 @@ def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
512514
model_extra_keys_without_x_prefix: set[str] | None = None,
513515
openapi_scopes: list[OpenAPIScope] | None = None,
514516
include_path_parameters: bool = False,
515-
graphql_scopes: list[GraphQLScope] | None = None, # noqa: ARG001
517+
graphql_scopes: list[GraphQLScope] | None = None,
516518
wrap_string_literal: bool | None = None,
517519
use_title_as_name: bool = False,
518520
use_operation_id_as_name: bool = False,
@@ -583,6 +585,190 @@ def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
583585
- When output is None and multiple modules: GeneratedModules (dict mapping
584586
module path tuples to generated code strings)
585587
"""
588+
if config is not None:
589+
input_filename = config.input_filename if input_filename is None else input_filename
590+
input_file_type = config.input_file_type if input_file_type == InputFileType.Auto else input_file_type
591+
output = config.output if output is None else output
592+
output_model_type = (
593+
config.output_model_type if output_model_type == DataModelType.PydanticBaseModel else output_model_type
594+
)
595+
target_python_version = (
596+
config.target_python_version if target_python_version == PythonVersionMin else target_python_version
597+
)
598+
target_pydantic_version = (
599+
config.target_pydantic_version if target_pydantic_version is None else target_pydantic_version
600+
)
601+
base_class = base_class or config.base_class
602+
base_class_map = config.base_class_map if base_class_map is None else base_class_map
603+
additional_imports = config.additional_imports if additional_imports is None else additional_imports
604+
class_decorators = config.class_decorators if class_decorators is None else class_decorators
605+
custom_template_dir = config.custom_template_dir if custom_template_dir is None else custom_template_dir
606+
if extra_template_data is None and config.extra_template_data is not None:
607+
from collections import defaultdict as _defaultdict # noqa: PLC0415
608+
609+
extra_template_data = _defaultdict(dict, config.extra_template_data)
610+
validation = validation or config.validation
611+
field_constraints = field_constraints or config.field_constraints
612+
snake_case_field = snake_case_field or config.snake_case_field
613+
strip_default_none = strip_default_none or config.strip_default_none
614+
aliases = config.aliases if aliases is None else aliases
615+
disable_timestamp = disable_timestamp or config.disable_timestamp
616+
enable_version_header = enable_version_header or config.enable_version_header
617+
enable_command_header = enable_command_header or config.enable_command_header
618+
command_line = config.command_line if command_line is None else command_line
619+
allow_population_by_field_name = allow_population_by_field_name or config.allow_population_by_field_name
620+
allow_extra_fields = allow_extra_fields or config.allow_extra_fields
621+
extra_fields = config.extra_fields if extra_fields is None else extra_fields
622+
use_generic_base_class = use_generic_base_class or config.use_generic_base_class
623+
apply_default_values_for_required_fields = (
624+
apply_default_values_for_required_fields or config.apply_default_values_for_required_fields
625+
)
626+
force_optional_for_required_fields = (
627+
force_optional_for_required_fields or config.force_optional_for_required_fields
628+
)
629+
class_name = config.class_name if class_name is None else class_name
630+
use_standard_collections = (
631+
config.use_standard_collections if use_standard_collections else use_standard_collections
632+
)
633+
use_schema_description = use_schema_description or config.use_schema_description
634+
use_field_description = use_field_description or config.use_field_description
635+
use_field_description_example = use_field_description_example or config.use_field_description_example
636+
use_attribute_docstrings = use_attribute_docstrings or config.use_attribute_docstrings
637+
use_inline_field_description = use_inline_field_description or config.use_inline_field_description
638+
use_default_kwarg = use_default_kwarg or config.use_default_kwarg
639+
reuse_model = reuse_model or config.reuse_model
640+
reuse_scope = config.reuse_scope if reuse_scope == ReuseScope.Module else reuse_scope
641+
shared_module_name = (
642+
config.shared_module_name if shared_module_name == DEFAULT_SHARED_MODULE_NAME else shared_module_name
643+
)
644+
encoding = config.encoding if encoding == "utf-8" else encoding
645+
enum_field_as_literal = config.enum_field_as_literal if enum_field_as_literal is None else enum_field_as_literal
646+
enum_field_as_literal_map = (
647+
config.enum_field_as_literal_map if enum_field_as_literal_map is None else enum_field_as_literal_map
648+
)
649+
ignore_enum_constraints = ignore_enum_constraints or config.ignore_enum_constraints
650+
use_one_literal_as_default = use_one_literal_as_default or config.use_one_literal_as_default
651+
use_enum_values_in_discriminator = use_enum_values_in_discriminator or config.use_enum_values_in_discriminator
652+
set_default_enum_member = set_default_enum_member or config.set_default_enum_member
653+
use_subclass_enum = use_subclass_enum or config.use_subclass_enum
654+
use_specialized_enum = config.use_specialized_enum if use_specialized_enum else False
655+
strict_nullable = strict_nullable or config.strict_nullable
656+
use_generic_container_types = use_generic_container_types or config.use_generic_container_types
657+
enable_faux_immutability = enable_faux_immutability or config.enable_faux_immutability
658+
disable_appending_item_suffix = disable_appending_item_suffix or config.disable_appending_item_suffix
659+
strict_types = config.strict_types if strict_types is None else strict_types
660+
empty_enum_field_name = config.empty_enum_field_name if empty_enum_field_name is None else empty_enum_field_name
661+
custom_class_name_generator = (
662+
config.custom_class_name_generator if custom_class_name_generator is None else custom_class_name_generator
663+
)
664+
field_extra_keys = config.field_extra_keys if field_extra_keys is None else field_extra_keys
665+
field_include_all_keys = field_include_all_keys or config.field_include_all_keys
666+
field_extra_keys_without_x_prefix = (
667+
config.field_extra_keys_without_x_prefix
668+
if field_extra_keys_without_x_prefix is None
669+
else field_extra_keys_without_x_prefix
670+
)
671+
model_extra_keys = config.model_extra_keys if model_extra_keys is None else model_extra_keys
672+
model_extra_keys_without_x_prefix = (
673+
config.model_extra_keys_without_x_prefix
674+
if model_extra_keys_without_x_prefix is None
675+
else model_extra_keys_without_x_prefix
676+
)
677+
openapi_scopes = config.openapi_scopes if openapi_scopes is None else openapi_scopes
678+
include_path_parameters = include_path_parameters or config.include_path_parameters
679+
graphql_scopes = config.graphql_scopes if graphql_scopes is None else graphql_scopes
680+
wrap_string_literal = config.wrap_string_literal if wrap_string_literal is None else wrap_string_literal
681+
use_title_as_name = use_title_as_name or config.use_title_as_name
682+
use_operation_id_as_name = use_operation_id_as_name or config.use_operation_id_as_name
683+
use_unique_items_as_set = use_unique_items_as_set or config.use_unique_items_as_set
684+
use_tuple_for_fixed_items = use_tuple_for_fixed_items or config.use_tuple_for_fixed_items
685+
allof_merge_mode = (
686+
config.allof_merge_mode if allof_merge_mode == AllOfMergeMode.Constraints else allof_merge_mode
687+
)
688+
http_headers = config.http_headers if http_headers is None else http_headers
689+
http_ignore_tls = http_ignore_tls or config.http_ignore_tls
690+
http_timeout = config.http_timeout if http_timeout is None else http_timeout
691+
use_annotated = use_annotated or config.use_annotated
692+
use_serialize_as_any = use_serialize_as_any or config.use_serialize_as_any
693+
use_non_positive_negative_number_constrained_types = (
694+
use_non_positive_negative_number_constrained_types
695+
or config.use_non_positive_negative_number_constrained_types
696+
)
697+
use_decimal_for_multiple_of = use_decimal_for_multiple_of or config.use_decimal_for_multiple_of
698+
original_field_name_delimiter = (
699+
config.original_field_name_delimiter
700+
if original_field_name_delimiter is None
701+
else original_field_name_delimiter
702+
)
703+
use_double_quotes = use_double_quotes or config.use_double_quotes
704+
use_union_operator = config.use_union_operator if use_union_operator else False
705+
collapse_root_models = collapse_root_models or config.collapse_root_models
706+
collapse_root_models_name_strategy = (
707+
config.collapse_root_models_name_strategy
708+
if collapse_root_models_name_strategy is None
709+
else collapse_root_models_name_strategy
710+
)
711+
collapse_reuse_models = collapse_reuse_models or config.collapse_reuse_models
712+
skip_root_model = skip_root_model or config.skip_root_model
713+
use_type_alias = use_type_alias or config.use_type_alias
714+
use_root_model_type_alias = use_root_model_type_alias or config.use_root_model_type_alias
715+
special_field_name_prefix = (
716+
config.special_field_name_prefix if special_field_name_prefix is None else special_field_name_prefix
717+
)
718+
remove_special_field_name_prefix = remove_special_field_name_prefix or config.remove_special_field_name_prefix
719+
capitalise_enum_members = capitalise_enum_members or config.capitalise_enum_members
720+
keep_model_order = keep_model_order or config.keep_model_order
721+
custom_file_header = config.custom_file_header if custom_file_header is None else custom_file_header
722+
custom_file_header_path = (
723+
config.custom_file_header_path if custom_file_header_path is None else custom_file_header_path
724+
)
725+
custom_formatters = config.custom_formatters if custom_formatters is None else custom_formatters
726+
custom_formatters_kwargs = (
727+
config.custom_formatters_kwargs if custom_formatters_kwargs is None else custom_formatters_kwargs
728+
)
729+
use_pendulum = use_pendulum or config.use_pendulum
730+
use_standard_primitive_types = use_standard_primitive_types or config.use_standard_primitive_types
731+
http_query_parameters = config.http_query_parameters if http_query_parameters is None else http_query_parameters
732+
treat_dot_as_module = config.treat_dot_as_module if treat_dot_as_module is None else treat_dot_as_module
733+
use_exact_imports = use_exact_imports or config.use_exact_imports
734+
union_mode = config.union_mode if union_mode is None else union_mode
735+
output_datetime_class = config.output_datetime_class if output_datetime_class is None else output_datetime_class
736+
output_date_class = config.output_date_class if output_date_class is None else output_date_class
737+
keyword_only = keyword_only or config.keyword_only
738+
frozen_dataclasses = frozen_dataclasses or config.frozen_dataclasses
739+
no_alias = no_alias or config.no_alias
740+
use_frozen_field = use_frozen_field or config.use_frozen_field
741+
use_default_factory_for_optional_nested_models = (
742+
use_default_factory_for_optional_nested_models or config.use_default_factory_for_optional_nested_models
743+
)
744+
formatters = config.formatters if formatters == DEFAULT_FORMATTERS else formatters
745+
settings_path = config.settings_path if settings_path is None else settings_path
746+
parent_scoped_naming = parent_scoped_naming or config.parent_scoped_naming
747+
naming_strategy = config.naming_strategy if naming_strategy is None else naming_strategy
748+
duplicate_name_suffix = config.duplicate_name_suffix if duplicate_name_suffix is None else duplicate_name_suffix
749+
dataclass_arguments = config.dataclass_arguments if dataclass_arguments is None else dataclass_arguments
750+
disable_future_imports = disable_future_imports or config.disable_future_imports
751+
type_mappings = config.type_mappings if type_mappings is None else type_mappings
752+
type_overrides = config.type_overrides if type_overrides is None else type_overrides
753+
read_only_write_only_model_type = (
754+
config.read_only_write_only_model_type
755+
if read_only_write_only_model_type is None
756+
else read_only_write_only_model_type
757+
)
758+
use_status_code_in_response_name = use_status_code_in_response_name or config.use_status_code_in_response_name
759+
all_exports_scope = config.all_exports_scope if all_exports_scope is None else all_exports_scope
760+
all_exports_collision_strategy = (
761+
config.all_exports_collision_strategy
762+
if all_exports_collision_strategy is None
763+
else all_exports_collision_strategy
764+
)
765+
field_type_collision_strategy = (
766+
config.field_type_collision_strategy
767+
if field_type_collision_strategy is None
768+
else field_type_collision_strategy
769+
)
770+
module_split_mode = config.module_split_mode if module_split_mode is None else module_split_mode
771+
586772
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
587773
match input_:
588774
case str():

src/datamodel_code_generator/__main__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,37 @@ def _filter_defs_by_strategy(
11161116
return {**schema, "$defs": new_defs}
11171117

11181118

1119+
def _try_rebuild_model(obj: type) -> None:
1120+
"""Try to rebuild a Pydantic model, handling config models specially."""
1121+
module = getattr(obj, "__module__", "")
1122+
class_name = getattr(obj, "__name__", "")
1123+
config_classes = {"GenerateConfig", "ParserConfig", "ParseConfig"}
1124+
if module in {"datamodel_code_generator.config", "config"} and class_name in config_classes:
1125+
from datamodel_code_generator.model.base import DataModel, DataModelFieldBase # noqa: PLC0415
1126+
from datamodel_code_generator.types import DataTypeManager, StrictTypes # noqa: PLC0415
1127+
1128+
try:
1129+
from datamodel_code_generator.model.pydantic_v2 import UnionMode # noqa: PLC0415
1130+
except ImportError: # pragma: no cover
1131+
from typing import Any # noqa: PLC0415
1132+
1133+
runtime_union_mode = Any
1134+
else:
1135+
runtime_union_mode = UnionMode
1136+
1137+
types_namespace = {
1138+
"Path": Path,
1139+
"DataModel": DataModel,
1140+
"DataModelFieldBase": DataModelFieldBase,
1141+
"DataTypeManager": DataTypeManager,
1142+
"StrictTypes": StrictTypes,
1143+
"UnionMode": runtime_union_mode,
1144+
}
1145+
obj.model_rebuild(_types_namespace=types_namespace)
1146+
else:
1147+
obj.model_rebuild()
1148+
1149+
11191150
def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
11201151
input_model: str,
11211152
input_file_type: InputFileType,
@@ -1199,6 +1230,8 @@ def _load_model_schema( # noqa: PLR0912, PLR0914, PLR0915
11991230
if not hasattr(obj, "model_json_schema"):
12001231
msg = "--input-model with Pydantic model requires Pydantic v2 runtime. Please upgrade Pydantic to v2."
12011232
raise Error(msg)
1233+
if hasattr(obj, "model_rebuild"): # pragma: no branch
1234+
_try_rebuild_model(obj)
12021235
schema_generator = _get_input_model_json_schema_class()
12031236
schema = obj.model_json_schema(schema_generator=schema_generator)
12041237
schema = _add_python_type_for_unserializable(schema, obj)

0 commit comments

Comments
 (0)