diff --git a/src/datamodel_code_generator/model/scalar.py b/src/datamodel_code_generator/model/scalar.py index dcddddb4c..89a8ee7ae 100644 --- a/src/datamodel_code_generator/model/scalar.py +++ b/src/datamodel_code_generator/model/scalar.py @@ -67,9 +67,9 @@ def __init__( # noqa: PLR0913 extra_template_data[scalar_name] = defaultdict(dict) # py_type - py_type = extra_template_data[scalar_name].get( + py_type = extra_template_data[reference.original_name].get( "py_type", - DEFAULT_GRAPHQL_SCALAR_TYPES.get(reference.name, DEFAULT_GRAPHQL_SCALAR_TYPE), + DEFAULT_GRAPHQL_SCALAR_TYPES.get(reference.original_name, DEFAULT_GRAPHQL_SCALAR_TYPE), ) extra_template_data[scalar_name]["py_type"] = py_type diff --git a/src/datamodel_code_generator/parser/graphql.py b/src/datamodel_code_generator/parser/graphql.py index b4e37abf8..d644ac232 100644 --- a/src/datamodel_code_generator/parser/graphql.py +++ b/src/datamodel_code_generator/parser/graphql.py @@ -479,8 +479,10 @@ def parse_input_object(self, input_graphql_object: graphql.GraphQLInputObjectTyp def parse_union(self, union_object: graphql.GraphQLUnionType) -> None: """Parse a GraphQL union type and add it to results.""" - fields = [self.data_model_field_type(name=type_.name, data_type=DataType()) for type_ in union_object.types] - + fields = [ + self.data_model_field_type(name=self.references[type_.name].name, data_type=DataType()) + for type_ in union_object.types + ] data_model_type = self.data_model_union_type( reference=self.references[union_object.name], fields=fields, diff --git a/tests/data/expected/main/graphql/simple_star_wars_class_name_prefix.py b/tests/data/expected/main/graphql/simple_star_wars_class_name_prefix.py new file mode 100644 index 000000000..891e908da --- /dev/null +++ b/tests/data/expected/main/graphql/simple_star_wars_class_name_prefix.py @@ -0,0 +1,158 @@ +# generated by datamodel-codegen: +# filename: simple-star-wars.graphql +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Literal, TypeAlias + +from pydantic import BaseModel, Field + +FooBoolean: TypeAlias = bool +""" +The `Boolean` scalar type represents `true` or `false`. +""" + + +FooID: TypeAlias = str +""" +The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID. +""" + + +FooInt: TypeAlias = int +""" +The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1. +""" + + +FooString: TypeAlias = str +""" +The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text. +""" + + +class FooFilm(BaseModel): + characters: list[FooPerson] + characters_ids: list[FooID] + director: FooString + episode_id: FooInt + id: FooID + opening_crawl: FooString + planets: list[FooPlanet] + planets_ids: list[FooID] + producer: FooString | None = None + release_date: FooString + species: list[FooSpecies] + species_ids: list[FooID] + starships: list[FooStarship] + starships_ids: list[FooID] + title: FooString + vehicles: list[FooVehicle] + vehicles_ids: list[FooID] + typename__: Literal['Film'] | None = Field('Film', alias='__typename') + + +class FooPerson(BaseModel): + birth_year: FooString | None = None + eye_color: FooString | None = None + films: list[FooFilm] + films_ids: list[FooID] + gender: FooString | None = None + hair_color: FooString | None = None + height: FooInt | None = None + homeworld: FooPlanet | None = None + homeworld_id: FooID | None = None + id: FooID + mass: FooInt | None = None + name: FooString + skin_color: FooString | None = None + species: list[FooSpecies] + species_ids: list[FooID] + starships: list[FooStarship] + starships_ids: list[FooID] + vehicles: list[FooVehicle] + vehicles_ids: list[FooID] + typename__: Literal['Person'] | None = Field('Person', alias='__typename') + + +class FooPlanet(BaseModel): + climate: FooString | None = None + diameter: FooString | None = None + films: list[FooFilm] + films_ids: list[FooID] + gravity: FooString | None = None + id: FooID + name: FooString + orbital_period: FooString | None = None + population: FooString | None = None + residents: list[FooPerson] + residents_ids: list[FooID] + rotation_period: FooString | None = None + surface_water: FooString | None = None + terrain: FooString | None = None + typename__: Literal['Planet'] | None = Field('Planet', alias='__typename') + + +class FooSpecies(BaseModel): + average_height: FooString | None = None + average_lifespan: FooString | None = None + classification: FooString | None = None + designation: FooString | None = None + eye_colors: FooString | None = None + films: list[FooFilm] + films_ids: list[FooID] + hair_colors: FooString | None = None + id: FooID + language: FooString | None = None + name: FooString + people: list[FooPerson] + people_ids: list[FooID] + skin_colors: FooString | None = None + typename__: Literal['Species'] | None = Field('Species', alias='__typename') + + +class FooStarship(BaseModel): + MGLT: FooString | None = None + cargo_capacity: FooString | None = None + consumables: FooString | None = None + cost_in_credits: FooString | None = None + crew: FooString | None = None + films: list[FooFilm] + films_ids: list[FooID] + hyperdrive_rating: FooString | None = None + id: FooID + length: FooString | None = None + manufacturer: FooString | None = None + max_atmosphering_speed: FooString | None = None + model: FooString | None = None + name: FooString + passengers: FooString | None = None + pilots: list[FooPerson] + pilots_ids: list[FooID] + starship_class: FooString | None = None + typename__: Literal['Starship'] | None = Field('Starship', alias='__typename') + + +class FooVehicle(BaseModel): + cargo_capacity: FooString | None = None + consumables: FooString | None = None + cost_in_credits: FooString | None = None + crew: FooString | None = None + films: list[FooFilm] + films_ids: list[FooID] + id: FooID + length: FooString | None = None + manufacturer: FooString | None = None + max_atmosphering_speed: FooString | None = None + model: FooString | None = None + name: FooString + passengers: FooString | None = None + pilots: list[FooPerson] + pilots_ids: list[FooID] + vehicle_class: FooString | None = None + typename__: Literal['Vehicle'] | None = Field('Vehicle', alias='__typename') + + +FooFilm.update_forward_refs() +FooPerson.update_forward_refs() diff --git a/tests/data/expected/parser/graphql/union_with_prefix.py b/tests/data/expected/parser/graphql/union_with_prefix.py new file mode 100644 index 000000000..5bee2bab4 --- /dev/null +++ b/tests/data/expected/parser/graphql/union_with_prefix.py @@ -0,0 +1,59 @@ +# generated by datamodel-codegen: +# filename: union.graphql +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Literal, TypeAlias, Union + +from pydantic import BaseModel, Field + +FooBoolean: TypeAlias = bool +""" +The `Boolean` scalar type represents `true` or `false`. +""" + + +FooID: TypeAlias = str +""" +The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID. +""" + + +FooInt: TypeAlias = int +""" +The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1. +""" + + +FooString: TypeAlias = str +""" +The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text. +""" + + +class FooIResource(BaseModel): + id: FooID + typename__: Literal['IResource'] | None = Field('IResource', alias='__typename') + + +class FooCar(FooIResource): + id: FooID + passengerCapacity: FooInt + typename__: Literal['Car'] | None = Field('Car', alias='__typename') + + +class FooEmployee(FooIResource): + firstName: FooString | None = None + id: FooID + lastName: FooString | None = None + typename__: Literal['Employee'] | None = Field('Employee', alias='__typename') + + +FooResource: TypeAlias = Union[ + 'FooCar', + 'FooEmployee', +] + + +FooTechnicalResource: TypeAlias = FooCar diff --git a/tests/main/graphql/test_main_graphql.py b/tests/main/graphql/test_main_graphql.py index 0ed095fd9..c33914869 100644 --- a/tests/main/graphql/test_main_graphql.py +++ b/tests/main/graphql/test_main_graphql.py @@ -762,6 +762,21 @@ def test_main_graphql_dataclass_frozen_keyword_only(output_file: Path) -> None: ) +def test_main_graphql_class_name_prefix(output_file: Path) -> None: + """Test GraphQL code generation with class name prefixing.""" + run_main_and_assert( + input_path=GRAPHQL_DATA_PATH / "simple-star-wars.graphql", + output_path=output_file, + input_file_type="graphql", + assert_func=assert_file_content, + expected_file="simple_star_wars_class_name_prefix.py", + extra_args=[ + "--class-name-prefix", + "Foo", + ], + ) + + def test_main_graphql_union_snake_case_field(output_file: Path) -> None: """Test that union type references are not converted to snake_case.""" run_main_and_assert( diff --git a/tests/parser/test_graphql.py b/tests/parser/test_graphql.py index 5c236bd45..af8825133 100644 --- a/tests/parser/test_graphql.py +++ b/tests/parser/test_graphql.py @@ -55,6 +55,18 @@ def test_graphql_union_commented(output_file: Path) -> None: ) +def test_graphql_union_with_prefix(output_file: Path) -> None: + """Test parsing GraphQL union with class name prefix (Unions should reference prefixed class names).""" + run_main_and_assert( + input_path=GRAPHQL_DATA_PATH / "union.graphql", + output_path=output_file, + input_file_type="graphql", + assert_func=assert_file_content, + expected_file="union_with_prefix.py", + extra_args=["--class-name-prefix", "Foo"], + ) + + @pytest.mark.parametrize( ("frozen_dataclasses", "keyword_only", "parser_dataclass_args", "kwargs_dataclass_args", "expected"), [