Skip to content

Commit b9eb0c3

Browse files
committed
Add test to reproduce the subgroup issue#276
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
1 parent ad007f6 commit b9eb0c3

10 files changed

Lines changed: 57 additions & 21 deletions

test/test_huggingface_compat.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,11 +1304,11 @@ def test_serialization(tmp_path: Path, filename: str, args: TrainingArguments):
13041304
assert load(TrainingArguments, path) == args
13051305

13061306

1307-
@pytest.mark.xfail(
1308-
raises=TypeError,
1309-
strict=True,
1310-
reason="All fields (non-init ones too) are passed to .set_defaults, which raises a TypeError",
1311-
)
1307+
# @pytest.mark.xfail(
1308+
# raises=TypeError,
1309+
# strict=True,
1310+
# reason="All fields (non-init ones too) are passed to .set_defaults, which raises a TypeError",
1311+
# )
13121312
@pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"])
13131313
def test_parse_with_config_file(tmp_path: Path, filetype: str):
13141314
default_args = TrainingArguments(label_smoothing_factor=123.123)

test/test_subgroups.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Callable, TypeVar
1212

1313
import pytest
14+
from simple_parsing.helpers.serialization import save
1415
from pytest_regressions.file_regression import FileRegressionFixture
1516
from typing_extensions import Annotated
1617

@@ -244,7 +245,6 @@ def test_two_subgroups_with_conflict(args_str: str, expected: TwoSubgroupsWithCo
244245

245246

246247
def test_subgroups_with_key_default() -> None:
247-
248248
with pytest.raises(ValueError):
249249
subgroups({"a": A, "b": B}, default_factory="a")
250250

@@ -270,7 +270,9 @@ def test_subgroup_default_needs_to_be_key_in_dict():
270270

271271

272272
def test_subgroup_default_factory_needs_to_be_value_in_dict():
273-
with pytest.raises(ValueError, match="`default_factory` must be a value in the subgroups dict"):
273+
with pytest.raises(
274+
ValueError, match="`default_factory` must be a value in the subgroups dict"
275+
):
274276
_ = subgroups({"a": B, "aa": A}, default_factory=C)
275277

276278

@@ -467,6 +469,7 @@ def test_help_string_displays_default_factory_arguments(
467469
When using `functools.partial` or lambda expressions, we'd ideally also like the help text to
468470
show the field values from inside the `partial` or lambda, if possible.
469471
"""
472+
470473
# NOTE: Here we need to return just A() and B() with these default factories, so the defaults
471474
# for the fields are the same
472475
@dataclass
@@ -585,7 +588,6 @@ class ModelBConfig(ModelConfig):
585588

586589
@dataclass
587590
class Config(TestSetup):
588-
589591
# Which model to use
590592
model: ModelConfig = subgroups(
591593
{"model_a": ModelAConfig, "model_b": ModelBConfig},
@@ -666,7 +668,9 @@ def test_annotated_as_subgroups():
666668

667669
@dataclasses.dataclass
668670
class Config(TestSetup):
669-
model: Model = subgroups({"small": SmallModel, "big": BigModel}, default_factory=SmallModel)
671+
model: Model = subgroups(
672+
{"small": SmallModel, "big": BigModel}, default_factory=SmallModel
673+
)
670674

671675
assert Config.setup().model == SmallModel()
672676
# Hopefully this illustrates why Annotated aren't exactly great:
@@ -880,7 +884,6 @@ class Dataset2Config(DatasetConfig):
880884

881885
@dataclass
882886
class Config(TestSetup):
883-
884887
# Which model to use
885888
model: ModelConfig = subgroups(
886889
{"model_a": ModelAConfig, "model_b": ModelBConfig},
@@ -931,3 +934,39 @@ def test_ordering_of_args_doesnt_matter():
931934
model=ModelAConfig(lr=0.0003, optimizer="Adam", betas=(0.0, 1.0)),
932935
dataset=Dataset2Config(data_dir="data/bar", bar=1.2),
933936
)
937+
938+
939+
@dataclass
940+
class A1:
941+
a_val: int = 1
942+
943+
944+
@dataclass
945+
class A2:
946+
a_val: int = 2
947+
948+
949+
@dataclass
950+
class A1OrA2:
951+
a: A1 | A2 = subgroups({"a1": A1, "a2": A2}, default="a1")
952+
953+
954+
@pytest.mark.parametrize(
955+
("value_in_config", "args", "expected"),
956+
[
957+
(A1OrA2(a=A2()), "", A1OrA2(a=A1())),
958+
(A1OrA2(a=A1()), "", A1OrA2(a=A1())),
959+
],
960+
)
961+
@pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"])
962+
def test_parse_with_config_file_with_different_subgroup(
963+
tmp_path: Path,
964+
filetype: str,
965+
value_in_config: A1OrA2,
966+
args: str,
967+
expected: A1OrA2,
968+
):
969+
config_path = (tmp_path / "bob").with_suffix(filetype)
970+
971+
save(value_in_config, config_path, save_dc_types=True)
972+
assert parse(A1OrA2, config_path=config_path, args=args) == expected

test/test_subgroups/test_help[Config---help].md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
# Regression file for [this test](test/test_subgroups.py:725)
1+
# Regression file for [this test](test/test_subgroups.py:729)
22

33
Given Source code:
44

55
```python
66
@dataclass
77
class Config(TestSetup):
8-
98
# Which model to use
109
model: ModelConfig = subgroups(
1110
{"model_a": ModelAConfig, "model_b": ModelBConfig},

test/test_subgroups/test_help[Config---model=model_a --help].md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
# Regression file for [this test](test/test_subgroups.py:725)
1+
# Regression file for [this test](test/test_subgroups.py:729)
22

33
Given Source code:
44

55
```python
66
@dataclass
77
class Config(TestSetup):
8-
98
# Which model to use
109
model: ModelConfig = subgroups(
1110
{"model_a": ModelAConfig, "model_b": ModelBConfig},

test/test_subgroups/test_help[Config---model=model_b --help].md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
# Regression file for [this test](test/test_subgroups.py:725)
1+
# Regression file for [this test](test/test_subgroups.py:729)
22

33
Given Source code:
44

55
```python
66
@dataclass
77
class Config(TestSetup):
8-
98
# Which model to use
109
model: ModelConfig = subgroups(
1110
{"model_a": ModelAConfig, "model_b": ModelBConfig},

test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:725)
1+
# Regression file for [this test](test/test_subgroups.py:729)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:725)
1+
# Regression file for [this test](test/test_subgroups.py:729)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:725)
1+
# Regression file for [this test](test/test_subgroups.py:729)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:725)
1+
# Regression file for [this test](test/test_subgroups.py:729)
22

33
Given Source code:
44

test/test_subgroups/test_help[ConfigWithFrozen---help].md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Regression file for [this test](test/test_subgroups.py:725)
1+
# Regression file for [this test](test/test_subgroups.py:729)
22

33
Given Source code:
44

0 commit comments

Comments
 (0)