1111from typing import Callable , TypeVar
1212
1313import pytest
14+ from simple_parsing .helpers .serialization import save
1415from pytest_regressions .file_regression import FileRegressionFixture
1516from typing_extensions import Annotated
1617
@@ -244,7 +245,6 @@ def test_two_subgroups_with_conflict(args_str: str, expected: TwoSubgroupsWithCo
244245
245246
246247def test_subgroups_with_key_default () -> None :
247-
248248 with pytest .raises (ValueError ):
249249 subgroups ({"a" : A , "b" : B }, default_factory = "a" )
250250
@@ -270,7 +270,9 @@ def test_subgroup_default_needs_to_be_key_in_dict():
270270
271271
272272def test_subgroup_default_factory_needs_to_be_value_in_dict ():
273- with pytest .raises (ValueError , match = "`default_factory` must be a value in the subgroups dict" ):
273+ with pytest .raises (
274+ ValueError , match = "`default_factory` must be a value in the subgroups dict"
275+ ):
274276 _ = subgroups ({"a" : B , "aa" : A }, default_factory = C )
275277
276278
@@ -467,6 +469,7 @@ def test_help_string_displays_default_factory_arguments(
467469 When using `functools.partial` or lambda expressions, we'd ideally also like the help text to
468470 show the field values from inside the `partial` or lambda, if possible.
469471 """
472+
470473 # NOTE: Here we need to return just A() and B() with these default factories, so the defaults
471474 # for the fields are the same
472475 @dataclass
@@ -585,7 +588,6 @@ class ModelBConfig(ModelConfig):
585588
586589@dataclass
587590class Config (TestSetup ):
588-
589591 # Which model to use
590592 model : ModelConfig = subgroups (
591593 {"model_a" : ModelAConfig , "model_b" : ModelBConfig },
@@ -666,7 +668,9 @@ def test_annotated_as_subgroups():
666668
667669 @dataclasses .dataclass
668670 class Config (TestSetup ):
669- model : Model = subgroups ({"small" : SmallModel , "big" : BigModel }, default_factory = SmallModel )
671+ model : Model = subgroups (
672+ {"small" : SmallModel , "big" : BigModel }, default_factory = SmallModel
673+ )
670674
671675 assert Config .setup ().model == SmallModel ()
672676 # Hopefully this illustrates why Annotated aren't exactly great:
@@ -880,7 +884,6 @@ class Dataset2Config(DatasetConfig):
880884
881885@dataclass
882886class Config (TestSetup ):
883-
884887 # Which model to use
885888 model : ModelConfig = subgroups (
886889 {"model_a" : ModelAConfig , "model_b" : ModelBConfig },
@@ -931,3 +934,39 @@ def test_ordering_of_args_doesnt_matter():
931934 model = ModelAConfig (lr = 0.0003 , optimizer = "Adam" , betas = (0.0 , 1.0 )),
932935 dataset = Dataset2Config (data_dir = "data/bar" , bar = 1.2 ),
933936 )
937+
938+
939+ @dataclass
940+ class A1 :
941+ a_val : int = 1
942+
943+
944+ @dataclass
945+ class A2 :
946+ a_val : int = 2
947+
948+
949+ @dataclass
950+ class A1OrA2 :
951+ a : A1 | A2 = subgroups ({"a1" : A1 , "a2" : A2 }, default = "a1" )
952+
953+
954+ @pytest .mark .parametrize (
955+ ("value_in_config" , "args" , "expected" ),
956+ [
957+ (A1OrA2 (a = A2 ()), "" , A1OrA2 (a = A1 ())),
958+ (A1OrA2 (a = A1 ()), "" , A1OrA2 (a = A1 ())),
959+ ],
960+ )
961+ @pytest .mark .parametrize ("filetype" , [".yaml" , ".json" , ".pkl" ])
962+ def test_parse_with_config_file_with_different_subgroup (
963+ tmp_path : Path ,
964+ filetype : str ,
965+ value_in_config : A1OrA2 ,
966+ args : str ,
967+ expected : A1OrA2 ,
968+ ):
969+ config_path = (tmp_path / "bob" ).with_suffix (filetype )
970+
971+ save (value_in_config , config_path , save_dc_types = True )
972+ assert parse (A1OrA2 , config_path = config_path , args = args ) == expected
0 commit comments