Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 139 additions & 193 deletions src/datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,15 @@
from datamodel_code_generator.parser import DefaultPutDict, LiteralType

if TYPE_CHECKING:
from datamodel_code_generator._types import GraphQLParserConfigDict, OpenAPIParserConfigDict, ParserConfigDict
from datamodel_code_generator._types.generate_config_dict import GenerateConfigDict
from datamodel_code_generator.config import GenerateConfig
from datamodel_code_generator.parser.base import Parser
from datamodel_code_generator.config import GenerateConfig, ParserConfig

YamlScalar: TypeAlias = str | int | float | bool | None
YamlValue = TypeAliasType("YamlValue", "dict[str, YamlValue] | list[YamlValue] | YamlScalar")

T = TypeVar("T")
_ConfigT = TypeVar("_ConfigT", bound="ParserConfig")

# Import is_pydantic_v2 here for module-level YamlValue type definition
from datamodel_code_generator.util import is_pydantic_v2 # noqa: E402
Expand Down Expand Up @@ -448,6 +449,27 @@ def _build_module_content(
return "\n".join(lines)


def _create_parser_config(
config_class: type[_ConfigT],
generate_config: GenerateConfig,
additional_options: ParserConfigDict | OpenAPIParserConfigDict | GraphQLParserConfigDict,
) -> _ConfigT:
"""Create a parser config from GenerateConfig with additional options.

For Pydantic v2: Uses model_validate with extra='ignore' and model_copy.
For Pydantic v1: Uses dict comprehension to filter fields.
"""
if is_pydantic_v2():
return config_class.model_validate(generate_config, from_attributes=True, extra="ignore").model_copy(
update=additional_options
)
parser_config_fields = set(config_class.__fields__.keys())
all_options = {
k: v for k, v in generate_config.dict().items() if k in parser_config_fields and k not in additional_options
} | dict(additional_options)
return config_class.parse_obj(all_options)


def generate( # noqa: PLR0912, PLR0914, PLR0915
input_: Path | str | ParseResult | Mapping[str, Any],
*,
Expand Down Expand Up @@ -580,80 +602,62 @@ def generate( # noqa: PLR0912, PLR0914, PLR0915
if isinstance(input_, Path) and input_.is_file() and input_file_type not in RAW_DATA_TYPES:
input_text = input_text_

kwargs: dict[str, Any] = {}
if input_file_type == InputFileType.OpenAPI: # noqa: PLR1702
from datamodel_code_generator.parser.openapi import OpenAPIParser # noqa: PLC0415
if input_file_type not in {InputFileType.OpenAPI, InputFileType.GraphQL} and input_file_type in RAW_DATA_TYPES:
import json # noqa: PLC0415

parser_class: type[Parser] = OpenAPIParser
kwargs["openapi_scopes"] = config.openapi_scopes
kwargs["include_path_parameters"] = config.include_path_parameters
kwargs["use_status_code_in_response_name"] = config.use_status_code_in_response_name
elif input_file_type == InputFileType.GraphQL:
from datamodel_code_generator.parser.graphql import GraphQLParser # noqa: PLC0415

parser_class: type[Parser] = GraphQLParser
else:
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser # noqa: PLC0415
try:
if isinstance(input_, Path) and input_.is_dir(): # pragma: no cover
msg = f"Input must be a file for {input_file_type}"
raise Error(msg) # noqa: TRY301
obj: dict[str, Any]
if input_file_type == InputFileType.CSV:
import csv # noqa: PLC0415

def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
csv_reader = csv.DictReader(csv_file)
assert csv_reader.fieldnames is not None
return dict(zip(csv_reader.fieldnames, next(csv_reader), strict=False))

if isinstance(input_, Path):
with input_.open(encoding=config.encoding) as f:
obj = get_header_and_first_line(f)
else:
import io # noqa: PLC0415

parser_class = JsonSchemaParser

if input_file_type in RAW_DATA_TYPES:
import json # noqa: PLC0415

try:
if isinstance(input_, Path) and input_.is_dir(): # pragma: no cover
msg = f"Input must be a file for {input_file_type}"
raise Error(msg) # noqa: TRY301
obj: dict[str, Any]
if input_file_type == InputFileType.CSV:
import csv # noqa: PLC0415

def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
csv_reader = csv.DictReader(csv_file)
assert csv_reader.fieldnames is not None
return dict(zip(csv_reader.fieldnames, next(csv_reader), strict=False))

if isinstance(input_, Path):
with input_.open(encoding=config.encoding) as f:
obj = get_header_and_first_line(f)
else:
import io # noqa: PLC0415

obj = get_header_and_first_line(io.StringIO(input_text))
elif input_file_type == InputFileType.Yaml:
if isinstance(input_, Path):
obj = load_yaml_dict(input_.read_text(encoding=config.encoding))
else: # pragma: no cover
assert input_text is not None
obj = load_yaml_dict(input_text)
elif input_file_type == InputFileType.Json:
if isinstance(input_, Path):
obj = json.loads(input_.read_text(encoding=config.encoding))
else:
assert input_text is not None
obj = json.loads(input_text)
elif input_file_type == InputFileType.Dict:
import ast # noqa: PLC0415

# Input can be a dict object stored in a python file
obj = (
ast.literal_eval(input_.read_text(encoding=config.encoding))
if isinstance(input_, Path)
else cast("dict[str, Any]", input_)
)
obj = get_header_and_first_line(io.StringIO(input_text))
elif input_file_type == InputFileType.Yaml:
if isinstance(input_, Path):
obj = load_yaml_dict(input_.read_text(encoding=config.encoding))
else: # pragma: no cover
msg = f"Unsupported input file type: {input_file_type}"
raise Error(msg) # noqa: TRY301
except Error:
raise
except Exception as exc:
raise InvalidFileFormatError(exc, input_file_type) from exc
assert input_text is not None
obj = load_yaml_dict(input_text)
elif input_file_type == InputFileType.Json:
if isinstance(input_, Path):
obj = json.loads(input_.read_text(encoding=config.encoding))
else:
assert input_text is not None
obj = json.loads(input_text)
elif input_file_type == InputFileType.Dict:
import ast # noqa: PLC0415

obj = (
ast.literal_eval(input_.read_text(encoding=config.encoding))
if isinstance(input_, Path)
else cast("dict[str, Any]", input_)
)
else: # pragma: no cover
msg = f"Unsupported input file type: {input_file_type}"
raise Error(msg) # noqa: TRY301
except Error:
raise
except Exception as exc:
raise InvalidFileFormatError(exc, input_file_type) from exc

from genson import SchemaBuilder # noqa: PLC0415
from genson import SchemaBuilder # noqa: PLC0415

builder = SchemaBuilder()
builder.add_object(obj)
input_text = json.dumps(builder.to_schema())
builder = SchemaBuilder()
builder.add_object(obj)
input_text = json.dumps(builder.to_schema())

if isinstance(input_, ParseResult) and input_file_type not in RAW_DATA_TYPES:
input_text = None
Expand All @@ -676,11 +680,6 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
use_root_model_type_alias=config.use_root_model_type_alias,
)

# Add GraphQL-specific model types if needed
if input_file_type == InputFileType.GraphQL:
kwargs["data_model_scalar_type"] = data_model_types.scalar_model
kwargs["data_model_union_type"] = data_model_types.union_model

if isinstance(input_, Mapping) and input_file_type not in RAW_DATA_TYPES:
source = dict(input_)
else:
Expand All @@ -689,125 +688,72 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:

defer_formatting = config.output is not None and not config.output.suffix

parser = parser_class(
source=source,
data_model_type=data_model_types.data_model,
data_model_root_type=data_model_types.root_model,
data_model_field_type=data_model_types.field_model,
data_type_manager_type=data_model_types.data_type_manager,
base_class=config.base_class,
base_class_map=config.base_class_map,
additional_imports=config.additional_imports,
class_decorators=config.class_decorators,
custom_template_dir=config.custom_template_dir,
extra_template_data=extra_template_data,
target_python_version=config.target_python_version,
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
validation=config.validation,
field_constraints=config.field_constraints,
snake_case_field=config.snake_case_field,
strip_default_none=config.strip_default_none,
aliases=config.aliases,
allow_population_by_field_name=config.allow_population_by_field_name,
allow_extra_fields=config.allow_extra_fields,
extra_fields=config.extra_fields,
use_generic_base_class=config.use_generic_base_class,
apply_default_values_for_required_fields=config.apply_default_values_for_required_fields,
force_optional_for_required_fields=config.force_optional_for_required_fields,
class_name=config.class_name,
use_standard_collections=config.use_standard_collections,
base_path=input_.parent if isinstance(input_, Path) and input_.is_file() else None,
use_schema_description=config.use_schema_description,
use_field_description=config.use_field_description,
use_field_description_example=config.use_field_description_example,
use_attribute_docstrings=config.use_attribute_docstrings,
use_inline_field_description=config.use_inline_field_description,
use_default_kwarg=config.use_default_kwarg,
reuse_model=config.reuse_model,
reuse_scope=config.reuse_scope,
shared_module_name=config.shared_module_name,
enum_field_as_literal=config.enum_field_as_literal
if config.enum_field_as_literal is not None
else (LiteralType.All if config.output_model_type == DataModelType.TypingTypedDict else None),
enum_field_as_literal_map=config.enum_field_as_literal_map,
ignore_enum_constraints=config.ignore_enum_constraints,
use_one_literal_as_default=config.use_one_literal_as_default,
use_enum_values_in_discriminator=config.use_enum_values_in_discriminator,
set_default_enum_member=True
if config.output_model_type == DataModelType.DataclassesDataclass
else config.set_default_enum_member,
use_subclass_enum=config.use_subclass_enum,
use_specialized_enum=config.use_specialized_enum,
strict_nullable=config.strict_nullable,
use_generic_container_types=config.use_generic_container_types,
enable_faux_immutability=config.enable_faux_immutability,
remote_text_cache=remote_text_cache,
disable_appending_item_suffix=config.disable_appending_item_suffix,
strict_types=config.strict_types,
empty_enum_field_name=config.empty_enum_field_name,
custom_class_name_generator=config.custom_class_name_generator,
field_extra_keys=config.field_extra_keys,
field_include_all_keys=config.field_include_all_keys,
field_extra_keys_without_x_prefix=config.field_extra_keys_without_x_prefix,
model_extra_keys=config.model_extra_keys,
model_extra_keys_without_x_prefix=config.model_extra_keys_without_x_prefix,
wrap_string_literal=config.wrap_string_literal,
use_title_as_name=config.use_title_as_name,
use_operation_id_as_name=config.use_operation_id_as_name,
use_unique_items_as_set=config.use_unique_items_as_set,
use_tuple_for_fixed_items=config.use_tuple_for_fixed_items,
allof_merge_mode=config.allof_merge_mode,
allof_class_hierarchy=config.allof_class_hierarchy,
http_headers=config.http_headers,
http_ignore_tls=config.http_ignore_tls,
http_timeout=config.http_timeout,
use_annotated=config.use_annotated,
use_serialize_as_any=config.use_serialize_as_any,
use_non_positive_negative_number_constrained_types=config.use_non_positive_negative_number_constrained_types,
use_decimal_for_multiple_of=config.use_decimal_for_multiple_of,
original_field_name_delimiter=config.original_field_name_delimiter,
use_double_quotes=config.use_double_quotes,
use_union_operator=config.use_union_operator,
collapse_root_models=config.collapse_root_models,
collapse_root_models_name_strategy=config.collapse_root_models_name_strategy,
collapse_reuse_models=config.collapse_reuse_models,
skip_root_model=config.skip_root_model,
use_type_alias=config.use_type_alias,
special_field_name_prefix=config.special_field_name_prefix,
remove_special_field_name_prefix=config.remove_special_field_name_prefix,
capitalise_enum_members=config.capitalise_enum_members,
keep_model_order=config.keep_model_order,
known_third_party=data_model_types.known_third_party,
custom_formatters=config.custom_formatters,
custom_formatters_kwargs=config.custom_formatters_kwargs,
use_pendulum=config.use_pendulum,
use_standard_primitive_types=config.use_standard_primitive_types,
http_query_parameters=config.http_query_parameters,
treat_dot_as_module=config.treat_dot_as_module,
use_exact_imports=config.use_exact_imports,
default_field_extras=default_field_extras,
target_datetime_class=config.output_datetime_class,
target_date_class=config.output_date_class,
keyword_only=config.keyword_only,
frozen_dataclasses=config.frozen_dataclasses,
no_alias=config.no_alias,
use_frozen_field=config.use_frozen_field,
use_default_factory_for_optional_nested_models=config.use_default_factory_for_optional_nested_models,
formatters=config.formatters,
defer_formatting=defer_formatting,
encoding=config.encoding,
parent_scoped_naming=config.parent_scoped_naming,
naming_strategy=config.naming_strategy,
duplicate_name_suffix=config.duplicate_name_suffix,
dataclass_arguments=dataclass_arguments,
type_mappings=config.type_mappings,
type_overrides=config.type_overrides,
read_only_write_only_model_type=config.read_only_write_only_model_type,
field_type_collision_strategy=config.field_type_collision_strategy,
target_pydantic_version=config.target_pydantic_version,
**kwargs,
from datamodel_code_generator.config import ( # noqa: PLC0415
GraphQLParserConfig,
JSONSchemaParserConfig,
OpenAPIParserConfig,
)

additional_options: ParserConfigDict = {
"data_model_type": data_model_types.data_model,
"data_model_root_type": data_model_types.root_model,
"data_model_field_type": data_model_types.field_model,
"data_type_manager_type": data_model_types.data_type_manager,
"dump_resolve_reference_action": data_model_types.dump_resolve_reference_action,
"extra_template_data": extra_template_data,
"base_path": input_.parent if isinstance(input_, Path) and input_.is_file() else None,
"remote_text_cache": remote_text_cache,
"known_third_party": data_model_types.known_third_party,
"default_field_extras": default_field_extras,
"target_datetime_class": (
config.output_datetime_class
if config.output_datetime_class is not None
else (
DatetimeClassType.Datetime
if input_file_type == InputFileType.GraphQL
else DatetimeClassType.Awaredatetime
)
),
"target_date_class": config.output_date_class,
"dataclass_arguments": dataclass_arguments,
"defer_formatting": defer_formatting,
"enum_field_as_literal": (
config.enum_field_as_literal
if config.enum_field_as_literal is not None
else (LiteralType.All if config.output_model_type == DataModelType.TypingTypedDict else None)
),
"set_default_enum_member": (
True if config.output_model_type == DataModelType.DataclassesDataclass else config.set_default_enum_member
),
}

if input_file_type == InputFileType.OpenAPI:
from datamodel_code_generator.parser.openapi import OpenAPIParser # noqa: PLC0415

openapi_additional_options: OpenAPIParserConfigDict = {
"openapi_scopes": config.openapi_scopes,
"include_path_parameters": config.include_path_parameters,
"use_status_code_in_response_name": config.use_status_code_in_response_name,
**additional_options,
}
parser_config = _create_parser_config(OpenAPIParserConfig, config, openapi_additional_options)
parser = OpenAPIParser(source=source, config=parser_config)
elif input_file_type == InputFileType.GraphQL:
from datamodel_code_generator.parser.graphql import GraphQLParser # noqa: PLC0415

graphql_additional_options: GraphQLParserConfigDict = {
"data_model_scalar_type": data_model_types.scalar_model,
"data_model_union_type": data_model_types.union_model,
**additional_options,
}
parser_config = _create_parser_config(GraphQLParserConfig, config, graphql_additional_options)
parser = GraphQLParser(source=source, config=parser_config)
else:
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser # noqa: PLC0415

parser_config = _create_parser_config(JSONSchemaParserConfig, config, additional_options)
parser = JsonSchemaParser(source=source, config=parser_config)

with chdir(config.output):
results = parser.parse(
settings_path=config.settings_path,
Expand Down
Loading