|
20 | 20 | Import, |
21 | 21 | ) |
22 | 22 | from datamodel_code_generator.model import DataModel, DataModelFieldBase |
23 | | -from datamodel_code_generator.model.base import UNDEFINED |
| 23 | +from datamodel_code_generator.model.base import UNDEFINED, BaseClassDataType |
24 | 24 | from datamodel_code_generator.model.imports import ( |
25 | 25 | IMPORT_MSGSPEC_CONVERT, |
26 | 26 | IMPORT_MSGSPEC_FIELD, |
27 | 27 | IMPORT_MSGSPEC_META, |
| 28 | + IMPORT_MSGSPEC_STRUCT, |
28 | 29 | IMPORT_MSGSPEC_UNSET, |
29 | 30 | IMPORT_MSGSPEC_UNSETTYPE, |
30 | 31 | ) |
@@ -109,7 +110,18 @@ class Struct(DataModel): |
109 | 110 |
|
110 | 111 | TEMPLATE_FILE_PATH: ClassVar[str] = "msgspec.jinja2" |
111 | 112 | BASE_CLASS: ClassVar[str] = "msgspec.Struct" |
| 113 | + BASE_CLASS_NAME: ClassVar[str] = "Struct" |
| 114 | + BASE_CLASS_ALIAS: ClassVar[str] = "_Struct" |
112 | 115 | DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = () |
| 116 | + CONFIG_MAPPING: ClassVar[dict[tuple[str, Any], tuple[str, Any] | None]] = { |
| 117 | + ("allow_mutation", False): ("frozen", True), |
| 118 | + ("extra_fields", "forbid"): ("forbid_unknown_fields", True), |
| 119 | + ("extra_fields", "allow"): None, |
| 120 | + ("extra_fields", "ignore"): None, |
| 121 | + ("allow_extra_fields", True): None, |
| 122 | + ("allow_population_by_field_name", True): None, |
| 123 | + ("use_attribute_docstrings", True): None, |
| 124 | + } |
113 | 125 |
|
114 | 126 | def __init__( # noqa: PLR0913 |
115 | 127 | self, |
@@ -154,6 +166,46 @@ def add_base_class_kwarg(self, name: str, value: str) -> None: |
154 | 166 | """Add keyword argument to base class constructor.""" |
155 | 167 | self.extra_template_data["base_class_kwargs"][name] = value |
156 | 168 |
|
| 169 | + @classmethod |
| 170 | + def create_base_class_model( |
| 171 | + cls, |
| 172 | + config: dict[str, Any], |
| 173 | + reference: Reference, |
| 174 | + custom_template_dir: Path | None = None, |
| 175 | + keyword_only: bool = False, # noqa: FBT001, FBT002 |
| 176 | + treat_dot_as_module: bool = False, # noqa: FBT001, FBT002 |
| 177 | + ) -> Struct | None: |
| 178 | + """Create a shared base class model for DRY configuration. |
| 179 | +
|
| 180 | + Creates a Struct that inherits from msgspec.Struct (aliased as _Struct) |
| 181 | + with the specified configuration. Updates the reference path and name in place. |
| 182 | + """ |
| 183 | + reference.path = f"#/{cls.BASE_CLASS_NAME}" |
| 184 | + reference.name = cls.BASE_CLASS_NAME |
| 185 | + |
| 186 | + base_model = cls( |
| 187 | + reference=reference, |
| 188 | + fields=[], |
| 189 | + custom_template_dir=custom_template_dir, |
| 190 | + keyword_only=keyword_only, |
| 191 | + treat_dot_as_module=treat_dot_as_module, |
| 192 | + ) |
| 193 | + |
| 194 | + base_model.base_classes = [BaseClassDataType(type=cls.BASE_CLASS_ALIAS)] |
| 195 | + |
| 196 | + for key, value in config.items(): |
| 197 | + mapping_result = cls.CONFIG_MAPPING.get((key, value)) |
| 198 | + if mapping_result is None: |
| 199 | + continue |
| 200 | + mapped_key, mapped_value = mapping_result |
| 201 | + base_model.add_base_class_kwarg(mapped_key, str(mapped_value)) |
| 202 | + |
| 203 | + base_model._additional_imports.append( |
| 204 | + Import(from_=IMPORT_MSGSPEC_STRUCT.from_, import_=IMPORT_MSGSPEC_STRUCT.import_, alias=cls.BASE_CLASS_ALIAS) |
| 205 | + ) |
| 206 | + |
| 207 | + return base_model |
| 208 | + |
157 | 209 |
|
158 | 210 | class Constraints(_Constraints): |
159 | 211 | """Constraint model for msgspec fields.""" |
|
0 commit comments