Skip to content

Commit 306b1d4

Browse files
committed
Add --use-default-factory-for-optional-nested-models option
1 parent 4423a49 commit 306b1d4

21 files changed

Lines changed: 353 additions & 1 deletion

src/datamodel_code_generator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
470470
frozen_dataclasses: bool = False,
471471
no_alias: bool = False,
472472
use_frozen_field: bool = False,
473+
use_default_factory_for_optional_nested_models: bool = False,
473474
formatters: list[Formatter] = DEFAULT_FORMATTERS,
474475
settings_path: Path | None = None,
475476
parent_scoped_naming: bool = False,
@@ -717,6 +718,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
717718
frozen_dataclasses=frozen_dataclasses,
718719
no_alias=no_alias,
719720
use_frozen_field=use_frozen_field,
721+
use_default_factory_for_optional_nested_models=use_default_factory_for_optional_nested_models,
720722
formatters=formatters,
721723
encoding=encoding,
722724
parent_scoped_naming=parent_scoped_naming,

src/datamodel_code_generator/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def validate_all_exports_collision_strategy(cls, values: dict[str, Any]) -> dict
459459
dataclass_arguments: Optional[DataclassArguments] = None # noqa: UP045
460460
no_alias: bool = False
461461
use_frozen_field: bool = False
462+
use_default_factory_for_optional_nested_models: bool = False
462463
formatters: list[Formatter] = DEFAULT_FORMATTERS
463464
parent_scoped_naming: bool = False
464465
disable_future_imports: bool = False
@@ -761,6 +762,7 @@ def run_generate_from_config( # noqa: PLR0913, PLR0917
761762
frozen_dataclasses=config.frozen_dataclasses,
762763
no_alias=config.no_alias,
763764
use_frozen_field=config.use_frozen_field,
765+
use_default_factory_for_optional_nested_models=config.use_default_factory_for_optional_nested_models,
764766
formatters=config.formatters,
765767
settings_path=settings_path,
766768
parent_scoped_naming=config.parent_scoped_naming,

src/datamodel_code_generator/arguments.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,13 @@ def start_section(self, heading: str | None) -> None:
594594
action="store_true",
595595
default=None,
596596
)
597+
field_options.add_argument(
598+
"--use-default-factory-for-optional-nested-models",
599+
help="Use default_factory for optional nested model fields instead of None default. "
600+
"E.g., `field: Optional[Model] = Field(default_factory=Model)` instead of `field: Optional[Model] = None`",
601+
action="store_true",
602+
default=None,
603+
)
597604

598605
# ======================================================================================
599606
# Options for templating output

src/datamodel_code_generator/cli_options.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ class CLIOptionMeta:
9292
"--strip-default-none": CLIOptionMeta(name="--strip-default-none", category=OptionCategory.MODEL),
9393
"--dataclass-arguments": CLIOptionMeta(name="--dataclass-arguments", category=OptionCategory.MODEL),
9494
"--use-frozen-field": CLIOptionMeta(name="--use-frozen-field", category=OptionCategory.MODEL),
95+
"--use-default-factory-for-optional-nested-models": CLIOptionMeta(
96+
name="--use-default-factory-for-optional-nested-models", category=OptionCategory.MODEL
97+
),
9598
"--union-mode": CLIOptionMeta(name="--union-mode", category=OptionCategory.MODEL),
9699
"--parent-scoped-naming": CLIOptionMeta(name="--parent-scoped-naming", category=OptionCategory.MODEL),
97100
"--use-one-literal-as-default": CLIOptionMeta(name="--use-one-literal-as-default", category=OptionCategory.MODEL),

src/datamodel_code_generator/model/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class Config:
166166
read_only: bool = False
167167
write_only: bool = False
168168
use_frozen_field: bool = False
169+
use_default_factory_for_optional_nested_models: bool = False
169170

170171
if not TYPE_CHECKING:
171172
if not PYDANTIC_V2:

src/datamodel_code_generator/model/dataclass.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,19 @@ def field(self) -> str | None:
143143
return None
144144
return result
145145

146+
def _get_default_factory_for_nested_model(self) -> str | None:
147+
"""Get default_factory for nested dataclass model fields.
148+
149+
Returns the class name if the field type references a DataClass,
150+
otherwise returns None.
151+
"""
152+
for data_type in self.data_type.data_types or (self.data_type,):
153+
if data_type.is_dict:
154+
continue
155+
if data_type.reference and isinstance(data_type.reference.source, DataClass):
156+
return data_type.alias or data_type.reference.source.class_name
157+
return None
158+
146159
def __str__(self) -> str:
147160
"""Generate field() call or default value representation."""
148161
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._FIELD_KEYS}
@@ -161,6 +174,17 @@ def __str__(self) -> str:
161174
}
162175
}
163176

177+
# Handle default_factory for optional nested models
178+
if (
179+
self.use_default_factory_for_optional_nested_models
180+
and not self.required
181+
and (self.default is None or self.default is UNDEFINED)
182+
and "default_factory" not in data
183+
):
184+
nested_model_name = self._get_default_factory_for_nested_model()
185+
if nested_model_name:
186+
data["default_factory"] = nested_model_name
187+
164188
if not data:
165189
return ""
166190

src/datamodel_code_generator/model/msgspec.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def field(self) -> str | None:
247247
return None
248248
return result
249249

250-
def __str__(self) -> str:
250+
def __str__(self) -> str: # noqa: PLR0912
251251
"""Generate field() call or default value representation."""
252252
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._FIELD_KEYS}
253253
if self.alias:
@@ -284,6 +284,18 @@ def __str__(self) -> str:
284284
else:
285285
data["default_factory"] = type(default_value).__name__
286286

287+
# Handle default_factory for optional nested models
288+
if (
289+
self.use_default_factory_for_optional_nested_models
290+
and not self.required
291+
and (self.default is None or self.default is UNDEFINED)
292+
and "default_factory" not in data
293+
):
294+
nested_model_name = self._get_default_factory_for_optional_nested_model()
295+
if nested_model_name:
296+
data["default_factory"] = nested_model_name
297+
data.pop("default", None)
298+
287299
if not data:
288300
return ""
289301

@@ -412,6 +424,19 @@ def _get_default_as_struct_model(self) -> str | None:
412424
)
413425
return None
414426

427+
def _get_default_factory_for_optional_nested_model(self) -> str | None:
428+
"""Get default_factory for optional nested Struct model fields.
429+
430+
Returns the class name if the field type references a Struct,
431+
otherwise returns None.
432+
"""
433+
for data_type in self.data_type.data_types or (self.data_type,):
434+
if data_type.is_dict:
435+
continue
436+
if data_type.reference and isinstance(data_type.reference.source, Struct):
437+
return data_type.alias or data_type.reference.source.class_name
438+
return None
439+
415440

416441
class DataTypeManager(_DataTypeManager):
417442
"""Type manager for msgspec Struct models."""

src/datamodel_code_generator/model/pydantic/base_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,19 @@ def _get_default_as_pydantic_model(self) -> str | None:
152152
)
153153
return None
154154

155+
def _get_default_factory_for_optional_nested_model(self) -> str | None:
156+
"""Get default_factory for optional nested Pydantic model fields.
157+
158+
Returns the class name if the field type references a BaseModel,
159+
otherwise returns None.
160+
"""
161+
for data_type in self.data_type.data_types or (self.data_type,):
162+
if data_type.is_dict:
163+
continue
164+
if data_type.reference and isinstance(data_type.reference.source, BaseModelBase):
165+
return data_type.alias or data_type.reference.source.class_name
166+
return None
167+
155168
def _process_data_in_str(self, data: dict[str, Any]) -> None:
156169
if self.const:
157170
data["const"] = True
@@ -204,6 +217,15 @@ def __str__(self) -> str: # noqa: PLR0912
204217
else:
205218
default_factory = data.pop("default_factory", None)
206219

220+
# Handle default_factory for optional nested models
221+
if (
222+
default_factory is None
223+
and self.use_default_factory_for_optional_nested_models
224+
and not self.required
225+
and (self.default is None or self.default is UNDEFINED)
226+
):
227+
default_factory = self._get_default_factory_for_optional_nested_model()
228+
207229
self.__dict__["_computed_default_factory"] = default_factory
208230

209231
field_arguments = sorted(f"{k}={v!r}" for k, v in data.items() if v is not None)

src/datamodel_code_generator/parser/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,7 @@ def __init__( # noqa: PLR0913, PLR0915
734734
frozen_dataclasses: bool = False,
735735
no_alias: bool = False,
736736
use_frozen_field: bool = False,
737+
use_default_factory_for_optional_nested_models: bool = False,
737738
formatters: list[Formatter] = DEFAULT_FORMATTERS,
738739
parent_scoped_naming: bool = False,
739740
dataclass_arguments: DataclassArguments | None = None,
@@ -875,6 +876,7 @@ def __init__( # noqa: PLR0913, PLR0915
875876
self.type_mappings: dict[tuple[str, str], str] = Parser._parse_type_mappings(type_mappings)
876877
self.read_only_write_only_model_type: ReadOnlyWriteOnlyModelType | None = read_only_write_only_model_type
877878
self.use_frozen_field: bool = use_frozen_field
879+
self.use_default_factory_for_optional_nested_models: bool = use_default_factory_for_optional_nested_models
878880

879881
@property
880882
def field_name_model_type(self) -> ModelType:

src/datamodel_code_generator/parser/graphql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def __init__( # noqa: PLR0913
190190
read_only_write_only_model_type: ReadOnlyWriteOnlyModelType | None = None,
191191
use_serialize_as_any: bool = False,
192192
use_frozen_field: bool = False,
193+
use_default_factory_for_optional_nested_models: bool = False,
193194
) -> None:
194195
"""Initialize the GraphQL parser with configuration options."""
195196
super().__init__(
@@ -284,6 +285,7 @@ def __init__( # noqa: PLR0913
284285
read_only_write_only_model_type=read_only_write_only_model_type,
285286
use_serialize_as_any=use_serialize_as_any,
286287
use_frozen_field=use_frozen_field,
288+
use_default_factory_for_optional_nested_models=use_default_factory_for_optional_nested_models,
287289
)
288290

289291
self.data_model_scalar_type = data_model_scalar_type

0 commit comments

Comments
 (0)