Skip to content

Commit 5bd8eab

Browse files
Add ryaml as optional YAML backend for faster parsing (#3055)
* Add ryaml as optional YAML backend for faster parsing * docs: update llms.txt files Generated by GitHub Actions * Bump ryaml minimum version to 0.5.1 --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent af51cd7 commit 5bd8eab

10 files changed

Lines changed: 225 additions & 31 deletions

File tree

docs/cli-reference/field-customization.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3678,8 +3678,7 @@ This is useful when schemas have descriptive titles that should be preserved.
36783678

36793679
class ProcessingTaskTitle(BaseModel):
36803680
processing_status_union: ProcessingStatusUnionTitle | None = Field(
3681-
default_factory=lambda: ProcessingStatusUnionTitle('COMPLETED'),
3682-
title='Processing Status Union Title',
3681+
'COMPLETED', title='Processing Status Union Title', validate_default=True
36833682
)
36843683
processing_status: ProcessingStatusTitle | None = 'COMPLETED'
36853684
name: str | None = None
@@ -3706,10 +3705,7 @@ This is useful when schemas have descriptive titles that should be preserved.
37063705
RootModel[ProcessingStatusDetail | ExtendedProcessingTask | ProcessingStatusTitle]
37073706
):
37083707
root: ProcessingStatusDetail | ExtendedProcessingTask | ProcessingStatusTitle = (
3709-
Field(
3710-
default_factory=lambda: ExtendedProcessingTask('COMPLETED'),
3711-
title='Processing Status Union Title',
3712-
)
3708+
Field('COMPLETED', title='Processing Status Union Title', validate_default=True)
37133709
)
37143710

37153711

docs/cli-reference/model-customization.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6014,9 +6014,7 @@ The `--use-one-literal-as-default` flag configures the code generation behavior.
60146014

60156015
class NestedNullableEnum(BaseModel):
60166016
nested_version: NestedVersion | None = Field(
6017-
default_factory=lambda: NestedVersion('RC1'),
6018-
description='nullable enum',
6019-
examples=['RC2'],
6017+
'RC1', description='nullable enum', examples=['RC2'], validate_default=True
60206018
)
60216019

60226020

docs/cli-reference/template-customization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2892,7 +2892,7 @@ helps maintain consistency with codebases that prefer double-quote formatting.
28922892
class MapState2(BaseModel):
28932893
latitude: Latitude
28942894
longitude: Longitude
2895-
zoom: Zoom | None = Field(default_factory=lambda: Zoom(0))
2895+
zoom: Zoom | None = Field(0, validate_default=True)
28962896
bearing: Bearing | None = None
28972897
pitch: Pitch
28982898
drag_rotate: DragRotate | None = Field(None, alias="dragRotate")

docs/cli-reference/typing-customization.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,9 +1350,7 @@ of Enum classes for all enumerations.
13501350
13511351
class NestedNullableEnum(BaseModel):
13521352
nested_version: NestedVersion | None = Field(
1353-
default_factory=lambda: NestedVersion('RC1'),
1354-
description='nullable enum',
1355-
examples=['RC2'],
1353+
'RC1', description='nullable enum', examples=['RC2'], validate_default=True
13561354
)
13571355
13581356

docs/llms-full.txt

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7242,9 +7242,7 @@ The `--use-one-literal-as-default` flag configures the code generation behavior.
72427242

72437243
class NestedNullableEnum(BaseModel):
72447244
nested_version: NestedVersion | None = Field(
7245-
default_factory=lambda: NestedVersion('RC1'),
7246-
description='nullable enum',
7247-
examples=['RC2'],
7245+
'RC1', description='nullable enum', examples=['RC2'], validate_default=True
72487246
)
72497247

72507248

@@ -11141,8 +11139,7 @@ This is useful when schemas have descriptive titles that should be preserved.
1114111139

1114211140
class ProcessingTaskTitle(BaseModel):
1114311141
processing_status_union: ProcessingStatusUnionTitle | None = Field(
11144-
default_factory=lambda: ProcessingStatusUnionTitle('COMPLETED'),
11145-
title='Processing Status Union Title',
11142+
'COMPLETED', title='Processing Status Union Title', validate_default=True
1114611143
)
1114711144
processing_status: ProcessingStatusTitle | None = 'COMPLETED'
1114811145
name: str | None = None
@@ -11169,10 +11166,7 @@ This is useful when schemas have descriptive titles that should be preserved.
1116911166
RootModel[ProcessingStatusDetail | ExtendedProcessingTask | ProcessingStatusTitle]
1117011167
):
1117111168
root: ProcessingStatusDetail | ExtendedProcessingTask | ProcessingStatusTitle = (
11172-
Field(
11173-
default_factory=lambda: ExtendedProcessingTask('COMPLETED'),
11174-
title='Processing Status Union Title',
11175-
)
11169+
Field('COMPLETED', title='Processing Status Union Title', validate_default=True)
1117611170
)
1117711171

1117811172

@@ -12537,9 +12531,7 @@ of Enum classes for all enumerations.
1253712531

1253812532
class NestedNullableEnum(BaseModel):
1253912533
nested_version: NestedVersion | None = Field(
12540-
default_factory=lambda: NestedVersion('RC1'),
12541-
description='nullable enum',
12542-
examples=['RC2'],
12534+
'RC1', description='nullable enum', examples=['RC2'], validate_default=True
1254312535
)
1254412536

1254512537

@@ -19083,7 +19075,7 @@ helps maintain consistency with codebases that prefer double-quote formatting.
1908319075
class MapState2(BaseModel):
1908419076
latitude: Latitude
1908519077
longitude: Longitude
19086-
zoom: Zoom | None = Field(default_factory=lambda: Zoom(0))
19078+
zoom: Zoom | None = Field(0, validate_default=True)
1908719079
bearing: Bearing | None = None
1908819080
pitch: Pitch
1908919081
drag_rotate: DragRotate | None = Field(None, alias="dragRotate")

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ optional-dependencies.http = [
6060
optional-dependencies.ruff = [
6161
"ruff>=0.9.10",
6262
]
63+
optional-dependencies.ryaml = [
64+
"ryaml>=0.5.1",
65+
]
6366
optional-dependencies.validation = [
6467
"openapi-spec-validator>=0.2.8,<0.8",
6568
"prance>=0.18.2",

src/datamodel_code_generator/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,15 @@
9797

9898

9999
def load_yaml(stream: str | TextIO) -> YamlValue:
100-
"""Load YAML content from a string or file-like object."""
100+
"""Load YAML content using ryaml (if available) or PyYAML."""
101+
from datamodel_code_generator.util import get_yaml_backend # noqa: PLC0415
102+
103+
if get_yaml_backend() == "ryaml":
104+
import ryaml # noqa: PLC0415 # ty: ignore[unresolved-import]
105+
106+
text = stream if isinstance(stream, str) else stream.read()
107+
return ryaml.loads(text)
108+
101109
import yaml # noqa: PLC0415
102110

103111
from datamodel_code_generator.util import SafeLoader # noqa: PLC0415
@@ -933,11 +941,11 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
933941

934942
def infer_input_type(text: str) -> InputFileType:
935943
"""Automatically detect the input file type from text content."""
936-
import yaml.parser # noqa: PLC0415
944+
from datamodel_code_generator.util import get_yaml_parse_errors # noqa: PLC0415
937945

938946
try:
939947
data = load_yaml(text)
940-
except yaml.parser.ParserError:
948+
except get_yaml_parse_errors():
941949
return InputFileType.CSV
942950
if isinstance(data, dict):
943951
if is_openapi(data):

src/datamodel_code_generator/util.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
import warnings
77
from functools import lru_cache
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, Any, Literal
99

1010
if TYPE_CHECKING:
1111
from collections.abc import Callable
@@ -83,6 +83,35 @@ class CustomSafeLoader(_SafeLoader): # type: ignore[valid-type,misc]
8383
return CustomSafeLoader
8484

8585

86+
YamlBackend = Literal["ryaml", "pyyaml"]
87+
88+
89+
@lru_cache(maxsize=1)
90+
def get_yaml_backend() -> YamlBackend:
91+
"""Detect the available YAML backend ('ryaml' or 'pyyaml')."""
92+
try:
93+
import ryaml # noqa: PLC0415, F401 # ty: ignore[unresolved-import]
94+
except ImportError:
95+
return "pyyaml"
96+
else:
97+
return "ryaml"
98+
99+
100+
@lru_cache(maxsize=1)
101+
def get_yaml_parse_errors() -> tuple[type[Exception], ...]:
102+
"""Return YAML parse error types for both backends."""
103+
import yaml # noqa: PLC0415
104+
105+
errors: list[type[Exception]] = [yaml.YAMLError]
106+
try:
107+
import ryaml # noqa: PLC0415 # ty: ignore[unresolved-import]
108+
109+
errors.append(ryaml.InvalidYamlError)
110+
except ImportError:
111+
pass
112+
return tuple(errors)
113+
114+
86115
@lru_cache(maxsize=1)
87116
def _get_base_model_class() -> type:
88117
"""Get BaseModel class with strict=False config lazily."""

tests/test_yaml_backend.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""Tests for YAML backend detection and ryaml/PyYAML switching."""
2+
3+
from __future__ import annotations
4+
5+
import io
6+
from typing import TYPE_CHECKING
7+
from unittest.mock import MagicMock, patch
8+
9+
import pytest
10+
import yaml
11+
12+
from datamodel_code_generator import InputFileType, infer_input_type, load_yaml
13+
from datamodel_code_generator.util import get_yaml_backend, get_yaml_parse_errors
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Iterator
17+
18+
19+
@pytest.fixture(autouse=True)
20+
def _clear_caches() -> Iterator[None]:
21+
"""Clear lru_cache before and after each test."""
22+
get_yaml_backend.cache_clear()
23+
get_yaml_parse_errors.cache_clear()
24+
yield
25+
get_yaml_backend.cache_clear()
26+
get_yaml_parse_errors.cache_clear()
27+
28+
29+
class TestGetYamlBackend:
30+
"""Tests for get_yaml_backend()."""
31+
32+
def test_without_ryaml(self) -> None:
33+
"""When ryaml is not importable, returns 'pyyaml'."""
34+
with patch.dict("sys.modules", {"ryaml": None}):
35+
assert get_yaml_backend() == "pyyaml"
36+
37+
def test_with_ryaml(self) -> None:
38+
"""When ryaml is importable, returns 'ryaml'."""
39+
mock_ryaml = MagicMock()
40+
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
41+
assert get_yaml_backend() == "ryaml"
42+
43+
44+
class TestGetYamlParseErrors:
45+
"""Tests for get_yaml_parse_errors()."""
46+
47+
def test_pyyaml_only(self) -> None:
48+
"""Without ryaml, only yaml.YAMLError is returned."""
49+
with patch.dict("sys.modules", {"ryaml": None}):
50+
errors = get_yaml_parse_errors()
51+
assert errors == (yaml.YAMLError,)
52+
53+
def test_includes_ryaml(self) -> None:
54+
"""With ryaml, InvalidYamlError is included."""
55+
mock_ryaml = MagicMock()
56+
mock_ryaml.InvalidYamlError = type("InvalidYamlError", (Exception,), {})
57+
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
58+
errors = get_yaml_parse_errors()
59+
assert yaml.YAMLError in errors
60+
assert mock_ryaml.InvalidYamlError in errors
61+
assert len(errors) == 2
62+
63+
64+
class TestLoadYaml:
65+
"""Tests for load_yaml() with backend switching."""
66+
67+
def test_pyyaml_fallback_string(self) -> None:
68+
"""When ryaml is unavailable, PyYAML is used for string input."""
69+
with patch.dict("sys.modules", {"ryaml": None}):
70+
result = load_yaml("key: value")
71+
assert result == {"key": "value"}
72+
73+
def test_pyyaml_fallback_textio(self) -> None:
74+
"""When ryaml is unavailable, PyYAML is used for TextIO input."""
75+
with patch.dict("sys.modules", {"ryaml": None}):
76+
result = load_yaml(io.StringIO("key: value"))
77+
assert result == {"key": "value"}
78+
79+
def test_with_ryaml_string(self) -> None:
80+
"""When ryaml is available, ryaml.loads() is used for string input."""
81+
mock_ryaml = MagicMock()
82+
mock_ryaml.loads.return_value = {"key": "value"}
83+
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
84+
result = load_yaml("key: value")
85+
mock_ryaml.loads.assert_called_once_with("key: value")
86+
assert result == {"key": "value"}
87+
88+
def test_with_ryaml_textio(self) -> None:
89+
"""When ryaml is available, TextIO.read() is called before ryaml.loads()."""
90+
mock_ryaml = MagicMock()
91+
mock_ryaml.loads.return_value = {"key": "value"}
92+
stream = io.StringIO("key: value")
93+
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
94+
result = load_yaml(stream)
95+
mock_ryaml.loads.assert_called_once_with("key: value")
96+
assert result == {"key": "value"}
97+
98+
99+
class TestInferInputType:
100+
"""Tests for infer_input_type() with backend error handling."""
101+
102+
def test_csv_with_pyyaml_error(self) -> None:
103+
"""YAML parse error from PyYAML returns CSV type."""
104+
with patch.dict("sys.modules", {"ryaml": None}):
105+
result = infer_input_type("a,b,c\n1,2,3\n::")
106+
assert result == InputFileType.CSV
107+
108+
def test_csv_with_ryaml_error(self) -> None:
109+
"""YAML parse error from ryaml returns CSV type."""
110+
mock_invalid_yaml_error = type("InvalidYamlError", (Exception,), {})
111+
mock_ryaml = MagicMock()
112+
mock_ryaml.InvalidYamlError = mock_invalid_yaml_error
113+
mock_ryaml.loads.side_effect = mock_invalid_yaml_error("parse error")
114+
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
115+
result = infer_input_type(":::invalid yaml:::")
116+
assert result == InputFileType.CSV
117+
118+
def test_openapi_detection(self) -> None:
119+
"""OpenAPI input is detected correctly regardless of backend."""
120+
with patch.dict("sys.modules", {"ryaml": None}):
121+
result = infer_input_type("openapi: '3.0.0'\ninfo:\n title: Test\n version: '1.0'")
122+
assert result == InputFileType.OpenAPI

0 commit comments

Comments
 (0)