Skip to content

Commit aa714fc

Browse files
authored
Fix msgspec discriminated union ClassVar generation and nullable type handling (#2620)
* Add support for --ignore-pyproject and --profile options in CLI * Fix discriminator handling for ClassVar and add nullable checks in msgspec * Add support for discriminator ClassVar with Meta constraints in msgspec
1 parent ed2de6f commit aa714fc

9 files changed

Lines changed: 256 additions & 18 deletions

File tree

src/datamodel_code_generator/model/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def imports(self) -> tuple[Import, ...]:
217217

218218
if has_optional:
219219
imports.append((IMPORT_OPTIONAL,))
220-
if self.use_annotated and self.annotated:
220+
if self.use_annotated and self.needs_annotated_import:
221221
imports.append((IMPORT_ANNOTATED,))
222222
return chain_as_tuple(*imports)
223223

@@ -269,6 +269,16 @@ def annotated(self) -> str | None:
269269
"""Get the Annotated type hint content, if any."""
270270
return None
271271

272+
@property
273+
def needs_annotated_import(self) -> bool:
274+
"""Check if this field requires the Annotated import."""
275+
return bool(self.annotated)
276+
277+
@property
278+
def needs_meta_import(self) -> bool: # pragma: no cover
279+
"""Check if this field requires the Meta import (msgspec only)."""
280+
return False
281+
272282
@property
273283
def has_default_factory(self) -> bool:
274284
"""Check if this field has a default_factory."""

src/datamodel_code_generator/model/msgspec.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def new_imports(self: DataModelFieldBaseT) -> tuple[Import, ...]:
8989
extra_imports.append(IMPORT_MSGSPEC_FIELD)
9090
if self.field and "lambda: convert" in self.field:
9191
extra_imports.append(IMPORT_MSGSPEC_CONVERT)
92-
if self.annotated:
92+
if isinstance(self, DataModelField) and self.needs_meta_import:
9393
extra_imports.append(IMPORT_MSGSPEC_META)
9494
if self.extras.get("is_classvar"):
9595
extra_imports.append(IMPORT_CLASSVAR)
@@ -281,7 +281,7 @@ def __str__(self) -> str:
281281
def type_hint(self) -> str:
282282
"""Return the type hint, using UnsetType for non-required non-nullable fields."""
283283
type_hint = super().type_hint
284-
if self._not_required and not self.nullable:
284+
if self._not_required and not self.nullable and not self.data_type.is_optional:
285285
return get_neither_required_nor_nullable_type(type_hint, self.data_type.use_union_operator)
286286
return type_hint
287287

@@ -294,12 +294,8 @@ def fall_back_to_nullable(self) -> bool:
294294
"""Return whether to fall back to nullable type instead of UnsetType."""
295295
return not self._not_required
296296

297-
@property
298-
def annotated(self) -> str | None:
299-
"""Get Annotated type hint with Meta constraints."""
300-
if not self.use_annotated: # pragma: no cover
301-
return None
302-
297+
def _get_meta_string(self) -> str | None:
298+
"""Compute Meta(...) string if there are any meta constraints."""
303299
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._META_FIELD_KEYS}
304300
has_type_constraints = self.data_type.kwargs is not None and len(self.data_type.kwargs) > 0
305301
if (
@@ -317,21 +313,50 @@ def annotated(self) -> str | None:
317313
}
318314

319315
meta_arguments = sorted(f"{k}={v!r}" for k, v in data.items() if v is not None)
320-
if not meta_arguments:
316+
return f"Meta({', '.join(meta_arguments)})" if meta_arguments else None
317+
318+
@property
319+
def annotated(self) -> str | None:
320+
"""Get Annotated type hint with Meta constraints."""
321+
if not self.use_annotated: # pragma: no cover
321322
return None
322323

323-
meta = f"Meta({', '.join(meta_arguments)})"
324+
meta = self._get_meta_string()
325+
326+
if not meta and not self.extras.get("is_classvar"):
327+
return None
324328

325329
if not self.required and not self.extras.get("is_classvar"):
326330
type_hint = self.data_type.type_hint
327331
annotated_type = f"Annotated[{type_hint}, {meta}]"
328332
return get_neither_required_nor_nullable_type(annotated_type, self.data_type.use_union_operator)
329333

330-
annotated_type = f"Annotated[{self.type_hint}, {meta}]"
334+
# Handle ClassVar case (for discriminator fields in msgspec)
331335
if self.extras.get("is_classvar"):
332-
annotated_type = f"ClassVar[{annotated_type}]"
336+
if meta:
337+
annotated_type = f"Annotated[{self.type_hint}, {meta}]"
338+
return f"ClassVar[{annotated_type}]"
339+
return f"ClassVar[{self.type_hint}]"
340+
341+
return f"Annotated[{self.type_hint}, {meta}]"
333342

334-
return annotated_type
343+
@property
344+
def needs_annotated_import(self) -> bool:
345+
"""Check if this field requires the Annotated import.
346+
347+
ClassVar fields without Meta constraints don't need Annotated.
348+
"""
349+
if not self.annotated:
350+
return False
351+
# ClassVar without Meta doesn't use Annotated
352+
if self.extras.get("is_classvar"):
353+
return self._get_meta_string() is not None
354+
return True
355+
356+
@property
357+
def needs_meta_import(self) -> bool:
358+
"""Check if this field requires the Meta import."""
359+
return self._get_meta_string() is not None
335360

336361
def _get_default_as_struct_model(self) -> str | None:
337362
"""Convert default value to Struct model using msgspec convert."""

src/datamodel_code_generator/parser/base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ def _create_discriminator_data_type(
991991
data_type = self.data_type(literals=type_names)
992992
return data_type
993993

994-
def __apply_discriminator_type( # noqa: PLR0912, PLR0915
994+
def __apply_discriminator_type( # noqa: PLR0912, PLR0914, PLR0915
995995
self,
996996
models: list[DataModel],
997997
imports: Imports,
@@ -1097,7 +1097,15 @@ def check_paths(
10971097
if field_name not in {discriminator_field.original_name, discriminator_field.name}:
10981098
continue
10991099
literals = discriminator_field.data_type.literals
1100-
if len(literals) == 1 and literals[0] == (type_names[0] if type_names else None):
1100+
const_value = discriminator_field.extras.get("const")
1101+
expected_value = type_names[0] if type_names else None
1102+
1103+
# Check if literals match (existing behavior)
1104+
literals_match = len(literals) == 1 and literals[0] == expected_value
1105+
# Check if const value matches (for msgspec with type: string + const)
1106+
const_match = const_value is not None and const_value == expected_value
1107+
1108+
if literals_match:
11011109
has_one_literal = True
11021110
if isinstance(discriminator_model, msgspec_model.Struct): # pragma: no cover
11031111
discriminator_model.add_base_class_kwarg("tag_field", f"'{field_name}'")
@@ -1106,6 +1114,14 @@ def check_paths(
11061114
# Found the discriminator field, no need to keep looking
11071115
break
11081116

1117+
# For msgspec with const value but no literal (type: string + const case)
1118+
if const_match and isinstance(discriminator_model, msgspec_model.Struct): # pragma: no cover
1119+
has_one_literal = True
1120+
discriminator_model.add_base_class_kwarg("tag_field", f"'{field_name}'")
1121+
discriminator_model.add_base_class_kwarg("tag", repr(const_value))
1122+
discriminator_field.extras["is_classvar"] = True
1123+
break
1124+
11091125
enum_source: Enum | None = None
11101126
if self.use_enum_values_in_discriminator:
11111127
enum_source = ( # pragma: no cover

src/datamodel_code_generator/parser/jsonschema.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,15 @@ def ref_type(self) -> JSONReference | None:
392392

393393
@cached_property
394394
def type_has_null(self) -> bool:
395-
"""Check if the type list contains null."""
396-
return isinstance(self.type, list) and "null" in self.type
395+
"""Check if the type list or oneOf/anyOf contains null."""
396+
if isinstance(self.type, list) and "null" in self.type:
397+
return True
398+
for item in self.oneOf + self.anyOf:
399+
if item.type == "null":
400+
return True
401+
if isinstance(item.type, list) and "null" in item.type:
402+
return True
403+
return False
397404

398405
@cached_property
399406
def has_multiple_types(self) -> bool:
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_with_meta_msgspec.json
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Annotated, ClassVar, Literal, Union
8+
9+
from msgspec import UNSET, Meta, Struct, UnsetType
10+
11+
12+
class SystemMessage(Struct, tag_field='role', tag='system'):
13+
role: ClassVar[Annotated[Literal['system'], Meta(title='Message Role')]]
14+
content: str
15+
16+
17+
class UserMessage(Struct, tag_field='role', tag='user'):
18+
role: ClassVar[Annotated[Literal['user'], Meta(title='Message Role')]]
19+
content: str
20+
21+
22+
class Model(Struct):
23+
message: Union[SystemMessage, UserMessage, UnsetType] = UNSET
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_with_type_string.json
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import ClassVar, Literal, Union
8+
9+
from msgspec import UNSET, Struct, UnsetType
10+
11+
12+
class SystemMessage(Struct, tag_field='role', tag='system'):
13+
role: ClassVar[Literal['system']]
14+
content: str
15+
16+
17+
class UserMessage(Struct, tag_field='role', tag='user'):
18+
role: ClassVar[Literal['user']]
19+
content: str
20+
21+
22+
class Model(Struct):
23+
message: Union[SystemMessage, UserMessage, UnsetType] = UNSET
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{
2+
"$defs": {
3+
"SystemMessage": {
4+
"type": "object",
5+
"properties": {
6+
"role": {
7+
"type": "string",
8+
"const": "system",
9+
"default": "system",
10+
"title": "Message Role"
11+
},
12+
"content": {
13+
"type": "string"
14+
}
15+
},
16+
"required": ["role", "content"]
17+
},
18+
"UserMessage": {
19+
"type": "object",
20+
"properties": {
21+
"role": {
22+
"type": "string",
23+
"const": "user",
24+
"default": "user",
25+
"title": "Message Role"
26+
},
27+
"content": {
28+
"type": "string"
29+
}
30+
},
31+
"required": ["role", "content"]
32+
}
33+
},
34+
"type": "object",
35+
"properties": {
36+
"message": {
37+
"discriminator": {
38+
"propertyName": "role",
39+
"mapping": {
40+
"system": "#/$defs/SystemMessage",
41+
"user": "#/$defs/UserMessage"
42+
}
43+
},
44+
"oneOf": [
45+
{"$ref": "#/$defs/SystemMessage"},
46+
{"$ref": "#/$defs/UserMessage"}
47+
]
48+
}
49+
}
50+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
{
2+
"$defs": {
3+
"SystemMessage": {
4+
"type": "object",
5+
"properties": {
6+
"role": {
7+
"type": "string",
8+
"const": "system",
9+
"default": "system"
10+
},
11+
"content": {
12+
"type": "string"
13+
}
14+
},
15+
"required": ["role", "content"]
16+
},
17+
"UserMessage": {
18+
"type": "object",
19+
"properties": {
20+
"role": {
21+
"type": "string",
22+
"const": "user",
23+
"default": "user"
24+
},
25+
"content": {
26+
"type": "string"
27+
}
28+
},
29+
"required": ["role", "content"]
30+
}
31+
},
32+
"type": "object",
33+
"properties": {
34+
"message": {
35+
"discriminator": {
36+
"propertyName": "role",
37+
"mapping": {
38+
"system": "#/$defs/SystemMessage",
39+
"user": "#/$defs/UserMessage"
40+
}
41+
},
42+
"oneOf": [
43+
{"$ref": "#/$defs/SystemMessage"},
44+
{"$ref": "#/$defs/UserMessage"}
45+
]
46+
}
47+
}
48+
}

tests/main/jsonschema/test_main_jsonschema.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,6 +2485,42 @@ def test_main_jsonschema_openapi_keyword_only_msgspec_with_extra_data(tmp_path:
24852485
assert_file_content(output_file, "discriminator_literals_msgspec_keyword_only_omit_defaults.py")
24862486

24872487

2488+
@MSGSPEC_LEGACY_BLACK_SKIP
2489+
def test_main_msgspec_discriminator_with_type_string(output_file: Path) -> None:
2490+
"""Test msgspec Struct generation with discriminator using type: string + const."""
2491+
run_main_and_assert(
2492+
input_path=JSON_SCHEMA_DATA_PATH / "discriminator_with_type_string.json",
2493+
output_path=output_file,
2494+
input_file_type="jsonschema",
2495+
assert_func=assert_file_content,
2496+
expected_file="discriminator_with_type_string_msgspec.py",
2497+
extra_args=[
2498+
"--output-model-type",
2499+
"msgspec.Struct",
2500+
"--target-python-version",
2501+
"3.10",
2502+
],
2503+
)
2504+
2505+
2506+
@MSGSPEC_LEGACY_BLACK_SKIP
2507+
def test_main_msgspec_discriminator_with_meta(output_file: Path) -> None:
2508+
"""Test msgspec Struct generation with discriminator ClassVar having Meta constraints."""
2509+
run_main_and_assert(
2510+
input_path=JSON_SCHEMA_DATA_PATH / "discriminator_with_meta_msgspec.json",
2511+
output_path=output_file,
2512+
input_file_type="jsonschema",
2513+
assert_func=assert_file_content,
2514+
expected_file="discriminator_with_meta_msgspec.py",
2515+
extra_args=[
2516+
"--output-model-type",
2517+
"msgspec.Struct",
2518+
"--target-python-version",
2519+
"3.10",
2520+
],
2521+
)
2522+
2523+
24882524
@MSGSPEC_LEGACY_BLACK_SKIP
24892525
def test_main_msgspec_null_field(output_file: Path) -> None:
24902526
"""Test msgspec Struct generation with null type fields."""

0 commit comments

Comments
 (0)