diff --git a/src/datamodel_code_generator/parser/graphql.py b/src/datamodel_code_generator/parser/graphql.py index 1249b61e0..53a751298 100644 --- a/src/datamodel_code_generator/parser/graphql.py +++ b/src/datamodel_code_generator/parser/graphql.py @@ -6,12 +6,10 @@ from __future__ import annotations -from pathlib import Path from typing import ( TYPE_CHECKING, Any, ) -from urllib.parse import ParseResult from typing_extensions import Unpack @@ -26,7 +24,6 @@ from datamodel_code_generator.parser.base import ( DataType, Parser, - Source, escape_characters, ) from datamodel_code_generator.reference import ModelType, Reference @@ -40,7 +37,8 @@ if TYPE_CHECKING: - from collections.abc import Iterator + from pathlib import Path + from urllib.parse import ParseResult from datamodel_code_generator._types import GraphQLParserConfigDict from datamodel_code_generator.config import GraphQLParserConfig @@ -137,31 +135,6 @@ def __init__( self.use_standard_collections = use_standard_collections self.use_union_operator = use_union_operator - def _get_context_source_path_parts(self) -> Iterator[tuple[Source, list[str]]]: - # TODO (denisart): Temporarily this method duplicates - # the method `datamodel_code_generator.parser.jsonschema.JsonSchemaParser._get_context_source_path_parts`. - - if isinstance(self.source, list) or ( # pragma: no cover - isinstance(self.source, Path) and self.source.is_dir() - ): # pragma: no cover - self.current_source_path = Path() - self.model_resolver.after_load_files = { - self.base_path.joinpath(s.path).resolve().as_posix() for s in self.iter_source - } - - for source in self.iter_source: - if isinstance(self.source, ParseResult): # pragma: no cover - path_parts = self.get_url_path_parts(self.source) - else: - path_parts = list(source.path.parts) - if self.current_source_path is not None: # pragma: no cover - self.current_source_path = source.path - with ( - self.model_resolver.current_base_path_context(source.path.parent), - self.model_resolver.current_root_context(path_parts), - ): - yield source, path_parts - def _resolve_types(self, paths: list[str], schema: graphql.GraphQLSchema) -> None: for type_name, type_ in schema.type_map.items(): if type_name.startswith("__"): @@ -523,13 +496,13 @@ def parse_raw(self) -> None: graphql.type.introspection.TypeKind.UNION: self.parse_union, } - for source, path_parts in self._get_context_source_path_parts(): - schema: graphql.GraphQLSchema = build_graphql_schema(source.text) - self.raw_obj = schema + combined_schema = "\n".join(source.text for source in self.iter_source) + schema: graphql.GraphQLSchema = build_graphql_schema(combined_schema) + self.raw_obj = schema - self._resolve_types(path_parts, schema) + self._resolve_types([], schema) - for next_type in self.parse_order: - for obj in self.support_graphql_types[next_type]: - parser_ = mapper_from_graphql_type_to_parser_method[next_type] - parser_(obj) + for next_type in self.parse_order: + for obj in self.support_graphql_types[next_type]: + parser_ = mapper_from_graphql_type_to_parser_method[next_type] + parser_(obj) diff --git a/tests/data/expected/main/graphql/split_graphql_schemas.py b/tests/data/expected/main/graphql/split_graphql_schemas.py new file mode 100644 index 000000000..7dd685173 --- /dev/null +++ b/tests/data/expected/main/graphql/split_graphql_schemas.py @@ -0,0 +1,46 @@ +# generated by datamodel-codegen: +# filename: split +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Literal, TypeAlias + +from pydantic import BaseModel, Field + +Boolean: TypeAlias = bool +""" +The `Boolean` scalar type represents `true` or `false`. +""" + + +Int: TypeAlias = int +""" +The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1. +""" + + +String: TypeAlias = str +""" +The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text. +""" + + +class Baz(BaseModel): + quux: Int + typename__: Literal['Baz'] | None = Field('Baz', alias='__typename') + + +class Bar(BaseModel): + eggs: String + foo: Foo + typename__: Literal['Bar'] | None = Field('Bar', alias='__typename') + + +class Foo(BaseModel): + baz: Bar + id: Int + typename__: Literal['Foo'] | None = Field('Foo', alias='__typename') + + +Bar.update_forward_refs() diff --git a/tests/data/graphql/split/a.graphql b/tests/data/graphql/split/a.graphql new file mode 100644 index 000000000..129cb1bdc --- /dev/null +++ b/tests/data/graphql/split/a.graphql @@ -0,0 +1,7 @@ +type Foo { + id: Int! +} + +extend type Bar { + eggs: String! +} diff --git a/tests/data/graphql/split/b.graphql b/tests/data/graphql/split/b.graphql new file mode 100644 index 000000000..18acd8cb9 --- /dev/null +++ b/tests/data/graphql/split/b.graphql @@ -0,0 +1,4 @@ +type Bar { + foo: Foo! +} + diff --git a/tests/data/graphql/split/c.graphql b/tests/data/graphql/split/c.graphql new file mode 100644 index 000000000..97d6dcb95 --- /dev/null +++ b/tests/data/graphql/split/c.graphql @@ -0,0 +1,7 @@ +extend type Foo { + baz: Bar! +} + +type Baz { + quux: Int! +} diff --git a/tests/main/graphql/test_main_graphql.py b/tests/main/graphql/test_main_graphql.py index 32310d0ec..253f4ceb7 100644 --- a/tests/main/graphql/test_main_graphql.py +++ b/tests/main/graphql/test_main_graphql.py @@ -772,3 +772,14 @@ def test_main_graphql_union_snake_case_field(output_file: Path) -> None: expected_file="union_snake_case_field.py", extra_args=["--snake-case-field", "--output-model-type", "pydantic_v2.BaseModel"], ) + + +def test_main_graphql_split_graphql_schemas(output_file: Path) -> None: + """Test GraphQL code generation with multiple schema files in a directory.""" + run_main_and_assert( + input_path=GRAPHQL_DATA_PATH / "split", + output_path=output_file, + input_file_type="graphql", + assert_func=assert_file_content, + expected_file="split_graphql_schemas.py", + )