Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 10 additions & 37 deletions src/datamodel_code_generator/parser/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

from __future__ import annotations

from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from urllib.parse import ParseResult

from typing_extensions import Unpack

Expand All @@ -26,7 +24,6 @@
from datamodel_code_generator.parser.base import (
DataType,
Parser,
Source,
escape_characters,
)
from datamodel_code_generator.reference import ModelType, Reference
Expand All @@ -40,7 +37,8 @@


if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
from urllib.parse import ParseResult

from datamodel_code_generator._types import GraphQLParserConfigDict
from datamodel_code_generator.config import GraphQLParserConfig
Expand Down Expand Up @@ -137,31 +135,6 @@ def __init__(
self.use_standard_collections = use_standard_collections
self.use_union_operator = use_union_operator

def _get_context_source_path_parts(self) -> Iterator[tuple[Source, list[str]]]:
# TODO (denisart): Temporarily this method duplicates
# the method `datamodel_code_generator.parser.jsonschema.JsonSchemaParser._get_context_source_path_parts`.

if isinstance(self.source, list) or ( # pragma: no cover
isinstance(self.source, Path) and self.source.is_dir()
): # pragma: no cover
self.current_source_path = Path()
self.model_resolver.after_load_files = {
self.base_path.joinpath(s.path).resolve().as_posix() for s in self.iter_source
}

for source in self.iter_source:
if isinstance(self.source, ParseResult): # pragma: no cover
path_parts = self.get_url_path_parts(self.source)
else:
path_parts = list(source.path.parts)
if self.current_source_path is not None: # pragma: no cover
self.current_source_path = source.path
with (
self.model_resolver.current_base_path_context(source.path.parent),
self.model_resolver.current_root_context(path_parts),
):
yield source, path_parts

def _resolve_types(self, paths: list[str], schema: graphql.GraphQLSchema) -> None:
for type_name, type_ in schema.type_map.items():
if type_name.startswith("__"):
Expand Down Expand Up @@ -523,13 +496,13 @@ def parse_raw(self) -> None:
graphql.type.introspection.TypeKind.UNION: self.parse_union,
}

for source, path_parts in self._get_context_source_path_parts():
schema: graphql.GraphQLSchema = build_graphql_schema(source.text)
self.raw_obj = schema
combined_schema = "\n".join(source.text for source in self.iter_source)
schema: graphql.GraphQLSchema = build_graphql_schema(combined_schema)
self.raw_obj = schema

self._resolve_types(path_parts, schema)
self._resolve_types([], schema)
Comment thread
koxudaxi marked this conversation as resolved.

for next_type in self.parse_order:
for obj in self.support_graphql_types[next_type]:
parser_ = mapper_from_graphql_type_to_parser_method[next_type]
parser_(obj)
for next_type in self.parse_order:
for obj in self.support_graphql_types[next_type]:
parser_ = mapper_from_graphql_type_to_parser_method[next_type]
parser_(obj)
46 changes: 46 additions & 0 deletions tests/data/expected/main/graphql/split_graphql_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# generated by datamodel-codegen:
# filename: split
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import Literal, TypeAlias

from pydantic import BaseModel, Field

Boolean: TypeAlias = bool
"""
The `Boolean` scalar type represents `true` or `false`.
"""


Int: TypeAlias = int
"""
The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1.
"""


String: TypeAlias = str
"""
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
"""


class Baz(BaseModel):
quux: Int
typename__: Literal['Baz'] | None = Field('Baz', alias='__typename')


class Bar(BaseModel):
eggs: String
foo: Foo
typename__: Literal['Bar'] | None = Field('Bar', alias='__typename')


class Foo(BaseModel):
baz: Bar
id: Int
typename__: Literal['Foo'] | None = Field('Foo', alias='__typename')


Bar.update_forward_refs()
7 changes: 7 additions & 0 deletions tests/data/graphql/split/a.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
type Foo {
id: Int!
}

extend type Bar {
eggs: String!
}
4 changes: 4 additions & 0 deletions tests/data/graphql/split/b.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type Bar {
foo: Foo!
}

7 changes: 7 additions & 0 deletions tests/data/graphql/split/c.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
extend type Foo {
baz: Bar!
}

type Baz {
quux: Int!
}
11 changes: 11 additions & 0 deletions tests/main/graphql/test_main_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,14 @@ def test_main_graphql_union_snake_case_field(output_file: Path) -> None:
expected_file="union_snake_case_field.py",
extra_args=["--snake-case-field", "--output-model-type", "pydantic_v2.BaseModel"],
)


def test_main_graphql_split_graphql_schemas(output_file: Path) -> None:
"""Test GraphQL code generation with multiple schema files in a directory."""
run_main_and_assert(
input_path=GRAPHQL_DATA_PATH / "split",
output_path=output_file,
input_file_type="graphql",
assert_func=assert_file_content,
expected_file="split_graphql_schemas.py",
)
Loading