|
10 | 10 | from argparse import Namespace |
11 | 11 | from collections.abc import Callable, Generator, Sequence |
12 | 12 | from pathlib import Path |
13 | | -from typing import Literal |
| 13 | +from typing import Any, Literal |
14 | 14 |
|
15 | 15 | import black |
16 | 16 | import pytest |
17 | 17 | from packaging import version |
18 | 18 |
|
| 19 | +from datamodel_code_generator import InputFileType, generate |
19 | 20 | from datamodel_code_generator.__main__ import Exit, main |
20 | 21 | from datamodel_code_generator.arguments import arg_parser |
21 | 22 | from tests.conftest import ( |
@@ -240,6 +241,53 @@ def run_main_with_args( |
240 | 241 | return return_code |
241 | 242 |
|
242 | 243 |
|
| 244 | +def run_generate_file_and_assert( |
| 245 | + *, |
| 246 | + input_path: Path, |
| 247 | + output_path: Path, |
| 248 | + input_file_type: InputFileType | None = None, |
| 249 | + assert_func: AssertFileContent, |
| 250 | + expected_file: str | Path | None = None, |
| 251 | + transform: Callable[[str], str] | None = None, |
| 252 | + **generate_kwargs: Any, |
| 253 | +) -> None: |
| 254 | + """Execute generate() for a file input and assert the generated output.""" |
| 255 | + __tracebackhide__ = True |
| 256 | + |
| 257 | + input_: Path = input_path |
| 258 | + if input_path.is_absolute(): |
| 259 | + try: |
| 260 | + input_ = input_path.relative_to(Path.cwd()) |
| 261 | + except ValueError: |
| 262 | + input_ = input_path |
| 263 | + else: |
| 264 | + assert not input_.is_absolute() |
| 265 | + |
| 266 | + generate_options: dict[str, Any] = { |
| 267 | + "output": output_path, |
| 268 | + **generate_kwargs, |
| 269 | + } |
| 270 | + if input_file_type is not None: |
| 271 | + generate_options["input_file_type"] = input_file_type |
| 272 | + |
| 273 | + generate( |
| 274 | + input_=input_, |
| 275 | + **generate_options, |
| 276 | + ) |
| 277 | + |
| 278 | + if expected_file is None: |
| 279 | + frame = inspect.currentframe() |
| 280 | + assert frame is not None |
| 281 | + assert frame.f_back is not None |
| 282 | + func_name = frame.f_back.f_code.co_name |
| 283 | + del frame |
| 284 | + for prefix in ("test_main_", "test_"): |
| 285 | + func_name = func_name.removeprefix(prefix) |
| 286 | + expected_file = f"{func_name}.py" |
| 287 | + |
| 288 | + assert_func(output_path, expected_file, transform=transform) |
| 289 | + |
| 290 | + |
243 | 291 | def run_main_and_assert( # noqa: PLR0912 |
244 | 292 | *, |
245 | 293 | input_path: Path | None = None, |
|
0 commit comments