From 1ed0c9cee2215172b10dc4b6fb38d16051ad6acc Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Thu, 7 May 2026 11:11:10 +0200 Subject: [PATCH 01/16] basic support for AnnData accessors this should keep scanpy plotting working with MuData objects --- pyproject.toml | 2 +- src/mudata/_core/mudata.py | 26 ++++++++++++++++++++++++-- tests/test_obs_var.py | 13 +++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 364f3d9..ac9db94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dev = [ "pre-commit", "twine>=4.0.2", ] -test = [ "coverage>=7.10", "mudata[io]", "pytest" ] +test = [ "coverage>=7.10", "mudata[io]", "packaging", "pytest" ] doc = [ "docutils>=0.8,!=0.18.*,!=0.19.*", "ipykernel", diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index d5fd1f6..e31d0dc 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -533,8 +533,30 @@ def strings_to_categoricals(self, df: pd.DataFrame | None = None) -> pd.DataFram def __getitem__(self, index) -> AnnData | MuData: if isinstance(index, str): return self._mod[index] - else: - return MuData(self, as_view=True, index=index) + + with suppress(ImportError): + from anndata.acc import AdRef + + if isinstance(index, AdRef): + try: + return index.acc.get(self, index.idx) + except KeyError as e: + if index.acc.dim in ("obs", "var"): + for modname, mod in self._mod.items(): + try: + index.acc.get(mod, index.idx) + except KeyError: + pass + else: + raise KeyError( + f"There is no key {index.idx} in MuData .{index.acc.dim} but there is one in {modname} .{index.acc.dim}. Consider running `pull_{index.acc.dim}()` to update global .{index.acc.dim}." + ) from e + raise KeyError( + f"There is no key {index.idx} in MuData .{index.acc.dim} or in .{index.acc.dim} of any modalities." + ) from e + else: + raise + return MuData(self, as_view=True, index=index) @property def mod(self) -> Mapping[str, AnnData | MuData]: diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 3057c50..40a97e4 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -1,8 +1,10 @@ from pathlib import Path +import anndata as ad import numpy as np import pandas as pd import pytest +from packaging.version import Version import mudata as md @@ -145,3 +147,14 @@ def test_names_make_unique(mdata: md.MuData): with pytest.raises(TypeError, match="axis="): getattr(mdata, f"{attr}_names_make_unique")() + + +@pytest.mark.skipif( + Version(ad.__version__) < Version("0.13dev0"), reason="anndata version too old, no accessor support" +) +def test_accessors(mdata: md.MuData): + assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() + with pytest.raises(KeyError, match="any modalities"): + mdata[ad.acc.A.var["test"]] + with pytest.raises(KeyError, match="there is one in"): + mdata[ad.acc.A.var["assert-bool"]] From a079a29946ee8376d84f9d002c018d85eb7a2fe3 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 8 May 2026 09:02:18 +0200 Subject: [PATCH 02/16] simplify --- src/mudata/_core/mudata.py | 12 ++---------- tests/test_obs_var.py | 2 +- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index e31d0dc..e556e5c 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -543,19 +543,11 @@ def __getitem__(self, index) -> AnnData | MuData: except KeyError as e: if index.acc.dim in ("obs", "var"): for modname, mod in self._mod.items(): - try: - index.acc.get(mod, index.idx) - except KeyError: - pass - else: + if index in mod: raise KeyError( f"There is no key {index.idx} in MuData .{index.acc.dim} but there is one in {modname} .{index.acc.dim}. Consider running `pull_{index.acc.dim}()` to update global .{index.acc.dim}." ) from e - raise KeyError( - f"There is no key {index.idx} in MuData .{index.acc.dim} or in .{index.acc.dim} of any modalities." - ) from e - else: - raise + raise return MuData(self, as_view=True, index=index) @property diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 40a97e4..451148a 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -154,7 +154,7 @@ def test_names_make_unique(mdata: md.MuData): ) def test_accessors(mdata: md.MuData): assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() - with pytest.raises(KeyError, match="any modalities"): + with pytest.raises(KeyError, match="test"): mdata[ad.acc.A.var["test"]] with pytest.raises(KeyError, match="there is one in"): mdata[ad.acc.A.var["assert-bool"]] From ac1fbd2449b36636e9d912489ec075ea5956546a Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 8 May 2026 09:45:11 +0200 Subject: [PATCH 03/16] implement __contains__ --- src/mudata/_core/mudata.py | 10 ++++++++++ tests/test_obs_var.py | 1 + tests/test_update.py | 3 +++ 3 files changed, 14 insertions(+) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index e556e5c..5e25ad8 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -550,6 +550,16 @@ def __getitem__(self, index) -> AnnData | MuData: raise return MuData(self, as_view=True, index=index) + def __contains__(self, key) -> bool: + if isinstance(key, str): + return key in self._mod + with suppress(ImportError): + from anndata.acc import AdRef, MapAcc, RefAcc + + if isinstance(key, AdRef | RefAcc | MapAcc): + return AnnData.__contains__(self, key) + raise TypeError(f"Unexpected key {key!r}.") + @property def mod(self) -> Mapping[str, AnnData | MuData]: """Dictionary of modalities.""" diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 451148a..4636b4f 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -153,6 +153,7 @@ def test_names_make_unique(mdata: md.MuData): Version(ad.__version__) < Version("0.13dev0"), reason="anndata version too old, no accessor support" ) def test_accessors(mdata: md.MuData): + assert ad.acc.A.obs["arange"] in mdata assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() with pytest.raises(KeyError, match="test"): mdata[ad.acc.A.var["test"]] diff --git a/tests/test_update.py b/tests/test_update.py index 0a32eb0..318ff49 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -146,6 +146,9 @@ def test_update_simple(mdata: MuData, axis: Axis): for mod in mdata.mod.keys(): assert mdata.obsmap[mod].dtype.kind == "u" assert mdata.varmap[mod].dtype.kind == "u" + assert mod in mdata + with pytest.raises(TypeError): + 1 in mdata # noqa: B015 # names along non-axis are concatenated assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) From 703f0e132fb740ee71d1f4dffc73cf20e9b52899 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 13:17:54 +0200 Subject: [PATCH 04/16] add MuData accessors --- src/mudata/__init__.py | 5 ++ src/mudata/_core/mudata.py | 12 ++- src/mudata/acc/__init__.py | 172 +++++++++++++++++++++++++++++++++++++ tests/conftest.py | 6 +- tests/test_accessors.py | 84 ++++++++++++++++++ tests/test_obs_var.py | 14 --- 6 files changed, 275 insertions(+), 18 deletions(-) create mode 100644 src/mudata/acc/__init__.py create mode 100644 tests/test_accessors.py diff --git a/src/mudata/__init__.py b/src/mudata/__init__.py index 6f4b046..e8be205 100644 --- a/src/mudata/__init__.py +++ b/src/mudata/__init__.py @@ -1,5 +1,7 @@ """Multimodal datasets""" +from contextlib import suppress + from anndata import AnnData from scverse_misc import ExtensionNamespace from scverse_misc import make_register_namespace_decorator as _make_register_namespace_decorator @@ -23,6 +25,9 @@ from ._core.to_ import to_anndata, to_mudata from ._version import __version__, __version_tuple__ +with suppress(ImportError): + from . import acc + # file format versions __anndataversion__ = "0.1.0" __mudataversion__ = "0.1.0" diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 5e25ad8..d5a4463 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -555,9 +555,17 @@ def __contains__(self, key) -> bool: return key in self._mod with suppress(ImportError): from anndata.acc import AdRef, MapAcc, RefAcc - - if isinstance(key, AdRef | RefAcc | MapAcc): + from ..acc import ModAcc, ModMapAcc, _ModalityMapAcc, _ModalityMixin + + if isinstance(key, ModAcc | _ModalityMapAcc): + return key.isin(self) + elif isinstance(key, _ModalityMixin): + return AnnData.__contains__(self.mod[key.mod], key) + elif isinstance(key, ModMapAcc): + return bool(self.mod) + elif isinstance(key, AdRef | RefAcc | MapAcc): return AnnData.__contains__(self, key) + raise TypeError(f"Unexpected key {key!r}.") @property diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py new file mode 100644 index 0000000..e6ee63e --- /dev/null +++ b/src/mudata/acc/__init__.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from dataclasses import KW_ONLY, dataclass, field +from typing import TYPE_CHECKING + +import pandas as pd +from anndata.acc import ( + AdAcc, + AdRef, + GraphAcc, + GraphMapAcc, + Idx2D, + LayerAcc, + LayerMapAcc, + MapAcc, + MetaAcc, + MultiAcc, + MultiMapAcc, +) +from anndata.compat import XVariable +from anndata.typing import InMemoryArray + +if TYPE_CHECKING: + from anndata import AnnData + + from .. import MuData + + +@dataclass(frozen=True, kw_only=True) +class _ModalityMixin: + mod: str + + +@dataclass(frozen=True) +class _ModalityMapAcc[I, R](_ModalityMixin): + def isin(self, mdata: MuData, idx: I | None = None) -> bool: + if self.mod not in mdata.mod: + return False + else: + return super().isin(mdata[self.mod], idx) + + def get(self, mdata: MuData, idx: I, /) -> R: + return super().get(mdata[self.mod], idx) + + +@dataclass(frozen=True) +class ModLayerAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], LayerAcc[R]): + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].X" if self.k is None else f"A.mod[{self.mod}].layers[{self.k!r}]" + + +@dataclass(frozen=True, kw_only=True) +class ModLayerMapAcc[R: AdRef](_ModalityMixin, LayerMapAcc[R]): + ref_acc_cls: type[ModLayerAcc] = ModLayerAcc + + def __getitem__(self, k: str | None, /) -> ModLayerAcc[R]: + if not isinstance(k, str | None): + raise TypeError(f"Unsupported layer {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].layers" + + +@dataclass(frozen=True) +class ModMetaAcc[R: AdRef[str | None]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}" + + +@dataclass(frozen=True) +class ModMultiAcc[R: AdRef[int]](_ModalityMapAcc[int, InMemoryArray], MultiAcc[R]): + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}m[self.k!r]" + + +@dataclass(frozen=True, kw_only=True) +class ModMultiMapAcc[R: AdRef](_ModalityMixin, MultiMapAcc[R]): + ref_acc_cls: type[ModMultiAcc] = ModMultiAcc + + def __getitem__(self, k: str, /) -> ModMultiAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported {self.dim}m key {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, dim=self.dim, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}m" + + +@dataclass(frozen=True) +class ModGraphAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], GraphAcc[R]): + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}p[{self.k!r}]" + + +@dataclass(frozen=True, kw_only=True) +class ModGraphMapAcc[R: AdRef](_ModalityMixin, GraphMapAcc[R]): + ref_acc_cls: type[ModGraphAcc] = ModGraphAcc + + def __getitem__(self, k: str, /) -> ModGraphAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported {self.dim}p key {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, dim=self.dim, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}p" + + +@dataclass(frozen=True, kw_only=True) +class ModAcc[R: AdRef](_ModalityMixin, AdAcc[R]): + layer_cls: type[ModLayerAcc] = ModLayerAcc + meta_cls: type[ModMetaAcc] = ModMetaAcc + multi_cls: type[ModMultiAcc] = ModMultiAcc + graph_cls: type[ModGraphAcc] = ModGraphAcc + + def isin(self, mdata: MuData) -> bool: + return self.mod in mdata.mod + + def get(self, mdata: MuData) -> ad.AnnData: + return mdata.mod[self.mod] + + def __post_init__(self) -> None: + x = self.layer_cls(mod=self.mod, k=None, ref_class=self.ref_class) + layers = ModLayerMapAcc(mod=self.mod, ref_class=self.ref_class, ref_acc_cls=self.layer_cls) + object.__setattr__(self, "X", x) + object.__setattr__(self, "layers", layers) + for dim in ("obs", "var"): + meta = self.meta_cls(mod=self.mod, dim=dim, ref_class=self.ref_class) + multi = ModMultiMapAcc(mod=self.mod, dim=dim, ref_class=self.ref_class, ref_acc_cls=self.multi_cls) + graphs = ModGraphMapAcc(mod=self.mod, dim=dim, ref_class=self.ref_class, ref_acc_cls=self.graph_cls) + object.__setattr__(self, dim, meta) + object.__setattr__(self, f"{dim}m", multi) + object.__setattr__(self, f"{dim}p", graphs) + + def __repr__(self) -> str: + return f"A.mod[{self.mod}]" + + +@dataclass(frozen=True) +class ModMapAcc[R: AdRef](MapAcc[ModAcc[R]]): + ref_class: type[R] + ref_acc_cls: type[ModAcc] = ModAcc + + def __getitem__(self, k: str, /) -> ModAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported mod key {k!r}") + return self.ref_acc_cls(mod=k, ref_class=self.ref_class) + + def __repr__(self) -> str: + return "A.mod" + + +@dataclass(frozen=True) +class MuAcc[R: AdRef](AdAcc[R]): + mod_cls: type[ModAcc] = ModAcc + """Class to use for `mod` accessors.""" + + mod: ModMapAcc[R] = field(init=False) + + def __post_init__(self) -> None: + super().__post_init__() + mod = ModMapAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls) + object.__setattr__(self, "mod", mod) + + def __getitem__(self, k: str, /) -> ModAcc[R]: + return self.mod[k] + + def __repr__(self) -> str: + return "A" + + +A = MuAcc() diff --git a/tests/conftest.py b/tests/conftest.py index d637007..9a4bc66 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,10 +37,12 @@ def rng() -> np.random.Generator: def mdata(rng: np.random.Generator, request: pytest.FixtureRequest) -> MuData: axis = getattr(request, "param", 0) mod1 = AnnData( - np.arange(0, 200, 0.1).reshape(-1, 20), obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)) + np.arange(0, 200, 0.1).reshape(-1, 20), + obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False).astype(str)), ) mod2 = AnnData( - np.arange(101, 3101, 1).reshape(-1, 30), obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)) + np.arange(101, 3101, 1).reshape(-1, 30), + obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False).astype(str)), ) mod1.var["assert-bool"] = True mod2.var["assert-bool"] = False diff --git a/tests/test_accessors.py b/tests/test_accessors.py new file mode 100644 index 0000000..cf5a4b3 --- /dev/null +++ b/tests/test_accessors.py @@ -0,0 +1,84 @@ +import anndata as ad +import numpy as np +import pandas as pd +import pytest +from packaging.version import Version + +import mudata as md +from mudata.acc import A + +if Version(ad.__version__) < Version("0.13dev0"): + pytest.skip("anndata version too old, no accessor support", allow_module_level=True) + + +@pytest.fixture +def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): + mdata["mod1"].layers["counts"] = rng.poisson(1, size=mdata["mod1"].shape) + mdata["mod2"].obsp["test"] = rng.normal(size=(mdata["mod2"].n_obs, mdata["mod2"].n_obs)) + + return mdata + + +def test_anndata_accessors(mdata: md.MuData): + assert ad.acc.A.obs["arange"] in mdata + assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() + with pytest.raises(KeyError, match="test"): + mdata[ad.acc.A.var["test"]] + with pytest.raises(KeyError, match="there is one in"): + mdata[ad.acc.A.var["assert-bool"]] + + +PATHS = [ + (A.mod["mod1"], lambda md: md.mod["mod1"]), + (A["mod1"], lambda md: md["mod1"]), + (A.mod["mod1"].var, lambda md: md.mod["mod1"].var), + (A.mod["mod1"].var["assert-bool"], lambda md: md.mod["mod1"].var["assert-bool"]), + (A.mod["mod1"].X, lambda md: md.mod["mod1"].X), + (A.mod["mod1"].X["obs_2", :], lambda md: md.mod["mod1"]["obs_2", :].X.squeeze()), + (A.mod["mod1"].X[:, "mod1_var_1"], lambda md: md.mod["mod1"][:, "mod1_var_1"].X.squeeze()), + (A["mod1"].layers, lambda md: md["mod1"].layers), + (A["mod1"].layers["counts"], lambda md: md["mod1"].layers["counts"]), + (A["mod1"].layers["counts"]["obs_2", :], lambda md: md["mod1"]["obs_2", :].layers["counts"].squeeze()), + (A["mod2"].obsp, lambda md: md["mod2"].obsp), + (A["mod2"].obsp["test"], lambda md: md["mod2"].obsp["test"]), + (A["mod2"].obsp["test"][:, "obs_3"], lambda md: md["mod2"].obsp["test"][:, md["mod2"].obs_names.get_loc("obs_3")]), +] + + +@pytest.mark.parametrize("acc", [path[0] for path in PATHS]) +def test_in(mdata_augmented: md.MuData, acc): + assert acc in mdata_augmented + + +@pytest.mark.parametrize( + "acc", + [ + A.mod["mod3"], + A["mod3"], + A.mod["mod3"].var, + A.mod["mod1"].var["does_not_exist"], + A.mod["mod3"].X, + A.mod["mod3"].X["obs_2", :], + A.mod["mod3"].X[:, "mod1_var_1"], + A["mod2"].layers, + A["mod1"].layers["does_not_exist"], + A["mod1"].layers["does_not_exist"]["obs_2", :], + A["mod1"].obsp, + A["mod2"].obsp["does_not_exist"], + A["mod2"].obsp["does_not_exist"][:, "obs_3"], + ], +) +def test_not_in(mdata: md.MuData, acc): + assert acc not in mdata + + +@pytest.mark.parametrize("acc_expected", [path for path in PATHS if isinstance(path[0], ad.acc.AdRef)]) +def test_get(mdata_augmented: md.MuData, acc_expected): + acc, expected = acc_expected + + val = mdata_augmented[acc] + expected = expected(mdata_augmented) + if isinstance(expected, pd.DataFrame | pd.Series | np.ndarray): + assert np.all(val == expected) + else: + assert val == expected diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 4636b4f..3057c50 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -1,10 +1,8 @@ from pathlib import Path -import anndata as ad import numpy as np import pandas as pd import pytest -from packaging.version import Version import mudata as md @@ -147,15 +145,3 @@ def test_names_make_unique(mdata: md.MuData): with pytest.raises(TypeError, match="axis="): getattr(mdata, f"{attr}_names_make_unique")() - - -@pytest.mark.skipif( - Version(ad.__version__) < Version("0.13dev0"), reason="anndata version too old, no accessor support" -) -def test_accessors(mdata: md.MuData): - assert ad.acc.A.obs["arange"] in mdata - assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() - with pytest.raises(KeyError, match="test"): - mdata[ad.acc.A.var["test"]] - with pytest.raises(KeyError, match="there is one in"): - mdata[ad.acc.A.var["assert-bool"]] From 752ce5722fc08a01312673a7dffbcdb179fec7de Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 14:22:58 +0200 Subject: [PATCH 05/16] prevent X and layers accessors for MuData objects --- src/mudata/acc/__init__.py | 5 +++++ tests/test_accessors.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index e6ee63e..5297af9 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -168,5 +168,10 @@ def __getitem__(self, k: str, /) -> ModAcc[R]: def __repr__(self) -> str: return "A" + def __getattribute__(self, name: str): + if name in ("X", "layers"): + raise AttributeError(f"{self.__class__.__name__!r} object has no attribute {name!r}") + return super().__getattribute__(name) + A = MuAcc() diff --git a/tests/test_accessors.py b/tests/test_accessors.py index cf5a4b3..71600b0 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -82,3 +82,10 @@ def test_get(mdata_augmented: md.MuData, acc_expected): assert np.all(val == expected) else: assert val == expected + + +def test_no_data(): + with pytest.raises(AttributeError): + A.X # noqa: B018 + with pytest.raises(AttributeError): + A.layers # noqa: B018 From 4d30e5b9ebcf33c6471814a0b7a0e7f68ea35085 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 14:57:33 +0200 Subject: [PATCH 06/16] add obsmap/varmap accessors --- src/mudata/_core/mudata.py | 6 +++--- src/mudata/acc/__init__.py | 43 +++++++++++++++++++++++++++++++------- tests/test_accessors.py | 6 ++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index d5a4463..1d2a4cf 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -555,13 +555,13 @@ def __contains__(self, key) -> bool: return key in self._mod with suppress(ImportError): from anndata.acc import AdRef, MapAcc, RefAcc - from ..acc import ModAcc, ModMapAcc, _ModalityMapAcc, _ModalityMixin + from ..acc import ModAcc, MultiModAcc, _ModalityMapAcc, _ModalityMixin if isinstance(key, ModAcc | _ModalityMapAcc): return key.isin(self) elif isinstance(key, _ModalityMixin): - return AnnData.__contains__(self.mod[key.mod], key) - elif isinstance(key, ModMapAcc): + return key in self.mod[key.mod] + elif isinstance(key, MultiModAcc): return bool(self.mod) elif isinstance(key, AdRef | RefAcc | MapAcc): return AnnData.__contains__(self, key) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 5297af9..0a06f66 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,12 +1,13 @@ from __future__ import annotations from dataclasses import KW_ONLY, dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pandas as pd from anndata.acc import ( AdAcc, AdRef, + Axes, GraphAcc, GraphMapAcc, Idx2D, @@ -16,6 +17,7 @@ MetaAcc, MultiAcc, MultiMapAcc, + RefAcc, ) from anndata.compat import XVariable from anndata.typing import InMemoryArray @@ -49,7 +51,7 @@ def __repr__(self) -> str: return f"A.mod[{self.mod!r}].X" if self.k is None else f"A.mod[{self.mod}].layers[{self.k!r}]" -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class ModLayerMapAcc[R: AdRef](_ModalityMixin, LayerMapAcc[R]): ref_acc_cls: type[ModLayerAcc] = ModLayerAcc @@ -74,7 +76,7 @@ def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}m[self.k!r]" -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class ModMultiMapAcc[R: AdRef](_ModalityMixin, MultiMapAcc[R]): ref_acc_cls: type[ModMultiAcc] = ModMultiAcc @@ -93,7 +95,7 @@ def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}p[{self.k!r}]" -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class ModGraphMapAcc[R: AdRef](_ModalityMixin, GraphMapAcc[R]): ref_acc_cls: type[ModGraphAcc] = ModGraphAcc @@ -106,6 +108,28 @@ def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}p" +@dataclass(frozen=True) +class ModMapAcc[R: AdRef[str]](RefAcc[R, str]): + dim: Literal[obs, var] + + def dims(self, idx: Any, /) -> Axes: + return (self.dim,) + + def __repr__(self) -> str: + return f"A.{self.dim}map" + + def idx_repr(self, idx: str, /) -> str: + return f"[{idx}]" + + def isin(self, mdata: MuData, idx: str | None = None) -> bool: + m = getattr(mdata, f"{self.dim}map") + return idx is None or idx in m + + def get(self, mdata: MuData, idx: str, /) -> InMemoryArray: + m = getattr(mdata, f"{self.dim}map") + return m[idx] + + @dataclass(frozen=True, kw_only=True) class ModAcc[R: AdRef](_ModalityMixin, AdAcc[R]): layer_cls: type[ModLayerAcc] = ModLayerAcc @@ -137,7 +161,7 @@ def __repr__(self) -> str: @dataclass(frozen=True) -class ModMapAcc[R: AdRef](MapAcc[ModAcc[R]]): +class MultiModAcc[R: AdRef](MapAcc[ModAcc[R]]): ref_class: type[R] ref_acc_cls: type[ModAcc] = ModAcc @@ -155,12 +179,15 @@ class MuAcc[R: AdRef](AdAcc[R]): mod_cls: type[ModAcc] = ModAcc """Class to use for `mod` accessors.""" - mod: ModMapAcc[R] = field(init=False) + mod: MultiModAcc[R] = field(init=False) + obsmap: ModMapAcc[R] = field(init=False) + varmap: ModMapAcc[R] = field(init=False) def __post_init__(self) -> None: super().__post_init__() - mod = ModMapAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls) - object.__setattr__(self, "mod", mod) + object.__setattr__(self, "mod", MultiModAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls)) + object.__setattr__(self, "obsmap", ModMapAcc("obs", ref_class=self.ref_class)) + object.__setattr__(self, "varmap", ModMapAcc("var", ref_class=self.ref_class)) def __getitem__(self, k: str, /) -> ModAcc[R]: return self.mod[k] diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 71600b0..a5ce402 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -42,6 +42,10 @@ def test_anndata_accessors(mdata: md.MuData): (A["mod2"].obsp, lambda md: md["mod2"].obsp), (A["mod2"].obsp["test"], lambda md: md["mod2"].obsp["test"]), (A["mod2"].obsp["test"][:, "obs_3"], lambda md: md["mod2"].obsp["test"][:, md["mod2"].obs_names.get_loc("obs_3")]), + (A.obsmap, lambda md: md.obsmap), + (A.varmap, lambda md: md.varmap), + (A.obsmap["mod1"], lambda md: md.obsmap["mod1"]), + (A.varmap["mod2"], lambda md: md.varmap["mod2"]), ] @@ -66,6 +70,8 @@ def test_in(mdata_augmented: md.MuData, acc): A["mod1"].obsp, A["mod2"].obsp["does_not_exist"], A["mod2"].obsp["does_not_exist"][:, "obs_3"], + A.obsmap["mod3"], + A.varmap["mod3"], ], ) def test_not_in(mdata: md.MuData, acc): From d8c023c610db0bd7496d05756b93e2e5c0505995 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 15:09:01 +0200 Subject: [PATCH 07/16] fixup! prevent X and layers accessors for MuData objects --- src/mudata/acc/__init__.py | 10 +++++----- tests/test_accessors.py | 5 +++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 0a06f66..8546a8c 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -189,16 +189,16 @@ def __post_init__(self) -> None: object.__setattr__(self, "obsmap", ModMapAcc("obs", ref_class=self.ref_class)) object.__setattr__(self, "varmap", ModMapAcc("var", ref_class=self.ref_class)) + del self.__dict__["X"] + del self.__dict__["layers"] + del self.__dataclass_fields__["X"] + del self.__dataclass_fields__["layers"] + def __getitem__(self, k: str, /) -> ModAcc[R]: return self.mod[k] def __repr__(self) -> str: return "A" - def __getattribute__(self, name: str): - if name in ("X", "layers"): - raise AttributeError(f"{self.__class__.__name__!r} object has no attribute {name!r}") - return super().__getattribute__(name) - A = MuAcc() diff --git a/tests/test_accessors.py b/tests/test_accessors.py index a5ce402..745b469 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -1,3 +1,5 @@ +from dataclasses import fields + import anndata as ad import numpy as np import pandas as pd @@ -95,3 +97,6 @@ def test_no_data(): A.X # noqa: B018 with pytest.raises(AttributeError): A.layers # noqa: B018 + + for field in fields(A): + assert field.name not in ("X", "layers") From 19d25c31a8ba1c951bdb555c2c2fbabbc91ef0ef Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 16:31:31 +0200 Subject: [PATCH 08/16] add docs --- docs/accessors.md | 34 ++++++++++++++++++++++++++ docs/api.md | 8 ++++++ docs/conf.py | 2 +- src/mudata/acc/__init__.py | 50 +++++++++++++++++++++++++++++++++++--- 4 files changed, 89 insertions(+), 5 deletions(-) create mode 100644 docs/accessors.md diff --git a/docs/accessors.md b/docs/accessors.md new file mode 100644 index 0000000..d711789 --- /dev/null +++ b/docs/accessors.md @@ -0,0 +1,34 @@ +# Accessors and paths + +```{eval-rst} +.. module:: mudata.acc +``` + +[](#mudata.acc) provides [accessors](inv:anndata:*:term#accessor) that create [references](inv:anndata:*:term#reference) to axis-aligned 1D and 2D arrays in [MuData](#mudata.MuData) objects. +See the corresponding [AnnData documentation](inv:anndata:*:doc#accessors). + +:::{important} +This functionality requires AnnData 0.13 or later. +::: + +The central [accessor](inv:anndata:*:term#accessor) is [](#A). +```{eval-rst} +.. autodata:: A +``` +See [](#MuAcc) and [AdAcc](#anndata.acc.AdAcc) for examples of how to use it to create [references](inv:anndata:*:term#reference) (i.e. [AdRefs](#anndata.acc.AdRef)). + +```{eval-rst} +.. autosummary:: + :toctree: generated + + MuAcc + MultiModAcc + ModAcc + ModMapAcc + ModMetaAcc + ModLayerAcc + ModGraphAcc + ModMultiAcc + ModMultiMapAcc + ModGraphMapAcc +``` diff --git a/docs/api.md b/docs/api.md index 22a9483..688701a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -35,6 +35,14 @@ write_zarr ``` +## Accessors +```{eval-rst} +.. toctree:: + :hidden: + + mudata.acc +``` + ## Extensions ```{eval-rst} .. autosummary:: diff --git a/docs/conf.py b/docs/conf.py index 96d0e75..929a36c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -96,7 +96,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), - "anndata": ("https://anndata.readthedocs.io/en/stable/", None), + "anndata": ("https://anndata.readthedocs.io/en/latest/", None), "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/docs/", None), diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 8546a8c..06179f1 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import KW_ONLY, dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import pandas as pd from anndata.acc import ( @@ -31,6 +31,7 @@ @dataclass(frozen=True, kw_only=True) class _ModalityMixin: mod: str + """Modality this accessor refers to.""" @dataclass(frozen=True) @@ -47,12 +48,16 @@ def get(self, mdata: MuData, idx: I, /) -> R: @dataclass(frozen=True) class ModLayerAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], LayerAcc[R]): + """Reference accessor for arrays in :attr:`~anndata.acc.AdAcc.layers`.""" + def __repr__(self) -> str: return f"A.mod[{self.mod!r}].X" if self.k is None else f"A.mod[{self.mod}].layers[{self.k!r}]" @dataclass(frozen=True) class ModLayerMapAcc[R: AdRef](_ModalityMixin, LayerMapAcc[R]): + """Accessor for arrays in :attr:~anndata.acc.AdAcc.layers`.""" + ref_acc_cls: type[ModLayerAcc] = ModLayerAcc def __getitem__(self, k: str | None, /) -> ModLayerAcc[R]: @@ -66,18 +71,24 @@ def __repr__(self) -> str: @dataclass(frozen=True) class ModMetaAcc[R: AdRef[str | None]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): + """Reference accessor for arrays from metadata containers (:attr:`~anndata.acc.AdAcc.obs` / :attr:`~anndata.acc.AdAcc.var`).""" + def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}" @dataclass(frozen=True) class ModMultiAcc[R: AdRef[int]](_ModalityMapAcc[int, InMemoryArray], MultiAcc[R]): + """Reference accessor for arrays from multi-dimensional containers (:attr:`~anndata.acc.AdAcc.obsm` / :attr:`~anndata.acc.AdAcc.varm`).""" + def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}m[self.k!r]" @dataclass(frozen=True) class ModMultiMapAcc[R: AdRef](_ModalityMixin, MultiMapAcc[R]): + """Accessor for multi-dimensional array containers (:attr:`~anndata.acc.AdAcc.obsm` / :attr:`~anndata.acc.AdAcc.varm`).""" + ref_acc_cls: type[ModMultiAcc] = ModMultiAcc def __getitem__(self, k: str, /) -> ModMultiAcc[R]: @@ -91,12 +102,16 @@ def __repr__(self) -> str: @dataclass(frozen=True) class ModGraphAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], GraphAcc[R]): + """Reference accessor for arrays from graph containers (:attr:`~anndata.acc.AdAcc.obsp` / :attr:`~anndata.acc.AdAcc.varp`).""" + def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}p[{self.k!r}]" @dataclass(frozen=True) class ModGraphMapAcc[R: AdRef](_ModalityMixin, GraphMapAcc[R]): + """Accessor for graph containers (:attr:`~anndata.acc.AdAcc.obsp` / :attr:`~anndata.acc.AdAcc.varp`)""" + ref_acc_cls: type[ModGraphAcc] = ModGraphAcc def __getitem__(self, k: str, /) -> ModGraphAcc[R]: @@ -110,37 +125,55 @@ def __repr__(self) -> str: @dataclass(frozen=True) class ModMapAcc[R: AdRef[str]](RefAcc[R, str]): - dim: Literal[obs, var] + """Reference accessor for modality maps (:attr:`~MuAcc.obsmap` / :attr:`~MuAcc.varmap`).""" + + dim: Literal["obs", "var"] + """Axis this accessor refers to, e.g. `A.obsmap[k].dim == "var"`.""" def dims(self, idx: Any, /) -> Axes: + """Get which dimension this array refers to.""" return (self.dim,) def __repr__(self) -> str: return f"A.{self.dim}map" def idx_repr(self, idx: str, /) -> str: + """Get a string representation of the index.""" return f"[{idx}]" def isin(self, mdata: MuData, idx: str | None = None) -> bool: + """Check if the referenced array is in the :class:`~mudata.MuData` object.""" m = getattr(mdata, f"{self.dim}map") return idx is None or idx in m def get(self, mdata: MuData, idx: str, /) -> InMemoryArray: + """Get the referenced array from the :class:`~mudata.MuData` object.""" m = getattr(mdata, f"{self.dim}map") return m[idx] @dataclass(frozen=True, kw_only=True) class ModAcc[R: AdRef](_ModalityMixin, AdAcc[R]): + """Accessor to create :class:`AdRefs ` (:data:`A`) for modalities (:attr:`~MuAcc.mod`).""" + layer_cls: type[ModLayerAcc] = ModLayerAcc + """Class to use for `layers` accessors.""" + meta_cls: type[ModMetaAcc] = ModMetaAcc + """Class to use for `obs`/`var` accessors.""" + multi_cls: type[ModMultiAcc] = ModMultiAcc + """Class to use for `obsm`/`varm` accessors.""" + graph_cls: type[ModGraphAcc] = ModGraphAcc + """Class to use for `obsp`/`varp` accessors.""" def isin(self, mdata: MuData) -> bool: + """Check if the referenced modality is in the :class:`~mudata.MuData` object.""" return self.mod in mdata.mod - def get(self, mdata: MuData) -> ad.AnnData: + def get(self, mdata: MuData) -> AnnData: + """Get the referenced modality from the :class:`~mudata.MuData` object.""" return mdata.mod[self.mod] def __post_init__(self) -> None: @@ -162,6 +195,8 @@ def __repr__(self) -> str: @dataclass(frozen=True) class MultiModAcc[R: AdRef](MapAcc[ModAcc[R]]): + """Accessor for modalities (:attr:`~MuAcc.mod`).""" + ref_class: type[R] ref_acc_cls: type[ModAcc] = ModAcc @@ -176,12 +211,19 @@ def __repr__(self) -> str: @dataclass(frozen=True) class MuAcc[R: AdRef](AdAcc[R]): + """Accessor to create :class:`AdRefs ` (:data:`A`).""" + mod_cls: type[ModAcc] = ModAcc """Class to use for `mod` accessors.""" mod: MultiModAcc[R] = field(init=False) + """Access modalities.""" + obsmap: ModMapAcc[R] = field(init=False) + """Access mappings of observation indices in the MuData to indices in individual modalities.""" + varmap: ModMapAcc[R] = field(init=False) + """Access mappings of variable indices in the MuData to indices in individual modalities.""" def __post_init__(self) -> None: super().__post_init__() @@ -201,4 +243,4 @@ def __repr__(self) -> str: return "A" -A = MuAcc() +A: MuAcc[AdRef] = MuAcc() From bad1dcf8d5a392f44edcfa03950c4788912d7e8e Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 17:11:17 +0200 Subject: [PATCH 09/16] implement resolve() --- src/mudata/acc/__init__.py | 32 +++++++++++++++++++++++++++++++- tests/test_accessors.py | 14 ++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 06179f1..f4cb1bf 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import KW_ONLY, dataclass, field -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal import pandas as pd from anndata.acc import ( @@ -225,6 +225,8 @@ class MuAcc[R: AdRef](AdAcc[R]): varmap: ModMapAcc[R] = field(init=False) """Access mappings of variable indices in the MuData to indices in individual modalities.""" + ATTRS: ClassVar = frozenset(("mod", "obs", "var", "obsm", "varm", "obsp", "varp", "obsmap", "varmap")) + def __post_init__(self) -> None: super().__post_init__() object.__setattr__(self, "mod", MultiModAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls)) @@ -242,5 +244,33 @@ def __getitem__(self, k: str, /) -> ModAcc[R]: def __repr__(self) -> str: return "A" + def resolve(self, spec: str, *, strict: bool = True) -> R | None: + """Create :class:`~anndata.acc.AdRef` from a simplified string.""" + if not strict: + try: + self.resolve(spec) + except ValueError: + return None + + firstdot = spec.find(".") + if firstdot < 0: + raise ValueError(f"Cannot parse accessor {spec!r} that is not period-separated.") + firstattr = spec[:firstdot] + match firstattr: + case "mod": + modend = spec.find(".", firstdot + 1) + mod = spec[firstdot + 1 : modend] + if not mod: + raise ValueError(f"Cannot parse accessor{spec!r} that has an empty modality.") + acc = self.mod[mod] + return super().resolve.__func__(acc, spec[modend + 1 :], strict=strict) + case "obsmap" | "varmap": + if firstdot == len(spec): + raise ValueError(f"Cannot parse accessor{spec!r} that has an empty modality.") + mod = spec[firstdot + 1 :] + return getattr(self, firstattr)[mod] + case _: + return super().resolve(spec, strict=strict) + A: MuAcc[AdRef] = MuAcc() diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 745b469..d6e2405 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -100,3 +100,17 @@ def test_no_data(): for field in fields(A): assert field.name not in ("X", "layers") + + +def test_resolve(): + assert A.resolve("mod.rna.X[:, ACT1]") == A.mod["rna"].X[:, "ACT1"] + assert A.resolve("obsmap.rna") == A.obsmap["rna"] + + with pytest.raises(ValueError, match="Unknown accessor"): + A.resolve("rna.X[:, :]") + + with pytest.raises(ValueError, match="empty modality"): + A.resolve("mod..X[:, :]") + + with pytest.raises(ValueError, match="period-separated"): + A.resolve("abcd") From a62f99845f287ffc9211ba5d399ea3f530dfa6c0 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 11:22:44 +0200 Subject: [PATCH 10/16] fix docs skip over X and layers attributes in the autosummary template, this is the only way to keep them out of MuAcc docs --- .../_templates/autosummary/class-accessor.rst | 26 +++++++++++++++++++ docs/accessors.md | 3 ++- docs/conf.py | 7 +++-- docs/extensions/skip_private_bases.py | 17 ++++++++++++ pyproject.toml | 2 +- src/mudata/acc/__init__.py | 17 ++++++++---- 6 files changed, 63 insertions(+), 9 deletions(-) create mode 100644 docs/_templates/autosummary/class-accessor.rst create mode 100644 docs/extensions/skip_private_bases.py diff --git a/docs/_templates/autosummary/class-accessor.rst b/docs/_templates/autosummary/class-accessor.rst new file mode 100644 index 0000000..cf3deb0 --- /dev/null +++ b/docs/_templates/autosummary/class-accessor.rst @@ -0,0 +1,26 @@ +{{ fullname | escape | underline}} + +{% set attributes = attributes | select("ne", "X") | select("ne", "layers") %} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :show-inheritance: + + {% block attributes %} + {%- for item in attributes %} + {%- if loop.first %} + .. rubric:: Attributes + {% endif %} + .. autoattribute:: {{ item }} + {%- endfor %} + {% endblock %} + + {% block methods %} + {%- for item in methods if item != "__init__" and item not in inherited_members %} + {%- if loop.first %} + .. rubric:: Methods + {% endif %} + .. automethod:: {{ item }} + {%- endfor %} + {% endblock %} diff --git a/docs/accessors.md b/docs/accessors.md index d711789..45ec21a 100644 --- a/docs/accessors.md +++ b/docs/accessors.md @@ -8,7 +8,7 @@ See the corresponding [AnnData documentation](inv:anndata:*:doc#accessors). :::{important} -This functionality requires AnnData 0.13 or later. +This functionality requires AnnData 0.13 or newer. ::: The central [accessor](inv:anndata:*:term#accessor) is [](#A). @@ -20,6 +20,7 @@ See [](#MuAcc) and [AdAcc](#anndata.acc.AdAcc) for examples of how to use it to ```{eval-rst} .. autosummary:: :toctree: generated + :template: class-accessor MuAcc MultiModAcc diff --git a/docs/conf.py b/docs/conf.py index 929a36c..03d5fb5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,8 +133,11 @@ pygments_style = "default" katex_prerender = shutil.which(katex.NODEJS_BINARY) is not None -nitpick_ignore = [ +nitpick_ignore = ( # If building the documentation fails because of a missing link that is outside your control, # you can add an exception to this list. # ("py:class", "igraph.Graph"), -] + ("py:obj", "typing.R"), + ("py:obj", "anndata.acc.Axes"), + ("py:class", "AdRef"), +) diff --git a/docs/extensions/skip_private_bases.py b/docs/extensions/skip_private_bases.py new file mode 100644 index 0000000..eb4a07c --- /dev/null +++ b/docs/extensions/skip_private_bases.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, get_origin + +from sphinx.util.typing import ExtensionMetadata + +if TYPE_CHECKING: + from sphinx.application import Sphinx + + +def skip_private_bases(app: Sphinx, name: str, obj: type, _unused, bases: list[type]) -> None: + bases[:] = [b for b in bases if b is not object if get_origin(b) is not Generic if not b.__name__.startswith("_")] + + +def setup(app: Sphinx) -> ExtensionMetadata: + app.connect("autodoc-process-bases", skip_private_bases) + return ExtensionMetadata(parallel_read_safe=True) diff --git a/pyproject.toml b/pyproject.toml index ac9db94..277d746 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,7 @@ lint.ignore = [ "TID252", # allow relative imports ] lint.per-file-ignores."*/__init__.py" = [ "F401" ] -lint.per-file-ignores."docs/*" = [ "I" ] +lint.per-file-ignores."docs/*" = [ "D", "I" ] lint.per-file-ignores."docs/notebooks/*" = [ "D", "F403", "F405" ] lint.per-file-ignores."tests/*" = [ "D" ] lint.pydocstyle.convention = "numpy" diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index f4cb1bf..ad4fc18 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Hashable from dataclasses import KW_ONLY, dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -70,7 +71,7 @@ def __repr__(self) -> str: @dataclass(frozen=True) -class ModMetaAcc[R: AdRef[str | None]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): +class ModMetaAcc[R: AdRef[str]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): """Reference accessor for arrays from metadata containers (:attr:`~anndata.acc.AdAcc.obs` / :attr:`~anndata.acc.AdAcc.var`).""" def __repr__(self) -> str: @@ -193,8 +194,8 @@ def __repr__(self) -> str: return f"A.mod[{self.mod}]" -@dataclass(frozen=True) -class MultiModAcc[R: AdRef](MapAcc[ModAcc[R]]): +@dataclass(frozen=True, kw_only=True) +class MultiModAcc[R: AdRef](MapAcc[ModAcc]): """Accessor for modalities (:attr:`~MuAcc.mod`).""" ref_class: type[R] @@ -235,8 +236,6 @@ def __post_init__(self) -> None: del self.__dict__["X"] del self.__dict__["layers"] - del self.__dataclass_fields__["X"] - del self.__dataclass_fields__["layers"] def __getitem__(self, k: str, /) -> ModAcc[R]: return self.mod[k] @@ -273,4 +272,12 @@ def resolve(self, spec: str, *, strict: bool = True) -> R | None: return super().resolve(spec, strict=strict) +del MuAcc.__dataclass_fields__["X"] +del MuAcc.__dataclass_fields__["layers"] +del MuAcc.__dataclass_fields__["layer_cls"] + A: MuAcc[AdRef] = MuAcc() + + +if not TYPE_CHECKING: # https://github.com/tox-dev/sphinx-autodoc-typehints/issues/580 + R = AdRef[Hashable] From 7f98676ce49913027cdd06a027cae353b2a32aa9 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 12:01:44 +0200 Subject: [PATCH 11/16] implement to/from_json --- src/mudata/acc/__init__.py | 27 ++++++++++++++++++++++++++- tests/test_accessors.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index ad4fc18..72bea30 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Hashable +from collections.abc import Hashable, Sequence from dataclasses import KW_ONLY, dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -193,6 +193,10 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return f"A.mod[{self.mod}]" + def to_json(self, ref: R) -> list[str | int | None]: + """Serialize :class:`~anndata.acc.AdRef` to a JSON-compatible list.""" + return ["mod", self.mod, super().to_json(ref)] + @dataclass(frozen=True, kw_only=True) class MultiModAcc[R: AdRef](MapAcc[ModAcc]): @@ -271,6 +275,27 @@ def resolve(self, spec: str, *, strict: bool = True) -> R | None: case _: return super().resolve(spec, strict=strict) + def to_json(self, ref: R) -> list[str | int | None]: + """Serialize :class:`~anndata.acc.AdRef` to a JSON-compatible list.""" + if isinstance(ref.acc, ModMapAcc): + return [f"{ref.acc.dim}map", ref.idx] + + ret = super().to_json(ref) + if isinstance(ref.acc, _ModalityMixin): + ret = ["mod", ref.acc.mod, ret] + return ret + + def from_json(self, data: Sequence[str | int | None]) -> R: + """Create a :class:`~anndata.acc.AdRef` from a JSON sequence.""" + match data: + case ["mod", str() as modname, list() as inner]: + return self.mod[modname].from_json(inner) + case ["obsmap" | "varmap" as dim, str() as modname]: + acc = self.obsmap if dim == "obsmap" else self.varmap + return acc[modname] + case _: + return super().from_json(data) + del MuAcc.__dataclass_fields__["X"] del MuAcc.__dataclass_fields__["layers"] diff --git a/tests/test_accessors.py b/tests/test_accessors.py index d6e2405..2b8df83 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -114,3 +114,32 @@ def test_resolve(): with pytest.raises(ValueError, match="period-separated"): A.resolve("abcd") + + +@pytest.mark.parametrize("acc", [path[0] for path in PATHS if isinstance(path[0], ad.acc.AdRef)]) +def test_to_from_json(acc): + serialized = A.to_json(acc) + if isinstance(acc.acc, md.acc.ModMapAcc): + assert serialized[0] == f"{acc.acc.dim}map" + assert serialized[1] == acc.idx + else: + assert serialized[0] == "mod" + assert serialized[1] == acc.acc.mod + assert serialized[2] == ad.acc.A.to_json(acc) + + assert A.from_json(serialized) == acc + + +@pytest.mark.parametrize( + "acc", + [path[0] for path in PATHS if isinstance(path[0], ad.acc.AdRef) and not isinstance(path[0].acc, md.acc.ModMapAcc)], +) +def test_to_from_json_mod(acc): + modA = A.mod["foobar"] + + serialized = modA.to_json(acc) + assert serialized[0] == "mod" + assert serialized[1] == "foobar" + assert serialized[2] == ad.acc.A.to_json(acc) + + assert A.from_json(serialized).acc.mod == "foobar" From d77f9d57a9695a533b9e5cc5feaa7d37b84b1a15 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 14:38:58 +0200 Subject: [PATCH 12/16] add changelog entry and require anndata 0.13 for docs --- CHANGELOG.md | 5 +++++ docs/conf.py | 2 +- pyproject.toml | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c98a891..84783f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning][]. ## [0.4.0] (Unreleased) +### Added + +- MuData accessors. These are similar to and build on [AnnData accessors](https://anndata.scverse.org/page/accessors.html), but add an additional + level for modalities. + ### Changed - `update()` no longer automatically pulls obs/var columns from individual modalities by default. Set `mudata.set_options(pull_on_update=true)` diff --git a/docs/conf.py b/docs/conf.py index 03d5fb5..2ee7a28 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -96,7 +96,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), - "anndata": ("https://anndata.readthedocs.io/en/latest/", None), + "anndata": ("https://anndata.readthedocs.io/en/stable/", None), "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/docs/", None), diff --git a/pyproject.toml b/pyproject.toml index 277d746..b8e8082 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dev = [ ] test = [ "coverage>=7.10", "mudata[io]", "packaging", "pytest" ] doc = [ + "anndata>=0.13rc1", "docutils>=0.8,!=0.18.*,!=0.19.*", "ipykernel", "ipython", From 8045267d3d60394e1556b8fbba94f2a753648664 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 12:46:30 +0000 Subject: [PATCH 13/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/mudata/_core/mudata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 1d2a4cf..b1c2426 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -555,6 +555,7 @@ def __contains__(self, key) -> bool: return key in self._mod with suppress(ImportError): from anndata.acc import AdRef, MapAcc, RefAcc + from ..acc import ModAcc, MultiModAcc, _ModalityMapAcc, _ModalityMixin if isinstance(key, ModAcc | _ModalityMapAcc): From daee8fe89e4b3db97ca3d532e09f48e63ff78cb5 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 14:49:37 +0200 Subject: [PATCH 14/16] docs: don't exclude X and layers from ModAcc --- docs/_templates/autosummary/class-accessor.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/_templates/autosummary/class-accessor.rst b/docs/_templates/autosummary/class-accessor.rst index cf3deb0..64c380a 100644 --- a/docs/_templates/autosummary/class-accessor.rst +++ b/docs/_templates/autosummary/class-accessor.rst @@ -1,6 +1,8 @@ {{ fullname | escape | underline}} +{%if fullname == "mudata.acc.MuAcc" %} {% set attributes = attributes | select("ne", "X") | select("ne", "layers") %} +{% endif %} .. currentmodule:: {{ module }} From 70a8ed861105318cb8ff3752c920831a9cf5898f Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 14:54:44 +0200 Subject: [PATCH 15/16] fix test --- tests/test_accessors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 2b8df83..0cdb854 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -7,11 +7,12 @@ from packaging.version import Version import mudata as md -from mudata.acc import A if Version(ad.__version__) < Version("0.13dev0"): pytest.skip("anndata version too old, no accessor support", allow_module_level=True) +from mudata.acc import A + @pytest.fixture def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): From 5ccd5787227cbc30e72e7cfe91fb641f30e20c7c Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 16:47:02 +0200 Subject: [PATCH 16/16] export mudata.acc when available --- src/mudata/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mudata/__init__.py b/src/mudata/__init__.py index e8be205..9ff6687 100644 --- a/src/mudata/__init__.py +++ b/src/mudata/__init__.py @@ -25,9 +25,6 @@ from ._core.to_ import to_anndata, to_mudata from ._version import __version__, __version_tuple__ -with suppress(ImportError): - from . import acc - # file format versions __anndataversion__ = "0.1.0" __mudataversion__ = "0.1.0" @@ -56,3 +53,8 @@ "register_mudata_namespace", "ExtensionNamespace", ] + +with suppress(ImportError): + from . import acc + + __all__.append("acc")