diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 1d9920a230..08c6fc296c 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -14,6 +14,7 @@ import json import re from collections.abc import Sequence +from copy import copy as _copy from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Any @@ -35,6 +36,204 @@ _default_globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"} +def _identity(value: Any) -> Any: + """Module-level reconstructor used by ``_ConfigProxy.__reduce__`` so proxies pickle as their raw value.""" + return value + + +def _wrap_parsed(parser: ConfigParser, id: str, value: Any) -> Any: + """ + Wrap a parsed dict/list in a :class:`_ConfigProxy` so nested access keeps chaining; pass scalars through. + + Args: + parser: the owning :class:`ConfigParser`, used to resolve chained ids. + id: the ``::``-separated id that produced ``value``. + value: the parsed content to wrap. + + Returns: + A :class:`_ConfigProxy` wrapping ``value`` if it is a ``dict`` or ``list``, + otherwise ``value`` unchanged. + """ + if isinstance(value, (dict, list)): + return _ConfigProxy(parser, id, value) + return value + + +class _ConfigProxy: + """ + Proxy that enables dot-notation and bracket-notation access to nested config structures. + + When :meth:`ConfigParser.__getattr__` resolves to a ``dict`` or ``list``, the result is + wrapped in this proxy so that further attribute and index access chains through the + config hierarchy using :meth:`ConfigParser.get_parsed_content`. For example:: + + parser.training.trainer.max_epochs + # equivalent to + parser.get_parsed_content("training::trainer::max_epochs") + + parser.transforms[0].keys # list indexing chains too + parser.A.B["C"] = 99 # writes update the config source + del parser.A.B["C"] # deletes update the config source + + Type caveat: + Accessing a ``dict``/``list`` member through a :class:`ConfigParser` now returns a + ``_ConfigProxy``, not the raw container, so ``type(parser.A)`` is ``_ConfigProxy`` + and ``isinstance(parser.A, dict)`` is ``False``. Code that needs the real container + should use ``parser.A._raw`` (read-only view) or ``parser.get_parsed_content("A")``. + + Precedence and fallback: + Config keys take precedence over ``dict``/``list`` attributes and methods. If a + config key is not found, the proxy falls back to the underlying ``dict``/``list`` + so that container methods (``.keys()``, ``.items()`` ...) and native indexing + semantics (``IndexError``, negative indices, dict ``KeyError``) still work. A + config key that collides with a container method name (e.g. ``"keys"``) shadows + that method on attribute access; access it via bracket notation, + :meth:`ConfigParser.get_parsed_content`, or ``._raw``. + + Writes: + ``__setitem__``/``__setattr__``/``__delitem__``/``__delattr__`` write through to + the config *source* (via :class:`ConfigParser`) and reset the reference resolver, + so the change is visible from both ``parser.`` and + ``parser.get_parsed_content("")``. + """ + + _INTERNAL = ("_parser", "_id", "_value") + + def __init__(self, parser: ConfigParser, id: str, value: Any): + """ + Args: + parser: the owning :class:`ConfigParser`. + id: the ``::``-separated id this proxy represents. + value: the parsed ``dict``/``list`` content this proxy wraps. + """ + self._parser = parser + self._id = id + self._value = value + + def _child_id(self, key: str | int) -> str: + return f"{self._id}{ID_SEP_KEY}{key}" + + def _backing_id(self) -> str: + """Return the real config id this proxy writes to, resolving all ``$@ref`` hops transitively.""" + current = self._id + seen: set[str] = set() + while True: + if current in seen: + break + seen.add(current) + raw = self._parser[current] + if not isinstance(raw, str): + break + refs = ReferenceResolver.match_refs_pattern(raw) + if not refs: + break + current = next(iter(refs)) + return current + + def _chain(self, key: str) -> Any: + """ + Resolve ``key`` as a nested config id. + + Args: + key: the child key/index. + + Returns: + The parsed child content, wrapped via :func:`_wrap_parsed`. + + Raises: + KeyError: if there is no config item at the chained id. + """ + new_id = self._child_id(key) + return _wrap_parsed(self._parser, new_id, self._parser.get_parsed_content(new_id)) + + def __getattr__(self, key: str) -> Any: + """ + Resolve ``key`` as a nested config attribute, falling back to the underlying container. + + Dunder names are never treated as config keys, so the proxy stays well-behaved + with ``copy``/``pickle``/``hasattr`` and other stdlib introspection. + + Raises: + AttributeError: if ``key`` is neither a config key nor an attribute of the + underlying ``dict``/``list``. + """ + if key.startswith("__") and key.endswith("__"): + raise AttributeError(key) + try: + return self._chain(key) + except KeyError: + return getattr(self._value, key) + + def __getitem__(self, key: str | int) -> Any: + try: + return self._chain(str(key)) + except KeyError: + # no config key of that name: defer to the underlying dict/list so normal + # indexing semantics apply (IndexError, negative indices, dict KeyError). + return self._value[key] + + def __setitem__(self, key: str | int, value: Any) -> None: + # Write directly to the backing container so literal dict keys are preserved, + # matching the semantics of __delitem__ and __getitem__. + backing = self._backing_id() + node = self._parser[backing] + node[key if isinstance(node, dict) else int(key)] = value + self._parser.ref_resolver.reset() + + def __delitem__(self, key: str | int) -> None: + backing = self._backing_id() + node = self._parser[backing] + del node[key if isinstance(node, dict) else int(key)] + self._parser.ref_resolver.reset() + + def __setattr__(self, key: str, value: Any) -> None: + if key in _ConfigProxy._INTERNAL: + object.__setattr__(self, key, value) + return + if key == "_raw": + raise AttributeError("_raw is read-only") + self[key] = value + + def __delattr__(self, key: str) -> None: + if key == "_raw": + raise AttributeError("_raw is read-only") + del self[key] + + def __len__(self) -> int: + return len(self._value) + + def __iter__(self) -> Any: + return iter(self._value) + + def __contains__(self, item: object) -> bool: + return item in self._value + + def __bool__(self) -> bool: + return bool(self._value) + + def __repr__(self) -> str: + return repr(self._value) + + def __eq__(self, other: object) -> Any: + if isinstance(other, _ConfigProxy): + other = other._value + return self._value == other + + def __copy__(self) -> Any: + return _copy(self._value) + + def __deepcopy__(self, memo: Any) -> Any: + return deepcopy(self._value, memo) + + def __reduce__(self) -> Any: + return (_identity, (self._value,)) + + @property + def _raw(self) -> Any: + """The underlying ``dict``/``list`` container (the reference is read-only; the container contents are not copied).""" + return self._value + + class ConfigParser: """ The primary configuration parser. It traverses a structured config (in the form of nested Python dict or list), @@ -127,14 +326,23 @@ def __getattr__(self, id): """ Get the parsed result of ``ConfigItem`` with the specified ``id`` with default arguments (e.g. ``lazy=True``, ``instantiate=True`` and ``eval_expr=True``). + When the result is a dict or list, it is wrapped in a ``_ConfigProxy`` so that + nested attributes and indices chain through the config hierarchy. + For example, ``parser.training.trainer.max_epochs`` is equivalent to + ``parser.get_parsed_content("training::trainer::max_epochs")``. Args: id: id of the ``ConfigItem``. + Returns: + The parsed content (instance, evaluated expression, or config value). When it + is a ``dict`` or ``list`` it is wrapped in a :class:`_ConfigProxy` so nested + attributes/indices chain through the config hierarchy. + See also: :py:meth:`get_parsed_content` """ - return self.get_parsed_content(id) + return _wrap_parsed(self, id, self.get_parsed_content(id)) def __getitem__(self, id: str | int) -> Any: """ diff --git a/tests/bundle/test_config_parser.py b/tests/bundle/test_config_parser.py index 5ead2af382..546957ba7e 100644 --- a/tests/bundle/test_config_parser.py +++ b/tests/bundle/test_config_parser.py @@ -11,7 +11,9 @@ from __future__ import annotations +import copy import os +import pickle import tempfile import unittest import warnings @@ -388,5 +390,117 @@ def test_load_configs( self.assertEqual(parser["key2"], expected_merged_vals) +class TestConfigProxy(unittest.TestCase): + """Nested dot-/bracket-notation access on ConfigParser (issue #6837).""" + + def setUp(self): + self.config = { + "A": {"B": {"C": 1, "D": [10, 20]}}, + "training": {"trainer": {"max_epochs": 100, "lr": 0.001}}, + "transforms": [{"keys": "image"}, {"keys": "label"}], + "my_dims": 2, + "dims_1": "$@my_dims + 1", + } + self.parser = ConfigParser(config=self.config, globals={"monai": "monai"}) + + def test_nested_attribute_access(self): + self.assertEqual(self.parser.A.B.C, 1) + self.assertEqual(self.parser.training.trainer.max_epochs, 100) + self.assertEqual(self.parser.training.trainer.lr, 0.001) + self.assertEqual(self.parser.dims_1, 3) + + def test_nested_index_access(self): + self.assertEqual(self.parser.A.B.D[0], 10) + self.assertEqual(self.parser.A.B.D[1], 20) + self.assertEqual(self.parser.transforms[0].keys, "image") + self.assertEqual(self.parser.transforms[1].keys, "label") + + def test_raw_and_container_protocol(self): + self.assertEqual(self.parser.A._raw, {"B": {"C": 1, "D": [10, 20]}}) + self.assertEqual(len(self.parser.A.B.D), 2) + self.assertEqual(list(self.parser.A.B.D), [10, 20]) + self.assertIn("B", self.parser.A) + self.assertTrue(self.parser.A.B.D) + self.assertFalse(ConfigParser(config={"e": []}, globals={"monai": "monai"}).e) + + def test_native_index_fallback(self): + # bracket access falls back to native container semantics when there is no + # config key of that name: negative indexing still works. + self.assertEqual(self.parser.A.B.D[-1], 20) + + def test_attribute_write_through(self): + # attribute assignment updates the config source and is visible from both + # ``parser.`` and ``get_parsed_content``. + self.parser.A.X = [2, 3] + self.assertEqual(self.parser.A.X, [2, 3]) + self.assertIn("X", self.parser.get_parsed_content("A")) + self.assertEqual(self.parser.get_parsed_content("A::X"), [2, 3]) + + def test_item_write_through(self): + self.parser.A.B["C"] = 99 + self.assertEqual(self.parser.A.B.C, 99) + self.assertEqual(self.parser.get_parsed_content("A::B::C"), 99) + self.parser.A.B.D[0] = 11 + self.assertEqual(self.parser.A.B.D._raw, [11, 20]) + self.assertEqual(self.parser.get_parsed_content("A::B::D"), [11, 20]) + + def test_delete_write_through(self): + del self.parser.A.B["C"] + self.assertNotIn("C", self.parser.get_parsed_content("A::B")) + del self.parser.training.trainer + self.assertNotIn("trainer", self.parser.get_parsed_content("training")) + + def test_copy_and_pickle_yield_raw_container(self): + # proxies copy/pickle as their underlying container (pre-proxy behaviour). + a = self.parser.A + self.assertEqual(copy.copy(a), {"B": {"C": 1, "D": [10, 20]}}) + self.assertEqual(copy.deepcopy(a), {"B": {"C": 1, "D": [10, 20]}}) + self.assertEqual(pickle.loads(pickle.dumps(a)), {"B": {"C": 1, "D": [10, 20]}}) # trusted in-process roundtrip + + def test_config_key_shadows_container_method(self): + # a config key named like a dict method shadows it on attribute access; + # use bracket notation / ._raw to reach the real container. + parser = ConfigParser(config={"sec": {"keys": "image"}}, globals={"monai": "monai"}) + self.assertEqual(parser.sec.keys, "image") + self.assertEqual(parser.sec["keys"], "image") + self.assertEqual(list(parser.sec._raw.keys()), ["keys"]) + + def test_ref_backed_proxy_write_through(self): + # Writes/deletes on a proxy reached via $@ref must update the real backing config + # node (i.e. "target"), not crash on the raw ref string (regression for the @ref + # write crash: parser.alias["x"] = ... raised ValueError before this fix). + parser = ConfigParser(config={"target": {"x": 1, "y": 2}, "alias": "$@target"}, globals={"monai": "monai"}) + parser.alias["x"] = 99 + # The change must be visible via both the backing id and a fresh alias proxy. + self.assertEqual(parser.get_parsed_content("target::x"), 99) + self.assertEqual(parser.alias["x"], 99) + del parser.alias["y"] + self.assertNotIn("y", parser.get_parsed_content("target")) + + def test_chained_ref_backed_proxy_write_through(self): + # _backing_id() must follow the full ref chain, not just one hop. + parser = ConfigParser( + config={"target": {"x": 1, "y": 2}, "mid": "$@target", "alias": "$@mid"}, globals={"monai": "monai"} + ) + parser.alias["x"] = 99 + self.assertEqual(parser.get_parsed_content("target::x"), 99) + del parser.alias["y"] + self.assertNotIn("y", parser.get_parsed_content("target")) + + def test_raw_is_read_only(self): + with self.assertRaises(AttributeError): + self.parser.A._raw = {"something": "else"} + with self.assertRaises(AttributeError): + del self.parser.A._raw + + def test_missing_raises(self): + with self.assertRaises(IndexError): + _ = self.parser.A.B.D[5] + with self.assertRaises(KeyError): + _ = self.parser.A.B["nonexistent"] + with self.assertRaises(AttributeError): + _ = self.parser.A.nonexistent + + if __name__ == "__main__": unittest.main()