1111from typing import Callable , TypeVar
1212
1313import pytest
14- from simple_parsing .helpers .serialization import save
1514from pytest_regressions .file_regression import FileRegressionFixture
1615from typing_extensions import Annotated
1716
17+ from simple_parsing .utils import Dataclass
18+ from simple_parsing .helpers .serialization import save
1819from simple_parsing import ArgumentParser , parse , subgroups
20+ from simple_parsing .helpers .serialization .serializable import from_dict , to_dict
1921from simple_parsing .wrappers .field_wrapper import ArgumentGenerationMode , NestedMode
2022
2123from .test_choice import Color
@@ -190,7 +192,7 @@ def test_parse(dataclass_type: type[TestClass], args: str, expected: TestClass):
190192
191193
192194def test_subgroup_choice_is_saved_on_namespace ():
193- """test for https://github.com/lebrice/SimpleParsing/issues/139
195+ """Test for https://github.com/lebrice/SimpleParsing/issues/139.
194196
195197 Need to save the chosen subgroup name somewhere on the args.
196198 """
@@ -278,7 +280,9 @@ def test_subgroup_default_factory_needs_to_be_value_in_dict():
278280
279281def test_lambdas_dont_return_same_instance ():
280282 """Slightly unrelated, but I just want to check if lambda expressions return the same object
281- instance when a default factory looks like `lambda: A()`. If so, then I won't encourage this.
283+ instance when a default factory looks like `lambda: A()`.
284+
285+ If so, then I won't encourage this.
282286 """
283287
284288 @dataclass
@@ -294,8 +298,7 @@ class Config(TestSetup):
294298
295299def test_partials_new_args_overwrite_set_values ():
296300 """Double-check that functools.partial overwrites the keywords that are stored when it is
297- created with the ones that are passed when calling it.
298- """
301+ created with the ones that are passed when calling it."""
299302 # just to avoid the test passing if I were to hard-code the same value as the default by
300303 # accident.
301304 default_a = A ().a
@@ -425,7 +428,10 @@ class Foo(TestSetup):
425428 marks = pytest .mark .xfail (
426429 strict = True ,
427430 raises = NotImplementedError ,
428- reason = "Lambda expressions aren't allowed in the subgroup dict or default_factory at the moment." ,
431+ reason = (
432+ "Lambda expressions aren't allowed in the subgroup dict or default_factory at the "
433+ "moment."
434+ ),
429435 ),
430436)
431437
@@ -440,7 +446,8 @@ class Foo(TestSetup):
440446 ],
441447)
442448def test_other_default_factories (a_factory : Callable [[], A ], b_factory : Callable [[], B ]):
443- """Test using other kinds of default factories (i.e. functools.partial or lambda expressions)"""
449+ """Test using other kinds of default factories (i.e. functools.partial or lambda
450+ expressions)"""
444451
445452 @dataclass
446453 class Foo (TestSetup ):
@@ -596,7 +603,7 @@ class Config(TestSetup):
596603
597604
598605def test_destination_substring_of_other_destination_issue191 ():
599- """Test for https://github.com/lebrice/SimpleParsing/issues/191"""
606+ """Test for https://github.com/lebrice/SimpleParsing/issues/191. """
600607
601608 parser = ArgumentParser ()
602609 parser .add_arguments (Config , dest = "config" )
@@ -783,11 +790,12 @@ def test_help(
783790# ModelConfig = _ModelConfig()
784791# SmallModel = _ModelConfig(num_layers=1, hidden_dim=32)
785792# BigModel = _ModelConfig(num_layers=32, hidden_dim=128)
786-
787- # @dataclasses.dataclass
788- # class Config(TestSetup):
789- # model: Model = subgroups({"small": SmallModel, "big": BigModel}, default_factory=SmallModel)
790-
793+ #
794+ # @dataclasses.dataclass
795+ # class Config(TestSetup):
796+ # model: Model = subgroups({"small": SmallModel, "big": BigModel},
797+ # default_factory=SmallModel)
798+ #
791799# assert Config.setup().model == SmallModel()
792800# # Hopefully this illustrates why Annotated aren't exactly great:
793801# # At runtime, they are basically the same as the original dataclass when called.
@@ -803,7 +811,7 @@ def test_help(
803811
804812@pytest .mark .parametrize ("frozen" , [True , False ])
805813def test_nested_subgroups (frozen : bool ):
806- """Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160"""
814+ """Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160. """
807815
808816 @dataclass (frozen = frozen )
809817 class FooConfig :
@@ -978,3 +986,17 @@ def test_parse_with_config_file_with_different_subgroup(
978986
979987 save (value_in_config , config_path , save_dc_types = True )
980988 assert parse (A1OrA2 , config_path = config_path , args = args ) == expected
989+
990+
991+ @pytest .mark .parametrize (
992+ "value" ,
993+ [
994+ A1OrA2 (),
995+ A1OrA2 (a = A2 (a_val = 2 )),
996+ ],
997+ )
998+ def test_roundtrip (value : Dataclass ):
999+ """Test to reproduce
1000+ https://github.com/lebrice/SimpleParsing/pull/284#issuecomment-1783490388."""
1001+ assert from_dict (type (value ), to_dict (value )) == value
1002+ assert to_dict (from_dict (type (value ), to_dict (value ))) == to_dict (value )
0 commit comments