Skip to content

Commit e9ffb9f

Browse files
authored
Fix Pydantic @model_validator() usage for Pydantic 2.12 (#2472)
Co-authored-by: Antonio Spadaro <9268789+ilovelinux@users.noreply.github.com>
1 parent 10da583 commit e9ffb9f

2 files changed

Lines changed: 161 additions & 68 deletions

File tree

src/datamodel_code_generator/__main__.py

Lines changed: 90 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from enum import IntEnum
1414
from io import TextIOBase
1515
from pathlib import Path
16-
from typing import TYPE_CHECKING, Any, Optional, Union, cast
16+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
1717
from urllib.parse import ParseResult, urlparse
1818

1919
import argcomplete
@@ -136,48 +136,6 @@ def validate_url(cls, value: Any) -> ParseResult | None: # noqa: N805
136136
msg = f"This protocol doesn't support only http/https. --input={value}"
137137
raise Error(msg) # pragma: no cover
138138

139-
@model_validator()
140-
def validate_original_field_name_delimiter(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
141-
if values.get("original_field_name_delimiter") is not None and not values.get("snake_case_field"):
142-
msg = "`--original-field-name-delimiter` can not be used without `--snake-case-field`."
143-
raise Error(msg)
144-
return values
145-
146-
@model_validator()
147-
def validate_custom_file_header(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
148-
if values.get("custom_file_header") and values.get("custom_file_header_path"):
149-
msg = "`--custom_file_header_path` can not be used with `--custom_file_header`."
150-
raise Error(msg) # pragma: no cover
151-
return values
152-
153-
@model_validator()
154-
def validate_keyword_only(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
155-
output_model_type: DataModelType = values.get("output_model_type") # pyright: ignore[reportAssignmentType]
156-
python_target: PythonVersion = values.get("target_python_version") # pyright: ignore[reportAssignmentType]
157-
if (
158-
values.get("keyword_only")
159-
and output_model_type == DataModelType.DataclassesDataclass
160-
and not python_target.has_kw_only_dataclass
161-
):
162-
msg = f"`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher."
163-
raise Error(msg)
164-
return values
165-
166-
@model_validator()
167-
def validate_output_datetime_class(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
168-
datetime_class_type: DatetimeClassType | None = values.get("output_datetime_class")
169-
if (
170-
datetime_class_type
171-
and datetime_class_type is not DatetimeClassType.Datetime
172-
and values.get("output_model_type") == DataModelType.DataclassesDataclass
173-
):
174-
msg = (
175-
'`--output-datetime-class` only allows "datetime" for '
176-
f"`--output-model-type` {DataModelType.DataclassesDataclass.value}"
177-
)
178-
raise Error(msg)
179-
return values
180-
181139
# Pydantic 1.5.1 doesn't support each_item=True correctly
182140
@field_validator("http_headers", mode="before")
183141
def validate_http_headers(cls, value: Any) -> list[tuple[str, str]] | None: # noqa: N805
@@ -225,18 +183,104 @@ def validate_custom_formatters(cls, values: dict[str, Any]) -> dict[str, Any]:
225183
values["custom_formatters"] = custom_formatters.split(",")
226184
return values
227185

186+
__validate_output_datetime_class_err: ClassVar[str] = (
187+
'`--output-datetime-class` only allows "datetime" for '
188+
f"`--output-model-type` {DataModelType.DataclassesDataclass.value}"
189+
)
190+
191+
__validate_original_field_name_delimiter_err: ClassVar[str] = (
192+
"`--original-field-name-delimiter` can not be used without `--snake-case-field`."
193+
)
194+
195+
__validate_custom_file_header_err: ClassVar[str] = (
196+
"`--custom_file_header_path` can not be used with `--custom_file_header`."
197+
)
198+
__validate_keyword_only_err: ClassVar[str] = (
199+
f"`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher."
200+
)
201+
228202
if PYDANTIC_V2:
229203

230204
@model_validator() # pyright: ignore[reportArgumentType]
231-
def validate_root(self: Self) -> Self:
205+
def validate_output_datetime_class(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
206+
datetime_class_type: DatetimeClassType | None = self.output_datetime_class
207+
if (
208+
datetime_class_type
209+
and datetime_class_type is not DatetimeClassType.Datetime
210+
and self.output_model_type == DataModelType.DataclassesDataclass
211+
):
212+
raise Error(self.__validate_output_datetime_class_err)
213+
return self
214+
215+
@model_validator() # pyright: ignore[reportArgumentType]
216+
def validate_original_field_name_delimiter(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
217+
if self.original_field_name_delimiter is not None and not self.snake_case_field:
218+
raise Error(self.__validate_original_field_name_delimiter_err)
219+
return self
220+
221+
@model_validator() # pyright: ignore[reportArgumentType]
222+
def validate_custom_file_header(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
223+
if self.custom_file_header and self.custom_file_header_path:
224+
raise Error(self.__validate_custom_file_header_err)
225+
return self
226+
227+
@model_validator() # pyright: ignore[reportArgumentType]
228+
def validate_keyword_only(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
229+
output_model_type: DataModelType = self.output_model_type
230+
python_target: PythonVersion = self.target_python_version
231+
if (
232+
self.keyword_only
233+
and output_model_type == DataModelType.DataclassesDataclass
234+
and not python_target.has_kw_only_dataclass
235+
):
236+
raise Error(self.__validate_keyword_only_err)
237+
return self
238+
239+
@model_validator() # pyright: ignore[reportArgumentType]
240+
def validate_root(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
232241
if self.use_annotated:
233242
self.field_constraints = True
234243
return self
235244

236245
else:
237246

238-
@model_validator()
239-
def validate_root(cls, values: Any) -> Any: # noqa: N805
247+
@model_validator() # pyright: ignore[reportArgumentType]
248+
def validate_output_datetime_class(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
249+
datetime_class_type: DatetimeClassType | None = values.get("output_datetime_class")
250+
if (
251+
datetime_class_type
252+
and datetime_class_type is not DatetimeClassType.Datetime
253+
and values.get("output_model_type") == DataModelType.DataclassesDataclass
254+
):
255+
raise Error(cls.__validate_output_datetime_class_err)
256+
return values
257+
258+
@model_validator() # pyright: ignore[reportArgumentType]
259+
def validate_original_field_name_delimiter(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
260+
if values.get("original_field_name_delimiter") is not None and not values.get("snake_case_field"):
261+
raise Error(cls.__validate_original_field_name_delimiter_err)
262+
return values
263+
264+
@model_validator() # pyright: ignore[reportArgumentType]
265+
def validate_custom_file_header(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
266+
if values.get("custom_file_header") and values.get("custom_file_header_path"):
267+
raise Error(cls.__validate_custom_file_header_err)
268+
return values
269+
270+
@model_validator() # pyright: ignore[reportArgumentType]
271+
def validate_keyword_only(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
272+
output_model_type: DataModelType = cast("DataModelType", values.get("output_model_type"))
273+
python_target: PythonVersion = cast("PythonVersion", values.get("target_python_version"))
274+
if (
275+
values.get("keyword_only")
276+
and output_model_type == DataModelType.DataclassesDataclass
277+
and not python_target.has_kw_only_dataclass
278+
):
279+
raise Error(cls.__validate_keyword_only_err)
280+
return values
281+
282+
@model_validator() # pyright: ignore[reportArgumentType]
283+
def validate_root(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
240284
if values.get("use_annotated"):
241285
values["field_constraints"] = True
242286
return values

src/datamodel_code_generator/util.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,33 @@
11
from __future__ import annotations
22

33
import copy
4-
from typing import TYPE_CHECKING, Any, Callable, TypeVar
4+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, overload
55

66
import pydantic
77
from packaging import version
88
from pydantic import BaseModel as _BaseModel
99

10+
if TYPE_CHECKING:
11+
from pathlib import Path
12+
1013
PYDANTIC_VERSION = version.parse(pydantic.VERSION if isinstance(pydantic.VERSION, str) else str(pydantic.VERSION))
1114

1215
PYDANTIC_V2: bool = version.parse("2.0b3") <= PYDANTIC_VERSION
1316

14-
if TYPE_CHECKING:
15-
from pathlib import Path
16-
from typing import Literal
17-
17+
try:
18+
from yaml import CSafeLoader as SafeLoader
19+
except ImportError: # pragma: no cover
1820
from yaml import SafeLoader
1921

20-
def load_toml(path: Path) -> dict[str, Any]: ...
22+
try:
23+
from tomllib import load as load_tomllib # type: ignore[ignoreMissingImports]
24+
except ImportError:
25+
from tomli import load as load_tomllib # type: ignore[ignoreMissingImports]
2126

22-
else:
23-
try:
24-
from yaml import CSafeLoader as SafeLoader
25-
except ImportError: # pragma: no cover
26-
from yaml import SafeLoader
27-
28-
try:
29-
from tomllib import load as load_tomllib
30-
except ImportError:
31-
from tomli import load as load_tomllib
3227

33-
def load_toml(path: Path) -> dict[str, Any]:
34-
with path.open("rb") as f:
35-
return load_tomllib(f)
28+
def load_toml(path: Path) -> dict[str, Any]:
29+
with path.open("rb") as f:
30+
return load_tomllib(f)
3631

3732

3833
SafeLoaderTemp = copy.deepcopy(SafeLoader)
@@ -44,16 +39,70 @@ def load_toml(path: Path) -> dict[str, Any]:
4439
SafeLoader = SafeLoaderTemp
4540

4641
Model = TypeVar("Model", bound=_BaseModel)
42+
T = TypeVar("T")
43+
44+
45+
@overload
46+
def model_validator(
47+
mode: Literal["before"],
48+
) -> (
49+
Callable[[Callable[[type[Model], T], T]], Callable[[type[Model], T], T]]
50+
| Callable[[Callable[[Model, T], T]], Callable[[Model, T], T]]
51+
): ...
4752

4853

54+
@overload
4955
def model_validator(
56+
mode: Literal["after"],
57+
) -> (
58+
Callable[[Callable[[type[Model], T], T]], Callable[[type[Model], T], T]]
59+
| Callable[[Callable[[Model, T], T]], Callable[[Model, T], T]]
60+
| Callable[[Callable[[Model], Model]], Callable[[Model], Model]]
61+
): ...
62+
63+
64+
@overload
65+
def model_validator() -> (
66+
Callable[[Callable[[type[Model], T], T]], Callable[[type[Model], T], T]]
67+
| Callable[[Callable[[Model, T], T]], Callable[[Model, T], T]]
68+
| Callable[[Callable[[Model], Model]], Callable[[Model], Model]]
69+
): ...
70+
71+
72+
def model_validator( # pyright: ignore[reportInconsistentOverload]
5073
mode: Literal["before", "after"] = "after",
51-
) -> Callable[[Callable[[Model, Any], Any]], Callable[[Model, Any], Any]]:
52-
def inner(method: Callable[[Model, Any], Any]) -> Callable[[Model, Any], Any]:
74+
) -> (
75+
Callable[[Callable[[type[Model], T], T]], Callable[[type[Model], T], T]]
76+
| Callable[[Callable[[Model, T], T]], Callable[[Model, T], T]]
77+
| Callable[[Callable[[Model], Model]], Callable[[Model], Model]]
78+
):
79+
"""
80+
Decorator for model validators in Pydantic models.
81+
82+
Uses `model_validator` in Pydantic v2 and `root_validator` in Pydantic v1.
83+
84+
We support only `before` mode because `after` mode needs different validator
85+
implementation for v1 and v2.
86+
"""
87+
88+
@overload
89+
def inner(method: Callable[[type[Model], T], T]) -> Callable[[type[Model], T], T]: ...
90+
91+
@overload
92+
def inner(method: Callable[[Model, T], T]) -> Callable[[Model, T], T]: ...
93+
94+
@overload
95+
def inner(method: Callable[[Model], Model]) -> Callable[[Model], Model]: ...
96+
97+
def inner(
98+
method: Callable[[type[Model], T], T] | Callable[[Model, T], T] | Callable[[Model], Model],
99+
) -> Callable[[type[Model], T], T] | Callable[[Model, T], T] | Callable[[Model], Model]:
53100
if PYDANTIC_V2:
54101
from pydantic import model_validator as model_validator_v2 # noqa: PLC0415
55102

56-
return model_validator_v2(mode=mode)(method) # pyright: ignore[reportReturnType]
103+
if method == "before":
104+
return model_validator_v2(mode=mode)(classmethod(method)) # type: ignore[reportReturnType]
105+
return model_validator_v2(mode=mode)(method) # type: ignore[reportReturnType]
57106
from pydantic import root_validator # noqa: PLC0415
58107

59108
return root_validator(method, pre=mode == "before") # pyright: ignore[reportCallIssue]

0 commit comments

Comments
 (0)