22
33from __future__ import annotations
44
5+ import importlib .util
56import inspect
67import shutil
78import sys
9+ import time
810from argparse import Namespace
911from collections .abc import Callable , Generator , Sequence
1012from pathlib import Path
1416import pytest
1517from packaging import version
1618
19+ from datamodel_code_generator import DataModelType
1720from datamodel_code_generator .__main__ import Exit , main
18- from tests .conftest import AssertFileContent , assert_directory_content , assert_output , freeze_time
21+ from datamodel_code_generator .util import PYDANTIC_V2
22+ from tests .conftest import (
23+ AssertFileContent ,
24+ _validation_stats ,
25+ assert_directory_content ,
26+ assert_output ,
27+ freeze_time ,
28+ validate_generated_code ,
29+ )
1930
2031InputFileTypeLiteral = Literal ["auto" , "openapi" , "jsonschema" , "json" , "yaml" , "dict" , "csv" , "graphql" ]
2132CopyFilesMapping = Sequence [tuple [Path , Path ]]
@@ -210,6 +221,8 @@ def run_main_and_assert( # noqa: PLR0912
210221 # stdin options
211222 stdin_path : Path | None = None ,
212223 monkeypatch : pytest .MonkeyPatch | None = None ,
224+ # Code validation options
225+ skip_code_validation : bool = False ,
213226) -> None :
214227 """Execute main() and assert output.
215228
@@ -335,6 +348,119 @@ def run_main_and_assert( # noqa: PLR0912
335348 expected_file = f"{ func_name } .py"
336349 assert_func (output_path , expected_file , transform = transform )
337350
351+ if output_path is not None and not skip_code_validation :
352+ _validate_output_files (output_path , extra_args )
353+
354+
355+ def _get_argument_value (arguments : Sequence [str ] | None , argument_name : str ) -> str | None :
356+ """Extract argument value from arguments."""
357+ if arguments is None :
358+ return None
359+ argument_list = list (arguments )
360+ for index , argument in enumerate (argument_list ):
361+ if argument == argument_name and index + 1 < len (argument_list ):
362+ return argument_list [index + 1 ]
363+ return None
364+
365+
366+ def _parse_target_version (extra_arguments : Sequence [str ] | None ) -> tuple [int , int ] | None :
367+ """Parse target Python version from arguments."""
368+ if (target_version := _get_argument_value (extra_arguments , "--target-python-version" )) is None :
369+ return None
370+ try :
371+ return tuple (int (part ) for part in target_version .split ("." )) # type: ignore[return-value]
372+ except ValueError : # pragma: no cover
373+ return None
374+
375+
376+ def _should_skip_compile (extra_arguments : Sequence [str ] | None ) -> bool :
377+ """Check if compile should be skipped when target version > runtime version."""
378+ if (target_version := _parse_target_version (extra_arguments )) is None :
379+ return False
380+ return target_version > sys .version_info [:2 ]
381+
382+
383+ def _should_skip_exec (extra_arguments : Sequence [str ] | None ) -> bool :
384+ """Check if exec should be skipped based on model type, pydantic version, and Python version."""
385+ output_model_type = _get_argument_value (extra_arguments , "--output-model-type" )
386+ is_pydantic_v1 = output_model_type is None or output_model_type == DataModelType .PydanticBaseModel .value
387+ if (is_pydantic_v1 and PYDANTIC_V2 ) or (
388+ output_model_type == DataModelType .PydanticV2BaseModel .value and not PYDANTIC_V2
389+ ):
390+ return True
391+ if (target_version := _parse_target_version (extra_arguments )) is None :
392+ return True
393+ if target_version != sys .version_info [:2 ]:
394+ return True
395+ return _get_argument_value (extra_arguments , "--base-class" ) is not None
396+
397+
398+ def _validate_output_files (output_path : Path , extra_arguments : Sequence [str ] | None = None ) -> None :
399+ """Validate generated Python files by compiling/executing them."""
400+ if _should_skip_compile (extra_arguments ):
401+ return
402+ should_exec = not _should_skip_exec (extra_arguments )
403+ if output_path .is_file () and output_path .suffix == ".py" :
404+ validate_generated_code (output_path .read_text (encoding = "utf-8" ), str (output_path ), do_exec = should_exec )
405+ elif output_path .is_dir (): # pragma: no cover
406+ for python_file in output_path .rglob ("*.py" ):
407+ validate_generated_code (python_file .read_text (encoding = "utf-8" ), str (python_file ), do_exec = False )
408+ if should_exec : # pragma: no cover
409+ _import_package (output_path )
410+
411+
412+ def _import_package (output_path : Path ) -> None : # pragma: no cover # noqa: PLR0912
413+ """Import generated packages to validate they can be loaded."""
414+ if (output_path / "__init__.py" ).exists ():
415+ packages = [(output_path .parent , output_path .name )]
416+ else :
417+ packages = [
418+ (output_path , directory .name )
419+ for directory in output_path .iterdir ()
420+ if directory .is_dir () and (directory / "__init__.py" ).exists ()
421+ ]
422+ if not packages :
423+ return
424+
425+ imported_modules : list [str ] = []
426+ start_time = time .perf_counter ()
427+ try :
428+ for parent_directory , package_name in packages :
429+ package_path = parent_directory / package_name
430+ sys .path .insert (0 , str (parent_directory ))
431+ spec = importlib .util .spec_from_file_location (
432+ package_name , package_path / "__init__.py" , submodule_search_locations = [str (package_path )]
433+ )
434+ if spec is None or spec .loader is None :
435+ continue
436+ module = importlib .util .module_from_spec (spec )
437+ sys .modules [package_name ] = module
438+ imported_modules .append (package_name )
439+ spec .loader .exec_module (module )
440+
441+ for python_file in package_path .rglob ("*.py" ):
442+ if python_file .name == "__init__.py" :
443+ continue
444+ relative_path = python_file .relative_to (package_path )
445+ module_name = f"{ package_name } .{ '.' .join (relative_path .with_suffix ('' ).parts )} "
446+ submodule_spec = importlib .util .spec_from_file_location (module_name , python_file )
447+ if submodule_spec is None or submodule_spec .loader is None :
448+ continue
449+ submodule = importlib .util .module_from_spec (submodule_spec )
450+ sys .modules [module_name ] = submodule
451+ imported_modules .append (module_name )
452+ submodule_spec .loader .exec_module (submodule )
453+ _validation_stats .record_exec (time .perf_counter () - start_time )
454+ except Exception as exception :
455+ _validation_stats .record_error (str (output_path ), f"{ type (exception ).__name__ } : { exception } " )
456+ raise
457+ finally :
458+ for parent_directory , _ in packages :
459+ if str (parent_directory ) in sys .path :
460+ sys .path .remove (str (parent_directory ))
461+ for module_name in imported_modules :
462+ sys .modules .pop (module_name , None )
463+
338464
339465def run_main_url_and_assert (
340466 * ,
@@ -361,3 +487,5 @@ def run_main_url_and_assert(
361487 return_code = _run_main_url (url , output_path , input_file_type , extra_args = extra_args )
362488 _assert_exit_code (return_code , Exit .OK , f"URL: { url } " )
363489 assert_func (output_path , expected_file , transform = transform )
490+
491+ _validate_output_files (output_path , extra_args )
0 commit comments