Skip to content

Commit efd679a

Browse files
authored
Add compile/exec validation for generated Python code in tests (#2665)
* Add code validation for generated files and introduce pytest options for execution * Add validation for generated code execution and improve target version checks * Add option to skip code validation in OpenAPI tests * Add coverage exclusions for code validation functions in conftest.py * Add coverage exclusions for conditional checks in conftest.py * Refactor output model argument naming to use '--output-model-type' across tests and validation functions * Refactor output model argument naming to use '--output-model-type' across tests and validation functions * Add target Python version argument to Pydantic v2 collision test * Refactor Pydantic v2 collision test to use current Python version for target argument * Refactor generated code validation to improve exception handling and remove unused imports * Refactor generated code validation to handle directory inputs correctly * Refactor validation logic to use DataModelType for Pydantic version checks
1 parent 3379183 commit efd679a

7 files changed

Lines changed: 540 additions & 236 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ dev = [
8080
test = [
8181
"freezegun; python_version<'3.10'",
8282
"inline-snapshot>=0.31.1",
83+
"msgspec>=0.18",
8384
"pytest>=6.1",
8485
"pytest>=8.3.4",
8586
"pytest-benchmark",
@@ -144,6 +145,7 @@ conflicts = [
144145
[
145146
{ group = "pydantic1" },
146147
{ group = "pkg-meta" },
148+
{ group = "dev" },
147149
],
148150
]
149151

tests/conftest.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import difflib
66
import inspect
77
import sys
8+
import time
89
from typing import TYPE_CHECKING, Any, Protocol
910

1011
import pytest
@@ -16,6 +17,55 @@
1617
from collections.abc import Callable
1718
from pathlib import Path
1819

20+
21+
class CodeValidationStats:
22+
"""Track code validation statistics."""
23+
24+
def __init__(self) -> None:
25+
"""Initialize statistics counters."""
26+
self.compile_count = 0
27+
self.compile_time = 0.0
28+
self.exec_count = 0
29+
self.exec_time = 0.0
30+
self.errors: list[tuple[str, str]] = []
31+
32+
def record_compile(self, elapsed: float) -> None:
33+
"""Record a compile operation."""
34+
self.compile_count += 1
35+
self.compile_time += elapsed
36+
37+
def record_exec(self, elapsed: float) -> None:
38+
"""Record an exec operation."""
39+
self.exec_count += 1
40+
self.exec_time += elapsed
41+
42+
def record_error(self, file_path: str, error: str) -> None: # pragma: no cover
43+
"""Record a validation error."""
44+
self.errors.append((file_path, error))
45+
46+
47+
_validation_stats = CodeValidationStats()
48+
49+
50+
def pytest_terminal_summary(terminalreporter: Any, exitstatus: int, config: pytest.Config) -> None: # noqa: ARG001 # pragma: no cover
51+
"""Print code validation summary at the end of test run."""
52+
if _validation_stats.compile_count > 0:
53+
terminalreporter.write_sep("=", "Code Validation Summary")
54+
terminalreporter.write_line(
55+
f"Compiled {_validation_stats.compile_count} files in {_validation_stats.compile_time:.3f}s "
56+
f"(avg: {_validation_stats.compile_time / _validation_stats.compile_count * 1000:.2f}ms)"
57+
)
58+
if _validation_stats.exec_count > 0:
59+
terminalreporter.write_line(
60+
f"Executed {_validation_stats.exec_count} files in {_validation_stats.exec_time:.3f}s "
61+
f"(avg: {_validation_stats.exec_time / _validation_stats.exec_count * 1000:.2f}ms)"
62+
)
63+
if _validation_stats.errors:
64+
terminalreporter.write_line(f"\nValidation errors: {len(_validation_stats.errors)}")
65+
for file_path, error in _validation_stats.errors:
66+
terminalreporter.write_line(f" {file_path}: {error}")
67+
68+
1969
if sys.version_info >= (3, 10):
2070
from datetime import datetime, timezone
2171

@@ -358,3 +408,33 @@ def _preload_heavy_modules() -> None:
358408
import isort # noqa: PLC0415, F401
359409

360410
import datamodel_code_generator # noqa: PLC0415, F401
411+
412+
413+
def validate_generated_code(
414+
code: str,
415+
file_path: str,
416+
*,
417+
do_exec: bool = False,
418+
) -> None:
419+
"""Validate generated code by compiling and optionally executing it.
420+
421+
Args:
422+
code: The generated Python code to validate.
423+
file_path: Path to the file (for error reporting).
424+
do_exec: Whether to execute the code after compiling (default: False).
425+
"""
426+
try:
427+
start = time.perf_counter()
428+
compiled = compile(code, file_path, "exec")
429+
_validation_stats.record_compile(time.perf_counter() - start)
430+
431+
if do_exec:
432+
start = time.perf_counter()
433+
exec(compiled, {})
434+
_validation_stats.record_exec(time.perf_counter() - start)
435+
except SyntaxError as e: # pragma: no cover
436+
_validation_stats.record_error(file_path, f"SyntaxError: {e}")
437+
raise
438+
except Exception as e: # pragma: no cover
439+
_validation_stats.record_error(file_path, f"{type(e).__name__}: {e}")
440+
raise

tests/main/conftest.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from __future__ import annotations
44

5+
import importlib.util
56
import inspect
67
import shutil
78
import sys
9+
import time
810
from argparse import Namespace
911
from collections.abc import Callable, Generator, Sequence
1012
from pathlib import Path
@@ -14,8 +16,17 @@
1416
import pytest
1517
from packaging import version
1618

19+
from datamodel_code_generator import DataModelType
1720
from datamodel_code_generator.__main__ import Exit, main
18-
from tests.conftest import AssertFileContent, assert_directory_content, assert_output, freeze_time
21+
from datamodel_code_generator.util import PYDANTIC_V2
22+
from tests.conftest import (
23+
AssertFileContent,
24+
_validation_stats,
25+
assert_directory_content,
26+
assert_output,
27+
freeze_time,
28+
validate_generated_code,
29+
)
1930

2031
InputFileTypeLiteral = Literal["auto", "openapi", "jsonschema", "json", "yaml", "dict", "csv", "graphql"]
2132
CopyFilesMapping = Sequence[tuple[Path, Path]]
@@ -210,6 +221,8 @@ def run_main_and_assert( # noqa: PLR0912
210221
# stdin options
211222
stdin_path: Path | None = None,
212223
monkeypatch: pytest.MonkeyPatch | None = None,
224+
# Code validation options
225+
skip_code_validation: bool = False,
213226
) -> None:
214227
"""Execute main() and assert output.
215228
@@ -335,6 +348,119 @@ def run_main_and_assert( # noqa: PLR0912
335348
expected_file = f"{func_name}.py"
336349
assert_func(output_path, expected_file, transform=transform)
337350

351+
if output_path is not None and not skip_code_validation:
352+
_validate_output_files(output_path, extra_args)
353+
354+
355+
def _get_argument_value(arguments: Sequence[str] | None, argument_name: str) -> str | None:
356+
"""Extract argument value from arguments."""
357+
if arguments is None:
358+
return None
359+
argument_list = list(arguments)
360+
for index, argument in enumerate(argument_list):
361+
if argument == argument_name and index + 1 < len(argument_list):
362+
return argument_list[index + 1]
363+
return None
364+
365+
366+
def _parse_target_version(extra_arguments: Sequence[str] | None) -> tuple[int, int] | None:
367+
"""Parse target Python version from arguments."""
368+
if (target_version := _get_argument_value(extra_arguments, "--target-python-version")) is None:
369+
return None
370+
try:
371+
return tuple(int(part) for part in target_version.split(".")) # type: ignore[return-value]
372+
except ValueError: # pragma: no cover
373+
return None
374+
375+
376+
def _should_skip_compile(extra_arguments: Sequence[str] | None) -> bool:
377+
"""Check if compile should be skipped when target version > runtime version."""
378+
if (target_version := _parse_target_version(extra_arguments)) is None:
379+
return False
380+
return target_version > sys.version_info[:2]
381+
382+
383+
def _should_skip_exec(extra_arguments: Sequence[str] | None) -> bool:
384+
"""Check if exec should be skipped based on model type, pydantic version, and Python version."""
385+
output_model_type = _get_argument_value(extra_arguments, "--output-model-type")
386+
is_pydantic_v1 = output_model_type is None or output_model_type == DataModelType.PydanticBaseModel.value
387+
if (is_pydantic_v1 and PYDANTIC_V2) or (
388+
output_model_type == DataModelType.PydanticV2BaseModel.value and not PYDANTIC_V2
389+
):
390+
return True
391+
if (target_version := _parse_target_version(extra_arguments)) is None:
392+
return True
393+
if target_version != sys.version_info[:2]:
394+
return True
395+
return _get_argument_value(extra_arguments, "--base-class") is not None
396+
397+
398+
def _validate_output_files(output_path: Path, extra_arguments: Sequence[str] | None = None) -> None:
399+
"""Validate generated Python files by compiling/executing them."""
400+
if _should_skip_compile(extra_arguments):
401+
return
402+
should_exec = not _should_skip_exec(extra_arguments)
403+
if output_path.is_file() and output_path.suffix == ".py":
404+
validate_generated_code(output_path.read_text(encoding="utf-8"), str(output_path), do_exec=should_exec)
405+
elif output_path.is_dir(): # pragma: no cover
406+
for python_file in output_path.rglob("*.py"):
407+
validate_generated_code(python_file.read_text(encoding="utf-8"), str(python_file), do_exec=False)
408+
if should_exec: # pragma: no cover
409+
_import_package(output_path)
410+
411+
412+
def _import_package(output_path: Path) -> None: # pragma: no cover # noqa: PLR0912
413+
"""Import generated packages to validate they can be loaded."""
414+
if (output_path / "__init__.py").exists():
415+
packages = [(output_path.parent, output_path.name)]
416+
else:
417+
packages = [
418+
(output_path, directory.name)
419+
for directory in output_path.iterdir()
420+
if directory.is_dir() and (directory / "__init__.py").exists()
421+
]
422+
if not packages:
423+
return
424+
425+
imported_modules: list[str] = []
426+
start_time = time.perf_counter()
427+
try:
428+
for parent_directory, package_name in packages:
429+
package_path = parent_directory / package_name
430+
sys.path.insert(0, str(parent_directory))
431+
spec = importlib.util.spec_from_file_location(
432+
package_name, package_path / "__init__.py", submodule_search_locations=[str(package_path)]
433+
)
434+
if spec is None or spec.loader is None:
435+
continue
436+
module = importlib.util.module_from_spec(spec)
437+
sys.modules[package_name] = module
438+
imported_modules.append(package_name)
439+
spec.loader.exec_module(module)
440+
441+
for python_file in package_path.rglob("*.py"):
442+
if python_file.name == "__init__.py":
443+
continue
444+
relative_path = python_file.relative_to(package_path)
445+
module_name = f"{package_name}.{'.'.join(relative_path.with_suffix('').parts)}"
446+
submodule_spec = importlib.util.spec_from_file_location(module_name, python_file)
447+
if submodule_spec is None or submodule_spec.loader is None:
448+
continue
449+
submodule = importlib.util.module_from_spec(submodule_spec)
450+
sys.modules[module_name] = submodule
451+
imported_modules.append(module_name)
452+
submodule_spec.loader.exec_module(submodule)
453+
_validation_stats.record_exec(time.perf_counter() - start_time)
454+
except Exception as exception:
455+
_validation_stats.record_error(str(output_path), f"{type(exception).__name__}: {exception}")
456+
raise
457+
finally:
458+
for parent_directory, _ in packages:
459+
if str(parent_directory) in sys.path:
460+
sys.path.remove(str(parent_directory))
461+
for module_name in imported_modules:
462+
sys.modules.pop(module_name, None)
463+
338464

339465
def run_main_url_and_assert(
340466
*,
@@ -361,3 +487,5 @@ def run_main_url_and_assert(
361487
return_code = _run_main_url(url, output_path, input_file_type, extra_args=extra_args)
362488
_assert_exit_code(return_code, Exit.OK, f"URL: {url}")
363489
assert_func(output_path, expected_file, transform=transform)
490+
491+
_validate_output_files(output_path, extra_args)

tests/main/graphql/test_main_graphql.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_main_graphql_simple_star_wars(output_model: str, expected_output: str,
3939
input_file_type="graphql",
4040
assert_func=assert_file_content,
4141
expected_file=expected_output,
42-
extra_args=["--output-model", output_model],
42+
extra_args=["--output-model-type", output_model],
4343
)
4444

4545

@@ -323,7 +323,7 @@ def test_main_graphql_dataclass_arguments(output_file: Path) -> None:
323323
assert_func=assert_file_content,
324324
expected_file="simple_star_wars_dataclass_arguments.py",
325325
extra_args=[
326-
"--output-model",
326+
"--output-model-type",
327327
"dataclasses.dataclass",
328328
"--dataclass-arguments",
329329
'{"slots": true, "order": true}',
@@ -347,7 +347,7 @@ def test_main_graphql_dataclass_arguments_with_pydantic(output_file: Path) -> No
347347
assert_func=assert_file_content,
348348
expected_file="simple_star_wars.py",
349349
extra_args=[
350-
"--output-model",
350+
"--output-model-type",
351351
"pydantic.BaseModel",
352352
"--dataclass-arguments",
353353
'{"slots": true, "order": true}',
@@ -372,7 +372,7 @@ def test_main_graphql_dataclass_frozen_keyword_only(output_file: Path) -> None:
372372
assert_func=assert_file_content,
373373
expected_file="simple_star_wars_dataclass_frozen_kw_only.py",
374374
extra_args=[
375-
"--output-model",
375+
"--output-model-type",
376376
"dataclasses.dataclass",
377377
"--frozen",
378378
"--keyword-only",

tests/main/jsonschema/test_main_jsonschema.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def test_main_jsonschema_dataclass_arguments_with_pydantic(output_file: Path) ->
270270
assert_func=assert_file_content,
271271
expected_file="general.py",
272272
extra_args=[
273-
"--output-model",
273+
"--output-model-type",
274274
"pydantic.BaseModel",
275275
"--dataclass-arguments",
276276
'{"slots": true, "order": true}',
@@ -291,7 +291,7 @@ def test_main_jsonschema_dataclass_frozen_keyword_only(output_file: Path) -> Non
291291
assert_func=assert_file_content,
292292
expected_file="general_dataclass_frozen_kw_only.py",
293293
extra_args=[
294-
"--output-model",
294+
"--output-model-type",
295295
"dataclasses.dataclass",
296296
"--frozen",
297297
"--keyword-only",
@@ -403,7 +403,7 @@ def test_main_null_and_array(output_model: str, expected_output: str, output_fil
403403
input_file_type="jsonschema",
404404
assert_func=assert_file_content,
405405
expected_file=expected_output,
406-
extra_args=["--output-model", output_model],
406+
extra_args=["--output-model-type", output_model],
407407
)
408408

409409

@@ -443,7 +443,7 @@ def test_main_complicated_enum_default_member(
443443
output_model: str, expected_output: str, option: str | None, output_file: Path
444444
) -> None:
445445
"""Test complicated enum with default member."""
446-
extra_args = [a for a in [option, "--output-model", output_model] if a]
446+
extra_args = [a for a in [option, "--output-model-type", output_model] if a]
447447
run_main_and_assert(
448448
input_path=JSON_SCHEMA_DATA_PATH / "complicated_enum.json",
449449
output_path=output_file,
@@ -1672,7 +1672,7 @@ def test_main_jsonschema_combine_any_of_object(
16721672
union_mode: str | None, output_model: str, expected_output: str, output_file: Path
16731673
) -> None:
16741674
"""Test combining anyOf with objects."""
1675-
extra_args = ["--output-model", output_model]
1675+
extra_args = ["--output-model-type", output_model]
16761676
if union_mode is not None:
16771677
extra_args.extend(["--union-mode", union_mode])
16781678
run_main_and_assert(
@@ -1689,9 +1689,9 @@ def test_main_jsonschema_combine_any_of_object(
16891689
@pytest.mark.parametrize(
16901690
("extra_args", "expected_file"),
16911691
[
1692-
(["--output-model", "pydantic_v2.BaseModel"], "jsonschema_root_model_ordering.py"),
1692+
(["--output-model-type", "pydantic_v2.BaseModel"], "jsonschema_root_model_ordering.py"),
16931693
(
1694-
["--output-model", "pydantic_v2.BaseModel", "--keep-model-order"],
1694+
["--output-model-type", "pydantic_v2.BaseModel", "--keep-model-order"],
16951695
"jsonschema_root_model_ordering_keep_model_order.py",
16961696
),
16971697
],
@@ -1745,7 +1745,7 @@ def test_main_jsonschema_field_extras_field_include_all_keys(
17451745
assert_func=assert_file_content,
17461746
expected_file=expected_output,
17471747
extra_args=[
1748-
"--output-model",
1748+
"--output-model-type",
17491749
output_model,
17501750
"--field-include-all-keys",
17511751
"--field-extra-keys-without-x-prefix",
@@ -1778,7 +1778,7 @@ def test_main_jsonschema_field_extras_field_extra_keys(
17781778
assert_func=assert_file_content,
17791779
expected_file=expected_output,
17801780
extra_args=[
1781-
"--output-model",
1781+
"--output-model-type",
17821782
output_model,
17831783
"--field-extra-keys",
17841784
"key2",
@@ -1810,7 +1810,7 @@ def test_main_jsonschema_field_extras(output_model: str, expected_output: str, o
18101810
input_file_type="jsonschema",
18111811
assert_func=assert_file_content,
18121812
expected_file=expected_output,
1813-
extra_args=["--output-model", output_model],
1813+
extra_args=["--output-model-type", output_model],
18141814
)
18151815

18161816

0 commit comments

Comments
 (0)