Skip to content

Commit 6515880

Browse files
committed
Remove non-init fields in set_default
Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
1 parent 1bd0587 commit 6515880

1 file changed

Lines changed: 26 additions & 11 deletions

File tree

simple_parsing/wrappers/dataclass_wrapper.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import argparse
44
import dataclasses
55
import functools
6-
import inspect
76
import sys
87
import textwrap
98
from dataclasses import MISSING
@@ -21,10 +20,11 @@
2120
logger = getLogger(__name__)
2221

2322
MAX_DOCSTRING_DESC_LINES_HEIGHT: int = 50
24-
"""
25-
Maximum number of lines of the class docstring to include in the autogenerated argument group
26-
description. If fields don't have docstrings or help text, then this is not used, and the entire
27-
docstring is used as the description of the argument group.
23+
"""Maximum number of lines of the class docstring to include in the autogenerated argument group
24+
description.
25+
26+
If fields don't have docstrings or help text, then this is not used, and the entire docstring is
27+
used as the description of the argument group.
2828
"""
2929

3030
DataclassWrapperType = TypeVar("DataclassWrapperType", bound="DataclassWrapper")
@@ -172,7 +172,8 @@ def __init__(
172172
# a "normal" attribute
173173
field_wrapper = self.field_wrapper_class(field, parent=self, prefix=self.prefix)
174174
logger.debug(
175-
f"wrapped field at {field_wrapper.dest} has a default value of {field_wrapper.default}"
175+
f"wrapped field at {field_wrapper.dest} has a default value of "
176+
f"{field_wrapper.default}"
176177
)
177178
if field_default is not dataclasses.MISSING:
178179
field_wrapper.set_default(field_default)
@@ -217,9 +218,12 @@ def add_arguments(self, parser: argparse.ArgumentParser):
217218
def equivalent_argparse_code(self, leading="group") -> str:
218219
code = ""
219220
code += textwrap.dedent(
220-
f"""
221-
group = parser.add_argument_group(title="{self.title.strip()}", description="{self.description.strip()}")
222-
"""
221+
f"""\
222+
group = parser.add_argument_group(
223+
title="{self.title.strip()}",
224+
description="{self.description.strip()}",
225+
)
226+
"""
223227
)
224228
for wrapped_field in self.fields:
225229
if wrapped_field.is_subparser:
@@ -295,6 +299,11 @@ def set_default(self, value: DataclassT | dict | None):
295299
self._default = value
296300
if field_default_values is None:
297301
return
302+
# Ignore default values for fields that have init=False.
303+
for field in dataclasses.fields(self.dataclass):
304+
if not field.init and field.name in field_default_values:
305+
field_default_values.pop(field.name)
306+
298307
unknown_names = set(field_default_values)
299308
for field_wrapper in self.fields:
300309
if field_wrapper.name not in field_default_values:
@@ -314,7 +323,10 @@ def set_default(self, value: DataclassT | dict | None):
314323
unknown_names.remove(nested_dataclass_wrapper.name)
315324
unknown_names.discard("_type_")
316325
if unknown_names:
317-
raise RuntimeError(f"{sorted(unknown_names)} are not fields of {self.dataclass} at path {self.dest!r}!")
326+
raise RuntimeError(
327+
f"{sorted(unknown_names)} are not fields of {self.dataclass} at path "
328+
f"{self.dest!r}!"
329+
)
318330

319331
@property
320332
def title(self) -> str:
@@ -423,13 +435,16 @@ def destinations(self, value: list[str]):
423435

424436
def merge(self, other: DataclassWrapper):
425437
"""Absorb all the relevant attributes from another wrapper.
438+
426439
Args:
427440
other (DataclassWrapper): Another instance to absorb into this one.
428441
"""
429442
# logger.debug(f"merging \n{self}\n with \n{other}")
430443
logger.debug(f"self destinations: {self.destinations}")
431444
logger.debug(f"other destinations: {other.destinations}")
432-
# assert not set(self.destinations).intersection(set(other.destinations)), "shouldn't have overlap in destinations"
445+
# assert not set(self.destinations).intersection(set(other.destinations)), (
446+
# "shouldn't have overlap in destinations"
447+
# )
433448
# self.destinations.extend(other.destinations)
434449
for dest in other.destinations:
435450
if dest not in self.destinations:

0 commit comments

Comments
 (0)