Skip to content

Commit 36530b1

Browse files
authored
refactor: streamline parser configuration creation and enhance input handling (#2880)
1 parent efe8dfa commit 36530b1

1 file changed

Lines changed: 139 additions & 193 deletions

File tree

src/datamodel_code_generator/__init__.py

Lines changed: 139 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,15 @@
6161
from datamodel_code_generator.parser import DefaultPutDict, LiteralType
6262

6363
if TYPE_CHECKING:
64+
from datamodel_code_generator._types import GraphQLParserConfigDict, OpenAPIParserConfigDict, ParserConfigDict
6465
from datamodel_code_generator._types.generate_config_dict import GenerateConfigDict
65-
from datamodel_code_generator.config import GenerateConfig
66-
from datamodel_code_generator.parser.base import Parser
66+
from datamodel_code_generator.config import GenerateConfig, ParserConfig
6767

6868
YamlScalar: TypeAlias = str | int | float | bool | None
6969
YamlValue = TypeAliasType("YamlValue", "dict[str, YamlValue] | list[YamlValue] | YamlScalar")
7070

7171
T = TypeVar("T")
72+
_ConfigT = TypeVar("_ConfigT", bound="ParserConfig")
7273

7374
# Import is_pydantic_v2 here for module-level YamlValue type definition
7475
from datamodel_code_generator.util import is_pydantic_v2 # noqa: E402
@@ -448,6 +449,27 @@ def _build_module_content(
448449
return "\n".join(lines)
449450

450451

452+
def _create_parser_config(
453+
config_class: type[_ConfigT],
454+
generate_config: GenerateConfig,
455+
additional_options: ParserConfigDict | OpenAPIParserConfigDict | GraphQLParserConfigDict,
456+
) -> _ConfigT:
457+
"""Create a parser config from GenerateConfig with additional options.
458+
459+
For Pydantic v2: Uses model_validate with extra='ignore' and model_copy.
460+
For Pydantic v1: Uses dict comprehension to filter fields.
461+
"""
462+
if is_pydantic_v2():
463+
return config_class.model_validate(generate_config, from_attributes=True, extra="ignore").model_copy(
464+
update=additional_options
465+
)
466+
parser_config_fields = set(config_class.__fields__.keys())
467+
all_options = {
468+
k: v for k, v in generate_config.dict().items() if k in parser_config_fields and k not in additional_options
469+
} | dict(additional_options)
470+
return config_class.parse_obj(all_options)
471+
472+
451473
def generate( # noqa: PLR0912, PLR0914, PLR0915
452474
input_: Path | str | ParseResult | Mapping[str, Any],
453475
*,
@@ -580,80 +602,62 @@ def generate( # noqa: PLR0912, PLR0914, PLR0915
580602
if isinstance(input_, Path) and input_.is_file() and input_file_type not in RAW_DATA_TYPES:
581603
input_text = input_text_
582604

583-
kwargs: dict[str, Any] = {}
584-
if input_file_type == InputFileType.OpenAPI: # noqa: PLR1702
585-
from datamodel_code_generator.parser.openapi import OpenAPIParser # noqa: PLC0415
605+
if input_file_type not in {InputFileType.OpenAPI, InputFileType.GraphQL} and input_file_type in RAW_DATA_TYPES:
606+
import json # noqa: PLC0415
586607

587-
parser_class: type[Parser] = OpenAPIParser
588-
kwargs["openapi_scopes"] = config.openapi_scopes
589-
kwargs["include_path_parameters"] = config.include_path_parameters
590-
kwargs["use_status_code_in_response_name"] = config.use_status_code_in_response_name
591-
elif input_file_type == InputFileType.GraphQL:
592-
from datamodel_code_generator.parser.graphql import GraphQLParser # noqa: PLC0415
593-
594-
parser_class: type[Parser] = GraphQLParser
595-
else:
596-
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser # noqa: PLC0415
608+
try:
609+
if isinstance(input_, Path) and input_.is_dir(): # pragma: no cover
610+
msg = f"Input must be a file for {input_file_type}"
611+
raise Error(msg) # noqa: TRY301
612+
obj: dict[str, Any]
613+
if input_file_type == InputFileType.CSV:
614+
import csv # noqa: PLC0415
615+
616+
def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
617+
csv_reader = csv.DictReader(csv_file)
618+
assert csv_reader.fieldnames is not None
619+
return dict(zip(csv_reader.fieldnames, next(csv_reader), strict=False))
620+
621+
if isinstance(input_, Path):
622+
with input_.open(encoding=config.encoding) as f:
623+
obj = get_header_and_first_line(f)
624+
else:
625+
import io # noqa: PLC0415
597626

598-
parser_class = JsonSchemaParser
599-
600-
if input_file_type in RAW_DATA_TYPES:
601-
import json # noqa: PLC0415
602-
603-
try:
604-
if isinstance(input_, Path) and input_.is_dir(): # pragma: no cover
605-
msg = f"Input must be a file for {input_file_type}"
606-
raise Error(msg) # noqa: TRY301
607-
obj: dict[str, Any]
608-
if input_file_type == InputFileType.CSV:
609-
import csv # noqa: PLC0415
610-
611-
def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
612-
csv_reader = csv.DictReader(csv_file)
613-
assert csv_reader.fieldnames is not None
614-
return dict(zip(csv_reader.fieldnames, next(csv_reader), strict=False))
615-
616-
if isinstance(input_, Path):
617-
with input_.open(encoding=config.encoding) as f:
618-
obj = get_header_and_first_line(f)
619-
else:
620-
import io # noqa: PLC0415
621-
622-
obj = get_header_and_first_line(io.StringIO(input_text))
623-
elif input_file_type == InputFileType.Yaml:
624-
if isinstance(input_, Path):
625-
obj = load_yaml_dict(input_.read_text(encoding=config.encoding))
626-
else: # pragma: no cover
627-
assert input_text is not None
628-
obj = load_yaml_dict(input_text)
629-
elif input_file_type == InputFileType.Json:
630-
if isinstance(input_, Path):
631-
obj = json.loads(input_.read_text(encoding=config.encoding))
632-
else:
633-
assert input_text is not None
634-
obj = json.loads(input_text)
635-
elif input_file_type == InputFileType.Dict:
636-
import ast # noqa: PLC0415
637-
638-
# Input can be a dict object stored in a python file
639-
obj = (
640-
ast.literal_eval(input_.read_text(encoding=config.encoding))
641-
if isinstance(input_, Path)
642-
else cast("dict[str, Any]", input_)
643-
)
627+
obj = get_header_and_first_line(io.StringIO(input_text))
628+
elif input_file_type == InputFileType.Yaml:
629+
if isinstance(input_, Path):
630+
obj = load_yaml_dict(input_.read_text(encoding=config.encoding))
644631
else: # pragma: no cover
645-
msg = f"Unsupported input file type: {input_file_type}"
646-
raise Error(msg) # noqa: TRY301
647-
except Error:
648-
raise
649-
except Exception as exc:
650-
raise InvalidFileFormatError(exc, input_file_type) from exc
632+
assert input_text is not None
633+
obj = load_yaml_dict(input_text)
634+
elif input_file_type == InputFileType.Json:
635+
if isinstance(input_, Path):
636+
obj = json.loads(input_.read_text(encoding=config.encoding))
637+
else:
638+
assert input_text is not None
639+
obj = json.loads(input_text)
640+
elif input_file_type == InputFileType.Dict:
641+
import ast # noqa: PLC0415
642+
643+
obj = (
644+
ast.literal_eval(input_.read_text(encoding=config.encoding))
645+
if isinstance(input_, Path)
646+
else cast("dict[str, Any]", input_)
647+
)
648+
else: # pragma: no cover
649+
msg = f"Unsupported input file type: {input_file_type}"
650+
raise Error(msg) # noqa: TRY301
651+
except Error:
652+
raise
653+
except Exception as exc:
654+
raise InvalidFileFormatError(exc, input_file_type) from exc
651655

652-
from genson import SchemaBuilder # noqa: PLC0415
656+
from genson import SchemaBuilder # noqa: PLC0415
653657

654-
builder = SchemaBuilder()
655-
builder.add_object(obj)
656-
input_text = json.dumps(builder.to_schema())
658+
builder = SchemaBuilder()
659+
builder.add_object(obj)
660+
input_text = json.dumps(builder.to_schema())
657661

658662
if isinstance(input_, ParseResult) and input_file_type not in RAW_DATA_TYPES:
659663
input_text = None
@@ -676,11 +680,6 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
676680
use_root_model_type_alias=config.use_root_model_type_alias,
677681
)
678682

679-
# Add GraphQL-specific model types if needed
680-
if input_file_type == InputFileType.GraphQL:
681-
kwargs["data_model_scalar_type"] = data_model_types.scalar_model
682-
kwargs["data_model_union_type"] = data_model_types.union_model
683-
684683
if isinstance(input_, Mapping) and input_file_type not in RAW_DATA_TYPES:
685684
source = dict(input_)
686685
else:
@@ -689,125 +688,72 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
689688

690689
defer_formatting = config.output is not None and not config.output.suffix
691690

692-
parser = parser_class(
693-
source=source,
694-
data_model_type=data_model_types.data_model,
695-
data_model_root_type=data_model_types.root_model,
696-
data_model_field_type=data_model_types.field_model,
697-
data_type_manager_type=data_model_types.data_type_manager,
698-
base_class=config.base_class,
699-
base_class_map=config.base_class_map,
700-
additional_imports=config.additional_imports,
701-
class_decorators=config.class_decorators,
702-
custom_template_dir=config.custom_template_dir,
703-
extra_template_data=extra_template_data,
704-
target_python_version=config.target_python_version,
705-
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
706-
validation=config.validation,
707-
field_constraints=config.field_constraints,
708-
snake_case_field=config.snake_case_field,
709-
strip_default_none=config.strip_default_none,
710-
aliases=config.aliases,
711-
allow_population_by_field_name=config.allow_population_by_field_name,
712-
allow_extra_fields=config.allow_extra_fields,
713-
extra_fields=config.extra_fields,
714-
use_generic_base_class=config.use_generic_base_class,
715-
apply_default_values_for_required_fields=config.apply_default_values_for_required_fields,
716-
force_optional_for_required_fields=config.force_optional_for_required_fields,
717-
class_name=config.class_name,
718-
use_standard_collections=config.use_standard_collections,
719-
base_path=input_.parent if isinstance(input_, Path) and input_.is_file() else None,
720-
use_schema_description=config.use_schema_description,
721-
use_field_description=config.use_field_description,
722-
use_field_description_example=config.use_field_description_example,
723-
use_attribute_docstrings=config.use_attribute_docstrings,
724-
use_inline_field_description=config.use_inline_field_description,
725-
use_default_kwarg=config.use_default_kwarg,
726-
reuse_model=config.reuse_model,
727-
reuse_scope=config.reuse_scope,
728-
shared_module_name=config.shared_module_name,
729-
enum_field_as_literal=config.enum_field_as_literal
730-
if config.enum_field_as_literal is not None
731-
else (LiteralType.All if config.output_model_type == DataModelType.TypingTypedDict else None),
732-
enum_field_as_literal_map=config.enum_field_as_literal_map,
733-
ignore_enum_constraints=config.ignore_enum_constraints,
734-
use_one_literal_as_default=config.use_one_literal_as_default,
735-
use_enum_values_in_discriminator=config.use_enum_values_in_discriminator,
736-
set_default_enum_member=True
737-
if config.output_model_type == DataModelType.DataclassesDataclass
738-
else config.set_default_enum_member,
739-
use_subclass_enum=config.use_subclass_enum,
740-
use_specialized_enum=config.use_specialized_enum,
741-
strict_nullable=config.strict_nullable,
742-
use_generic_container_types=config.use_generic_container_types,
743-
enable_faux_immutability=config.enable_faux_immutability,
744-
remote_text_cache=remote_text_cache,
745-
disable_appending_item_suffix=config.disable_appending_item_suffix,
746-
strict_types=config.strict_types,
747-
empty_enum_field_name=config.empty_enum_field_name,
748-
custom_class_name_generator=config.custom_class_name_generator,
749-
field_extra_keys=config.field_extra_keys,
750-
field_include_all_keys=config.field_include_all_keys,
751-
field_extra_keys_without_x_prefix=config.field_extra_keys_without_x_prefix,
752-
model_extra_keys=config.model_extra_keys,
753-
model_extra_keys_without_x_prefix=config.model_extra_keys_without_x_prefix,
754-
wrap_string_literal=config.wrap_string_literal,
755-
use_title_as_name=config.use_title_as_name,
756-
use_operation_id_as_name=config.use_operation_id_as_name,
757-
use_unique_items_as_set=config.use_unique_items_as_set,
758-
use_tuple_for_fixed_items=config.use_tuple_for_fixed_items,
759-
allof_merge_mode=config.allof_merge_mode,
760-
allof_class_hierarchy=config.allof_class_hierarchy,
761-
http_headers=config.http_headers,
762-
http_ignore_tls=config.http_ignore_tls,
763-
http_timeout=config.http_timeout,
764-
use_annotated=config.use_annotated,
765-
use_serialize_as_any=config.use_serialize_as_any,
766-
use_non_positive_negative_number_constrained_types=config.use_non_positive_negative_number_constrained_types,
767-
use_decimal_for_multiple_of=config.use_decimal_for_multiple_of,
768-
original_field_name_delimiter=config.original_field_name_delimiter,
769-
use_double_quotes=config.use_double_quotes,
770-
use_union_operator=config.use_union_operator,
771-
collapse_root_models=config.collapse_root_models,
772-
collapse_root_models_name_strategy=config.collapse_root_models_name_strategy,
773-
collapse_reuse_models=config.collapse_reuse_models,
774-
skip_root_model=config.skip_root_model,
775-
use_type_alias=config.use_type_alias,
776-
special_field_name_prefix=config.special_field_name_prefix,
777-
remove_special_field_name_prefix=config.remove_special_field_name_prefix,
778-
capitalise_enum_members=config.capitalise_enum_members,
779-
keep_model_order=config.keep_model_order,
780-
known_third_party=data_model_types.known_third_party,
781-
custom_formatters=config.custom_formatters,
782-
custom_formatters_kwargs=config.custom_formatters_kwargs,
783-
use_pendulum=config.use_pendulum,
784-
use_standard_primitive_types=config.use_standard_primitive_types,
785-
http_query_parameters=config.http_query_parameters,
786-
treat_dot_as_module=config.treat_dot_as_module,
787-
use_exact_imports=config.use_exact_imports,
788-
default_field_extras=default_field_extras,
789-
target_datetime_class=config.output_datetime_class,
790-
target_date_class=config.output_date_class,
791-
keyword_only=config.keyword_only,
792-
frozen_dataclasses=config.frozen_dataclasses,
793-
no_alias=config.no_alias,
794-
use_frozen_field=config.use_frozen_field,
795-
use_default_factory_for_optional_nested_models=config.use_default_factory_for_optional_nested_models,
796-
formatters=config.formatters,
797-
defer_formatting=defer_formatting,
798-
encoding=config.encoding,
799-
parent_scoped_naming=config.parent_scoped_naming,
800-
naming_strategy=config.naming_strategy,
801-
duplicate_name_suffix=config.duplicate_name_suffix,
802-
dataclass_arguments=dataclass_arguments,
803-
type_mappings=config.type_mappings,
804-
type_overrides=config.type_overrides,
805-
read_only_write_only_model_type=config.read_only_write_only_model_type,
806-
field_type_collision_strategy=config.field_type_collision_strategy,
807-
target_pydantic_version=config.target_pydantic_version,
808-
**kwargs,
691+
from datamodel_code_generator.config import ( # noqa: PLC0415
692+
GraphQLParserConfig,
693+
JSONSchemaParserConfig,
694+
OpenAPIParserConfig,
809695
)
810696

697+
additional_options: ParserConfigDict = {
698+
"data_model_type": data_model_types.data_model,
699+
"data_model_root_type": data_model_types.root_model,
700+
"data_model_field_type": data_model_types.field_model,
701+
"data_type_manager_type": data_model_types.data_type_manager,
702+
"dump_resolve_reference_action": data_model_types.dump_resolve_reference_action,
703+
"extra_template_data": extra_template_data,
704+
"base_path": input_.parent if isinstance(input_, Path) and input_.is_file() else None,
705+
"remote_text_cache": remote_text_cache,
706+
"known_third_party": data_model_types.known_third_party,
707+
"default_field_extras": default_field_extras,
708+
"target_datetime_class": (
709+
config.output_datetime_class
710+
if config.output_datetime_class is not None
711+
else (
712+
DatetimeClassType.Datetime
713+
if input_file_type == InputFileType.GraphQL
714+
else DatetimeClassType.Awaredatetime
715+
)
716+
),
717+
"target_date_class": config.output_date_class,
718+
"dataclass_arguments": dataclass_arguments,
719+
"defer_formatting": defer_formatting,
720+
"enum_field_as_literal": (
721+
config.enum_field_as_literal
722+
if config.enum_field_as_literal is not None
723+
else (LiteralType.All if config.output_model_type == DataModelType.TypingTypedDict else None)
724+
),
725+
"set_default_enum_member": (
726+
True if config.output_model_type == DataModelType.DataclassesDataclass else config.set_default_enum_member
727+
),
728+
}
729+
730+
if input_file_type == InputFileType.OpenAPI:
731+
from datamodel_code_generator.parser.openapi import OpenAPIParser # noqa: PLC0415
732+
733+
openapi_additional_options: OpenAPIParserConfigDict = {
734+
"openapi_scopes": config.openapi_scopes,
735+
"include_path_parameters": config.include_path_parameters,
736+
"use_status_code_in_response_name": config.use_status_code_in_response_name,
737+
**additional_options,
738+
}
739+
parser_config = _create_parser_config(OpenAPIParserConfig, config, openapi_additional_options)
740+
parser = OpenAPIParser(source=source, config=parser_config)
741+
elif input_file_type == InputFileType.GraphQL:
742+
from datamodel_code_generator.parser.graphql import GraphQLParser # noqa: PLC0415
743+
744+
graphql_additional_options: GraphQLParserConfigDict = {
745+
"data_model_scalar_type": data_model_types.scalar_model,
746+
"data_model_union_type": data_model_types.union_model,
747+
**additional_options,
748+
}
749+
parser_config = _create_parser_config(GraphQLParserConfig, config, graphql_additional_options)
750+
parser = GraphQLParser(source=source, config=parser_config)
751+
else:
752+
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser # noqa: PLC0415
753+
754+
parser_config = _create_parser_config(JSONSchemaParserConfig, config, additional_options)
755+
parser = JsonSchemaParser(source=source, config=parser_config)
756+
811757
with chdir(config.output):
812758
results = parser.parse(
813759
settings_path=config.settings_path,

0 commit comments

Comments
 (0)