Skip to content

Commit 590722c

Browse files
committed
Add test and use trick from anivegesena
Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
1 parent 6515880 commit 590722c

4 files changed

Lines changed: 74 additions & 36 deletions

File tree

.pre-commit-config.yaml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ repos:
2828

2929
- repo: https://github.com/charliermarsh/ruff-pre-commit
3030
# Ruff version.
31-
rev: 'v0.0.261'
31+
rev: "v0.0.261"
3232
hooks:
3333
- id: ruff
34-
args: ['--line-length', '99', '--fix']
34+
args: ["--line-length", "99", "--fix"]
3535
require_serial: true
3636

3737
# python code formatting
@@ -44,7 +44,7 @@ repos:
4444

4545
# python docstring formatting
4646
- repo: https://github.com/myint/docformatter
47-
rev: v1.5.1
47+
rev: v1.7.5
4848
hooks:
4949
- id: docformatter
5050
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
@@ -65,7 +65,6 @@ repos:
6565
- id: nbstripout
6666
require_serial: true
6767

68-
6968
# md formatting
7069
- repo: https://github.com/executablebooks/mdformat
7170
rev: 0.7.16
@@ -80,7 +79,6 @@ repos:
8079
# - mdformat-black
8180
require_serial: true
8281

83-
8482
# word spelling linter
8583
- repo: https://github.com/codespell-project/codespell
8684
rev: v2.2.2

simple_parsing/helpers/serialization/serializable.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ def __init_subclass__(
208208
if parent in SerializableMixin.subclasses and parent is not SerializableMixin:
209209
decode_into_subclasses = parent.decode_into_subclasses
210210
logger.debug(
211-
f"Parent class {parent} has decode_into_subclasses = {decode_into_subclasses}"
211+
f"Parent class {parent} has decode_into_subclasses = "
212+
f"{decode_into_subclasses}"
212213
)
213214
break
214215

@@ -220,7 +221,10 @@ def __init_subclass__(
220221
register_decoding_fn(cls, cls.from_dict)
221222

222223
def to_dict(
223-
self, dict_factory: type[dict] = dict, recurse: bool = True, save_dc_types: bool = False
224+
self,
225+
dict_factory: type[dict] = dict,
226+
recurse: bool = True,
227+
save_dc_types: bool | int = False,
224228
) -> dict:
225229
"""Serializes this dataclass to a dict.
226230
@@ -462,7 +466,8 @@ class SimpleSerializable(SerializableMixin, decode_into_subclasses=True):
462466
def get_serializable_dataclass_types_from_forward_ref(
463467
forward_ref: type, serializable_base_class: type[S] = SerializableMixin
464468
) -> list[type[S]]:
465-
"""Gets all the subclasses of `serializable_base_class` that have the same name as the argument of this forward reference annotation."""
469+
"""Gets all the subclasses of `serializable_base_class` that have the same name as the argument
470+
of this forward reference annotation."""
466471
arg = get_forward_arg(forward_ref)
467472
potential_classes: list[type] = []
468473
for serializable_class in serializable_base_class.subclasses:
@@ -595,6 +600,7 @@ def loads_yaml(
595600

596601
def read_file(path: str | Path) -> dict:
597602
"""Returns the contents of the given file as a dictionary.
603+
598604
Uses the right function depending on `path.suffix`:
599605
{
600606
".yml": yaml.safe_load,
@@ -613,7 +619,7 @@ def save(
613619
obj: Any,
614620
path: str | Path,
615621
format: FormatExtension | None = None,
616-
save_dc_types: bool = False,
622+
save_dc_types: bool | int = False,
617623
**kwargs,
618624
) -> None:
619625
"""Save the given dataclass or dictionary to the given file."""
@@ -704,7 +710,7 @@ def to_dict(
704710
dc: DataclassT,
705711
dict_factory: type[dict] = dict,
706712
recurse: bool = True,
707-
save_dc_types: bool = False,
713+
save_dc_types: bool | int = False,
708714
) -> dict:
709715
"""Serializes this dataclass to a dict.
710716
@@ -736,6 +742,11 @@ def to_dict(
736742
else:
737743
d[DC_TYPE_KEY] = module + "." + class_name
738744

745+
# Decrement save_dc_types if it is an int, so that we only save the type of the subgroups
746+
# dataclass, not all dataclasses recursively.
747+
if save_dc_types is not True and save_dc_types > 0:
748+
save_dc_types -= 1
749+
739750
for f in fields(dc):
740751
name = f.name
741752
value = getattr(dc, name)
@@ -763,7 +774,8 @@ def to_dict(
763774
encoded = encoding_fn(value)
764775
except Exception as e:
765776
logger.error(
766-
f"Unable to encode value {value} of type {type(value)}! Leaving it as-is. (exception: {e})"
777+
f"Unable to encode value {value} of type {type(value)}! Leaving it as-is. "
778+
f"(exception: {e})"
767779
)
768780
encoded = value
769781
d[name] = encoded
@@ -931,13 +943,11 @@ def is_dataclass_or_optional_dataclass_type(t: type) -> bool:
931943

932944
@functools.lru_cache(maxsize=None)
933945
def _locate(path: str) -> Any:
934-
"""
935-
COPIED FROM Hydra:
936-
https://github.com/facebookresearch/hydra/blob/f8940600d0ab5c695961ad83abd042ffe9458caf/hydra/_internal/utils.py#L614
946+
"""COPIED FROM Hydra: https://github.com/facebookresearch/hydra/blob/f8940600d0ab5c695961ad83ab
947+
d042ffe9458caf/hydra/_internal/utils.py#L614.
937948
938-
Locate an object by name or dotted path, importing as necessary.
939-
This is similar to the pydoc function `locate`, except that it checks for
940-
the module from the given path from back to front.
949+
Locate an object by name or dotted path, importing as necessary. This is similar to the pydoc
950+
function `locate`, except that it checks for the module from the given path from back to front.
941951
"""
942952
if path == "":
943953
raise ImportError("Empty path")
@@ -972,7 +982,8 @@ def _locate(path: str) -> Any:
972982
except ModuleNotFoundError as exc_import:
973983
raise ImportError(
974984
f"Error loading '{path}':\n{repr(exc_import)}"
975-
+ f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?"
985+
+ f"\nAre you sure that '{part}' is importable from module "
986+
f"'{parent_dotpath}'?"
976987
) from exc_import
977988
except Exception as exc_import:
978989
raise ImportError(

simple_parsing/helpers/subgroups.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from typing import Any, Callable, Mapping, TypeVar, Union
1010

1111
from typing_extensions import TypeAlias
12+
from simple_parsing.helpers.serialization.serializable import to_dict
1213

13-
from simple_parsing.utils import DataclassT, is_dataclass_instance, is_dataclass_type
14+
from simple_parsing.utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type
1415

1516
logger = get_logger(__name__)
1617

@@ -80,7 +81,11 @@ def subgroups(
8081
metadata["subgroup_default"] = default
8182
metadata["subgroup_dataclass_types"] = {}
8283

83-
subgroup_dataclass_types: dict[Key, type[DataclassT]] = {}
84+
# Custom encoding function that will add the _type_ key with the subgroup dataclass type.
85+
# Using an int here means that only to the subgroup dataclass.
86+
kwargs.setdefault("encoding_fn", functools.partial(to_dict, save_dc_types=1))
87+
88+
subgroup_dataclass_types: dict[Key, type[Dataclass]] = {}
8489
choices = subgroups.keys()
8590

8691
# NOTE: Perhaps we could raise a warning if the default_factory is a Lambda, since we have to
@@ -198,7 +203,8 @@ def _get_dataclass_type_from_callable(
198203
f"{dataclass_fn!r}, because it doesn't have a return type annotation, and we don't "
199204
f"want to call it just to figure out what it produces."
200205
)
201-
# NOTE: recurse here, so it also works with `partial(partial(...))` and `partial(some_function)`
206+
# NOTE: recurse here, so it also works with `partial(partial(...))` and
207+
# `partial(some_function)`
202208
# Recurse, so this also works with partial(partial(...)) (idk why you'd do that though.)
203209

204210
if isinstance(signature.return_annotation, str):
@@ -241,7 +247,8 @@ def _get_dataclass_type_from_callable(
241247
def is_lambda(obj: Any) -> bool:
242248
"""Returns True if the given object is a lambda expression.
243249
244-
Taken from https://stackoverflow.com/questions/3655842/how-can-i-test-whether-a-variable-holds-a-lambda
250+
Taken from
251+
https://stackoverflow.com/questions/3655842/how-can-i-test-whether-a-variable-holds-a-lambda
245252
"""
246253
LAMBDA = lambda: 0 # noqa: E731
247254
return isinstance(obj, type(LAMBDA)) and obj.__name__ == LAMBDA.__name__

test/test_subgroups.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from typing import Callable, TypeVar
1212

1313
import pytest
14-
from simple_parsing.helpers.serialization import save
1514
from pytest_regressions.file_regression import FileRegressionFixture
1615
from typing_extensions import Annotated
1716

17+
from simple_parsing.utils import Dataclass
18+
from simple_parsing.helpers.serialization import save
1819
from simple_parsing import ArgumentParser, parse, subgroups
20+
from simple_parsing.helpers.serialization.serializable import from_dict, to_dict
1921
from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode
2022

2123
from .test_choice import Color
@@ -190,7 +192,7 @@ def test_parse(dataclass_type: type[TestClass], args: str, expected: TestClass):
190192

191193

192194
def test_subgroup_choice_is_saved_on_namespace():
193-
"""test for https://github.com/lebrice/SimpleParsing/issues/139
195+
"""Test for https://github.com/lebrice/SimpleParsing/issues/139.
194196
195197
Need to save the chosen subgroup name somewhere on the args.
196198
"""
@@ -278,7 +280,9 @@ def test_subgroup_default_factory_needs_to_be_value_in_dict():
278280

279281
def test_lambdas_dont_return_same_instance():
280282
"""Slightly unrelated, but I just want to check if lambda expressions return the same object
281-
instance when a default factory looks like `lambda: A()`. If so, then I won't encourage this.
283+
instance when a default factory looks like `lambda: A()`.
284+
285+
If so, then I won't encourage this.
282286
"""
283287

284288
@dataclass
@@ -294,8 +298,7 @@ class Config(TestSetup):
294298

295299
def test_partials_new_args_overwrite_set_values():
296300
"""Double-check that functools.partial overwrites the keywords that are stored when it is
297-
created with the ones that are passed when calling it.
298-
"""
301+
created with the ones that are passed when calling it."""
299302
# just to avoid the test passing if I were to hard-code the same value as the default by
300303
# accident.
301304
default_a = A().a
@@ -425,7 +428,10 @@ class Foo(TestSetup):
425428
marks=pytest.mark.xfail(
426429
strict=True,
427430
raises=NotImplementedError,
428-
reason="Lambda expressions aren't allowed in the subgroup dict or default_factory at the moment.",
431+
reason=(
432+
"Lambda expressions aren't allowed in the subgroup dict or default_factory at the "
433+
"moment."
434+
),
429435
),
430436
)
431437

@@ -440,7 +446,8 @@ class Foo(TestSetup):
440446
],
441447
)
442448
def test_other_default_factories(a_factory: Callable[[], A], b_factory: Callable[[], B]):
443-
"""Test using other kinds of default factories (i.e. functools.partial or lambda expressions)"""
449+
"""Test using other kinds of default factories (i.e. functools.partial or lambda
450+
expressions)"""
444451

445452
@dataclass
446453
class Foo(TestSetup):
@@ -596,7 +603,7 @@ class Config(TestSetup):
596603

597604

598605
def test_destination_substring_of_other_destination_issue191():
599-
"""Test for https://github.com/lebrice/SimpleParsing/issues/191"""
606+
"""Test for https://github.com/lebrice/SimpleParsing/issues/191."""
600607

601608
parser = ArgumentParser()
602609
parser.add_arguments(Config, dest="config")
@@ -783,11 +790,12 @@ def test_help(
783790
# ModelConfig = _ModelConfig()
784791
# SmallModel = _ModelConfig(num_layers=1, hidden_dim=32)
785792
# BigModel = _ModelConfig(num_layers=32, hidden_dim=128)
786-
787-
# @dataclasses.dataclass
788-
# class Config(TestSetup):
789-
# model: Model = subgroups({"small": SmallModel, "big": BigModel}, default_factory=SmallModel)
790-
793+
#
794+
# @dataclasses.dataclass
795+
# class Config(TestSetup):
796+
# model: Model = subgroups({"small": SmallModel, "big": BigModel},
797+
# default_factory=SmallModel)
798+
#
791799
# assert Config.setup().model == SmallModel()
792800
# # Hopefully this illustrates why Annotated aren't exactly great:
793801
# # At runtime, they are basically the same as the original dataclass when called.
@@ -803,7 +811,7 @@ def test_help(
803811

804812
@pytest.mark.parametrize("frozen", [True, False])
805813
def test_nested_subgroups(frozen: bool):
806-
"""Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160"""
814+
"""Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160."""
807815

808816
@dataclass(frozen=frozen)
809817
class FooConfig:
@@ -978,3 +986,17 @@ def test_parse_with_config_file_with_different_subgroup(
978986

979987
save(value_in_config, config_path, save_dc_types=True)
980988
assert parse(A1OrA2, config_path=config_path, args=args) == expected
989+
990+
991+
@pytest.mark.parametrize(
992+
"value",
993+
[
994+
A1OrA2(),
995+
A1OrA2(a=A2(a_val=2)),
996+
],
997+
)
998+
def test_roundtrip(value: Dataclass):
999+
"""Test to reproduce
1000+
https://github.com/lebrice/SimpleParsing/pull/284#issuecomment-1783490388."""
1001+
assert from_dict(type(value), to_dict(value)) == value
1002+
assert to_dict(from_dict(type(value), to_dict(value))) == to_dict(value)

0 commit comments

Comments
 (0)