diff --git a/CHANGELOG.md b/CHANGELOG.md index c98a891..84783f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning][]. ## [0.4.0] (Unreleased) +### Added + +- MuData accessors. These are similar to and build on [AnnData accessors](https://anndata.scverse.org/page/accessors.html), but add an additional + level for modalities. + ### Changed - `update()` no longer automatically pulls obs/var columns from individual modalities by default. Set `mudata.set_options(pull_on_update=true)` diff --git a/docs/_templates/autosummary/class-accessor.rst b/docs/_templates/autosummary/class-accessor.rst new file mode 100644 index 0000000..64c380a --- /dev/null +++ b/docs/_templates/autosummary/class-accessor.rst @@ -0,0 +1,28 @@ +{{ fullname | escape | underline}} + +{%if fullname == "mudata.acc.MuAcc" %} +{% set attributes = attributes | select("ne", "X") | select("ne", "layers") %} +{% endif %} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :show-inheritance: + + {% block attributes %} + {%- for item in attributes %} + {%- if loop.first %} + .. rubric:: Attributes + {% endif %} + .. autoattribute:: {{ item }} + {%- endfor %} + {% endblock %} + + {% block methods %} + {%- for item in methods if item != "__init__" and item not in inherited_members %} + {%- if loop.first %} + .. rubric:: Methods + {% endif %} + .. automethod:: {{ item }} + {%- endfor %} + {% endblock %} diff --git a/docs/accessors.md b/docs/accessors.md new file mode 100644 index 0000000..45ec21a --- /dev/null +++ b/docs/accessors.md @@ -0,0 +1,35 @@ +# Accessors and paths + +```{eval-rst} +.. module:: mudata.acc +``` + +[](#mudata.acc) provides [accessors](inv:anndata:*:term#accessor) that create [references](inv:anndata:*:term#reference) to axis-aligned 1D and 2D arrays in [MuData](#mudata.MuData) objects. +See the corresponding [AnnData documentation](inv:anndata:*:doc#accessors). + +:::{important} +This functionality requires AnnData 0.13 or newer. +::: + +The central [accessor](inv:anndata:*:term#accessor) is [](#A). +```{eval-rst} +.. autodata:: A +``` +See [](#MuAcc) and [AdAcc](#anndata.acc.AdAcc) for examples of how to use it to create [references](inv:anndata:*:term#reference) (i.e. [AdRefs](#anndata.acc.AdRef)). + +```{eval-rst} +.. autosummary:: + :toctree: generated + :template: class-accessor + + MuAcc + MultiModAcc + ModAcc + ModMapAcc + ModMetaAcc + ModLayerAcc + ModGraphAcc + ModMultiAcc + ModMultiMapAcc + ModGraphMapAcc +``` diff --git a/docs/api.md b/docs/api.md index 22a9483..688701a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -35,6 +35,14 @@ write_zarr ``` +## Accessors +```{eval-rst} +.. toctree:: + :hidden: + + mudata.acc +``` + ## Extensions ```{eval-rst} .. autosummary:: diff --git a/docs/conf.py b/docs/conf.py index 96d0e75..2ee7a28 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,8 +133,11 @@ pygments_style = "default" katex_prerender = shutil.which(katex.NODEJS_BINARY) is not None -nitpick_ignore = [ +nitpick_ignore = ( # If building the documentation fails because of a missing link that is outside your control, # you can add an exception to this list. # ("py:class", "igraph.Graph"), -] + ("py:obj", "typing.R"), + ("py:obj", "anndata.acc.Axes"), + ("py:class", "AdRef"), +) diff --git a/docs/extensions/skip_private_bases.py b/docs/extensions/skip_private_bases.py new file mode 100644 index 0000000..eb4a07c --- /dev/null +++ b/docs/extensions/skip_private_bases.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, get_origin + +from sphinx.util.typing import ExtensionMetadata + +if TYPE_CHECKING: + from sphinx.application import Sphinx + + +def skip_private_bases(app: Sphinx, name: str, obj: type, _unused, bases: list[type]) -> None: + bases[:] = [b for b in bases if b is not object if get_origin(b) is not Generic if not b.__name__.startswith("_")] + + +def setup(app: Sphinx) -> ExtensionMetadata: + app.connect("autodoc-process-bases", skip_private_bases) + return ExtensionMetadata(parallel_read_safe=True) diff --git a/pyproject.toml b/pyproject.toml index 364f3d9..b8e8082 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,8 +48,9 @@ dev = [ "pre-commit", "twine>=4.0.2", ] -test = [ "coverage>=7.10", "mudata[io]", "pytest" ] +test = [ "coverage>=7.10", "mudata[io]", "packaging", "pytest" ] doc = [ + "anndata>=0.13rc1", "docutils>=0.8,!=0.18.*,!=0.19.*", "ipykernel", "ipython", @@ -123,7 +124,7 @@ lint.ignore = [ "TID252", # allow relative imports ] lint.per-file-ignores."*/__init__.py" = [ "F401" ] -lint.per-file-ignores."docs/*" = [ "I" ] +lint.per-file-ignores."docs/*" = [ "D", "I" ] lint.per-file-ignores."docs/notebooks/*" = [ "D", "F403", "F405" ] lint.per-file-ignores."tests/*" = [ "D" ] lint.pydocstyle.convention = "numpy" diff --git a/src/mudata/__init__.py b/src/mudata/__init__.py index 6f4b046..9ff6687 100644 --- a/src/mudata/__init__.py +++ b/src/mudata/__init__.py @@ -1,5 +1,7 @@ """Multimodal datasets""" +from contextlib import suppress + from anndata import AnnData from scverse_misc import ExtensionNamespace from scverse_misc import make_register_namespace_decorator as _make_register_namespace_decorator @@ -51,3 +53,8 @@ "register_mudata_namespace", "ExtensionNamespace", ] + +with suppress(ImportError): + from . import acc + + __all__.append("acc") diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index d5fd1f6..b1c2426 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -533,8 +533,41 @@ def strings_to_categoricals(self, df: pd.DataFrame | None = None) -> pd.DataFram def __getitem__(self, index) -> AnnData | MuData: if isinstance(index, str): return self._mod[index] - else: - return MuData(self, as_view=True, index=index) + + with suppress(ImportError): + from anndata.acc import AdRef + + if isinstance(index, AdRef): + try: + return index.acc.get(self, index.idx) + except KeyError as e: + if index.acc.dim in ("obs", "var"): + for modname, mod in self._mod.items(): + if index in mod: + raise KeyError( + f"There is no key {index.idx} in MuData .{index.acc.dim} but there is one in {modname} .{index.acc.dim}. Consider running `pull_{index.acc.dim}()` to update global .{index.acc.dim}." + ) from e + raise + return MuData(self, as_view=True, index=index) + + def __contains__(self, key) -> bool: + if isinstance(key, str): + return key in self._mod + with suppress(ImportError): + from anndata.acc import AdRef, MapAcc, RefAcc + + from ..acc import ModAcc, MultiModAcc, _ModalityMapAcc, _ModalityMixin + + if isinstance(key, ModAcc | _ModalityMapAcc): + return key.isin(self) + elif isinstance(key, _ModalityMixin): + return key in self.mod[key.mod] + elif isinstance(key, MultiModAcc): + return bool(self.mod) + elif isinstance(key, AdRef | RefAcc | MapAcc): + return AnnData.__contains__(self, key) + + raise TypeError(f"Unexpected key {key!r}.") @property def mod(self) -> Mapping[str, AnnData | MuData]: diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py new file mode 100644 index 0000000..72bea30 --- /dev/null +++ b/src/mudata/acc/__init__.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +from collections.abc import Hashable, Sequence +from dataclasses import KW_ONLY, dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar, Literal + +import pandas as pd +from anndata.acc import ( + AdAcc, + AdRef, + Axes, + GraphAcc, + GraphMapAcc, + Idx2D, + LayerAcc, + LayerMapAcc, + MapAcc, + MetaAcc, + MultiAcc, + MultiMapAcc, + RefAcc, +) +from anndata.compat import XVariable +from anndata.typing import InMemoryArray + +if TYPE_CHECKING: + from anndata import AnnData + + from .. import MuData + + +@dataclass(frozen=True, kw_only=True) +class _ModalityMixin: + mod: str + """Modality this accessor refers to.""" + + +@dataclass(frozen=True) +class _ModalityMapAcc[I, R](_ModalityMixin): + def isin(self, mdata: MuData, idx: I | None = None) -> bool: + if self.mod not in mdata.mod: + return False + else: + return super().isin(mdata[self.mod], idx) + + def get(self, mdata: MuData, idx: I, /) -> R: + return super().get(mdata[self.mod], idx) + + +@dataclass(frozen=True) +class ModLayerAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], LayerAcc[R]): + """Reference accessor for arrays in :attr:`~anndata.acc.AdAcc.layers`.""" + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].X" if self.k is None else f"A.mod[{self.mod}].layers[{self.k!r}]" + + +@dataclass(frozen=True) +class ModLayerMapAcc[R: AdRef](_ModalityMixin, LayerMapAcc[R]): + """Accessor for arrays in :attr:~anndata.acc.AdAcc.layers`.""" + + ref_acc_cls: type[ModLayerAcc] = ModLayerAcc + + def __getitem__(self, k: str | None, /) -> ModLayerAcc[R]: + if not isinstance(k, str | None): + raise TypeError(f"Unsupported layer {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].layers" + + +@dataclass(frozen=True) +class ModMetaAcc[R: AdRef[str]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): + """Reference accessor for arrays from metadata containers (:attr:`~anndata.acc.AdAcc.obs` / :attr:`~anndata.acc.AdAcc.var`).""" + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}" + + +@dataclass(frozen=True) +class ModMultiAcc[R: AdRef[int]](_ModalityMapAcc[int, InMemoryArray], MultiAcc[R]): + """Reference accessor for arrays from multi-dimensional containers (:attr:`~anndata.acc.AdAcc.obsm` / :attr:`~anndata.acc.AdAcc.varm`).""" + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}m[self.k!r]" + + +@dataclass(frozen=True) +class ModMultiMapAcc[R: AdRef](_ModalityMixin, MultiMapAcc[R]): + """Accessor for multi-dimensional array containers (:attr:`~anndata.acc.AdAcc.obsm` / :attr:`~anndata.acc.AdAcc.varm`).""" + + ref_acc_cls: type[ModMultiAcc] = ModMultiAcc + + def __getitem__(self, k: str, /) -> ModMultiAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported {self.dim}m key {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, dim=self.dim, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}m" + + +@dataclass(frozen=True) +class ModGraphAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], GraphAcc[R]): + """Reference accessor for arrays from graph containers (:attr:`~anndata.acc.AdAcc.obsp` / :attr:`~anndata.acc.AdAcc.varp`).""" + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}p[{self.k!r}]" + + +@dataclass(frozen=True) +class ModGraphMapAcc[R: AdRef](_ModalityMixin, GraphMapAcc[R]): + """Accessor for graph containers (:attr:`~anndata.acc.AdAcc.obsp` / :attr:`~anndata.acc.AdAcc.varp`)""" + + ref_acc_cls: type[ModGraphAcc] = ModGraphAcc + + def __getitem__(self, k: str, /) -> ModGraphAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported {self.dim}p key {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, dim=self.dim, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}p" + + +@dataclass(frozen=True) +class ModMapAcc[R: AdRef[str]](RefAcc[R, str]): + """Reference accessor for modality maps (:attr:`~MuAcc.obsmap` / :attr:`~MuAcc.varmap`).""" + + dim: Literal["obs", "var"] + """Axis this accessor refers to, e.g. `A.obsmap[k].dim == "var"`.""" + + def dims(self, idx: Any, /) -> Axes: + """Get which dimension this array refers to.""" + return (self.dim,) + + def __repr__(self) -> str: + return f"A.{self.dim}map" + + def idx_repr(self, idx: str, /) -> str: + """Get a string representation of the index.""" + return f"[{idx}]" + + def isin(self, mdata: MuData, idx: str | None = None) -> bool: + """Check if the referenced array is in the :class:`~mudata.MuData` object.""" + m = getattr(mdata, f"{self.dim}map") + return idx is None or idx in m + + def get(self, mdata: MuData, idx: str, /) -> InMemoryArray: + """Get the referenced array from the :class:`~mudata.MuData` object.""" + m = getattr(mdata, f"{self.dim}map") + return m[idx] + + +@dataclass(frozen=True, kw_only=True) +class ModAcc[R: AdRef](_ModalityMixin, AdAcc[R]): + """Accessor to create :class:`AdRefs ` (:data:`A`) for modalities (:attr:`~MuAcc.mod`).""" + + layer_cls: type[ModLayerAcc] = ModLayerAcc + """Class to use for `layers` accessors.""" + + meta_cls: type[ModMetaAcc] = ModMetaAcc + """Class to use for `obs`/`var` accessors.""" + + multi_cls: type[ModMultiAcc] = ModMultiAcc + """Class to use for `obsm`/`varm` accessors.""" + + graph_cls: type[ModGraphAcc] = ModGraphAcc + """Class to use for `obsp`/`varp` accessors.""" + + def isin(self, mdata: MuData) -> bool: + """Check if the referenced modality is in the :class:`~mudata.MuData` object.""" + return self.mod in mdata.mod + + def get(self, mdata: MuData) -> AnnData: + """Get the referenced modality from the :class:`~mudata.MuData` object.""" + return mdata.mod[self.mod] + + def __post_init__(self) -> None: + x = self.layer_cls(mod=self.mod, k=None, ref_class=self.ref_class) + layers = ModLayerMapAcc(mod=self.mod, ref_class=self.ref_class, ref_acc_cls=self.layer_cls) + object.__setattr__(self, "X", x) + object.__setattr__(self, "layers", layers) + for dim in ("obs", "var"): + meta = self.meta_cls(mod=self.mod, dim=dim, ref_class=self.ref_class) + multi = ModMultiMapAcc(mod=self.mod, dim=dim, ref_class=self.ref_class, ref_acc_cls=self.multi_cls) + graphs = ModGraphMapAcc(mod=self.mod, dim=dim, ref_class=self.ref_class, ref_acc_cls=self.graph_cls) + object.__setattr__(self, dim, meta) + object.__setattr__(self, f"{dim}m", multi) + object.__setattr__(self, f"{dim}p", graphs) + + def __repr__(self) -> str: + return f"A.mod[{self.mod}]" + + def to_json(self, ref: R) -> list[str | int | None]: + """Serialize :class:`~anndata.acc.AdRef` to a JSON-compatible list.""" + return ["mod", self.mod, super().to_json(ref)] + + +@dataclass(frozen=True, kw_only=True) +class MultiModAcc[R: AdRef](MapAcc[ModAcc]): + """Accessor for modalities (:attr:`~MuAcc.mod`).""" + + ref_class: type[R] + ref_acc_cls: type[ModAcc] = ModAcc + + def __getitem__(self, k: str, /) -> ModAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported mod key {k!r}") + return self.ref_acc_cls(mod=k, ref_class=self.ref_class) + + def __repr__(self) -> str: + return "A.mod" + + +@dataclass(frozen=True) +class MuAcc[R: AdRef](AdAcc[R]): + """Accessor to create :class:`AdRefs ` (:data:`A`).""" + + mod_cls: type[ModAcc] = ModAcc + """Class to use for `mod` accessors.""" + + mod: MultiModAcc[R] = field(init=False) + """Access modalities.""" + + obsmap: ModMapAcc[R] = field(init=False) + """Access mappings of observation indices in the MuData to indices in individual modalities.""" + + varmap: ModMapAcc[R] = field(init=False) + """Access mappings of variable indices in the MuData to indices in individual modalities.""" + + ATTRS: ClassVar = frozenset(("mod", "obs", "var", "obsm", "varm", "obsp", "varp", "obsmap", "varmap")) + + def __post_init__(self) -> None: + super().__post_init__() + object.__setattr__(self, "mod", MultiModAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls)) + object.__setattr__(self, "obsmap", ModMapAcc("obs", ref_class=self.ref_class)) + object.__setattr__(self, "varmap", ModMapAcc("var", ref_class=self.ref_class)) + + del self.__dict__["X"] + del self.__dict__["layers"] + + def __getitem__(self, k: str, /) -> ModAcc[R]: + return self.mod[k] + + def __repr__(self) -> str: + return "A" + + def resolve(self, spec: str, *, strict: bool = True) -> R | None: + """Create :class:`~anndata.acc.AdRef` from a simplified string.""" + if not strict: + try: + self.resolve(spec) + except ValueError: + return None + + firstdot = spec.find(".") + if firstdot < 0: + raise ValueError(f"Cannot parse accessor {spec!r} that is not period-separated.") + firstattr = spec[:firstdot] + match firstattr: + case "mod": + modend = spec.find(".", firstdot + 1) + mod = spec[firstdot + 1 : modend] + if not mod: + raise ValueError(f"Cannot parse accessor{spec!r} that has an empty modality.") + acc = self.mod[mod] + return super().resolve.__func__(acc, spec[modend + 1 :], strict=strict) + case "obsmap" | "varmap": + if firstdot == len(spec): + raise ValueError(f"Cannot parse accessor{spec!r} that has an empty modality.") + mod = spec[firstdot + 1 :] + return getattr(self, firstattr)[mod] + case _: + return super().resolve(spec, strict=strict) + + def to_json(self, ref: R) -> list[str | int | None]: + """Serialize :class:`~anndata.acc.AdRef` to a JSON-compatible list.""" + if isinstance(ref.acc, ModMapAcc): + return [f"{ref.acc.dim}map", ref.idx] + + ret = super().to_json(ref) + if isinstance(ref.acc, _ModalityMixin): + ret = ["mod", ref.acc.mod, ret] + return ret + + def from_json(self, data: Sequence[str | int | None]) -> R: + """Create a :class:`~anndata.acc.AdRef` from a JSON sequence.""" + match data: + case ["mod", str() as modname, list() as inner]: + return self.mod[modname].from_json(inner) + case ["obsmap" | "varmap" as dim, str() as modname]: + acc = self.obsmap if dim == "obsmap" else self.varmap + return acc[modname] + case _: + return super().from_json(data) + + +del MuAcc.__dataclass_fields__["X"] +del MuAcc.__dataclass_fields__["layers"] +del MuAcc.__dataclass_fields__["layer_cls"] + +A: MuAcc[AdRef] = MuAcc() + + +if not TYPE_CHECKING: # https://github.com/tox-dev/sphinx-autodoc-typehints/issues/580 + R = AdRef[Hashable] diff --git a/tests/conftest.py b/tests/conftest.py index d637007..9a4bc66 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,10 +37,12 @@ def rng() -> np.random.Generator: def mdata(rng: np.random.Generator, request: pytest.FixtureRequest) -> MuData: axis = getattr(request, "param", 0) mod1 = AnnData( - np.arange(0, 200, 0.1).reshape(-1, 20), obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)) + np.arange(0, 200, 0.1).reshape(-1, 20), + obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False).astype(str)), ) mod2 = AnnData( - np.arange(101, 3101, 1).reshape(-1, 30), obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)) + np.arange(101, 3101, 1).reshape(-1, 30), + obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False).astype(str)), ) mod1.var["assert-bool"] = True mod2.var["assert-bool"] = False diff --git a/tests/test_accessors.py b/tests/test_accessors.py new file mode 100644 index 0000000..0cdb854 --- /dev/null +++ b/tests/test_accessors.py @@ -0,0 +1,146 @@ +from dataclasses import fields + +import anndata as ad +import numpy as np +import pandas as pd +import pytest +from packaging.version import Version + +import mudata as md + +if Version(ad.__version__) < Version("0.13dev0"): + pytest.skip("anndata version too old, no accessor support", allow_module_level=True) + +from mudata.acc import A + + +@pytest.fixture +def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): + mdata["mod1"].layers["counts"] = rng.poisson(1, size=mdata["mod1"].shape) + mdata["mod2"].obsp["test"] = rng.normal(size=(mdata["mod2"].n_obs, mdata["mod2"].n_obs)) + + return mdata + + +def test_anndata_accessors(mdata: md.MuData): + assert ad.acc.A.obs["arange"] in mdata + assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() + with pytest.raises(KeyError, match="test"): + mdata[ad.acc.A.var["test"]] + with pytest.raises(KeyError, match="there is one in"): + mdata[ad.acc.A.var["assert-bool"]] + + +PATHS = [ + (A.mod["mod1"], lambda md: md.mod["mod1"]), + (A["mod1"], lambda md: md["mod1"]), + (A.mod["mod1"].var, lambda md: md.mod["mod1"].var), + (A.mod["mod1"].var["assert-bool"], lambda md: md.mod["mod1"].var["assert-bool"]), + (A.mod["mod1"].X, lambda md: md.mod["mod1"].X), + (A.mod["mod1"].X["obs_2", :], lambda md: md.mod["mod1"]["obs_2", :].X.squeeze()), + (A.mod["mod1"].X[:, "mod1_var_1"], lambda md: md.mod["mod1"][:, "mod1_var_1"].X.squeeze()), + (A["mod1"].layers, lambda md: md["mod1"].layers), + (A["mod1"].layers["counts"], lambda md: md["mod1"].layers["counts"]), + (A["mod1"].layers["counts"]["obs_2", :], lambda md: md["mod1"]["obs_2", :].layers["counts"].squeeze()), + (A["mod2"].obsp, lambda md: md["mod2"].obsp), + (A["mod2"].obsp["test"], lambda md: md["mod2"].obsp["test"]), + (A["mod2"].obsp["test"][:, "obs_3"], lambda md: md["mod2"].obsp["test"][:, md["mod2"].obs_names.get_loc("obs_3")]), + (A.obsmap, lambda md: md.obsmap), + (A.varmap, lambda md: md.varmap), + (A.obsmap["mod1"], lambda md: md.obsmap["mod1"]), + (A.varmap["mod2"], lambda md: md.varmap["mod2"]), +] + + +@pytest.mark.parametrize("acc", [path[0] for path in PATHS]) +def test_in(mdata_augmented: md.MuData, acc): + assert acc in mdata_augmented + + +@pytest.mark.parametrize( + "acc", + [ + A.mod["mod3"], + A["mod3"], + A.mod["mod3"].var, + A.mod["mod1"].var["does_not_exist"], + A.mod["mod3"].X, + A.mod["mod3"].X["obs_2", :], + A.mod["mod3"].X[:, "mod1_var_1"], + A["mod2"].layers, + A["mod1"].layers["does_not_exist"], + A["mod1"].layers["does_not_exist"]["obs_2", :], + A["mod1"].obsp, + A["mod2"].obsp["does_not_exist"], + A["mod2"].obsp["does_not_exist"][:, "obs_3"], + A.obsmap["mod3"], + A.varmap["mod3"], + ], +) +def test_not_in(mdata: md.MuData, acc): + assert acc not in mdata + + +@pytest.mark.parametrize("acc_expected", [path for path in PATHS if isinstance(path[0], ad.acc.AdRef)]) +def test_get(mdata_augmented: md.MuData, acc_expected): + acc, expected = acc_expected + + val = mdata_augmented[acc] + expected = expected(mdata_augmented) + if isinstance(expected, pd.DataFrame | pd.Series | np.ndarray): + assert np.all(val == expected) + else: + assert val == expected + + +def test_no_data(): + with pytest.raises(AttributeError): + A.X # noqa: B018 + with pytest.raises(AttributeError): + A.layers # noqa: B018 + + for field in fields(A): + assert field.name not in ("X", "layers") + + +def test_resolve(): + assert A.resolve("mod.rna.X[:, ACT1]") == A.mod["rna"].X[:, "ACT1"] + assert A.resolve("obsmap.rna") == A.obsmap["rna"] + + with pytest.raises(ValueError, match="Unknown accessor"): + A.resolve("rna.X[:, :]") + + with pytest.raises(ValueError, match="empty modality"): + A.resolve("mod..X[:, :]") + + with pytest.raises(ValueError, match="period-separated"): + A.resolve("abcd") + + +@pytest.mark.parametrize("acc", [path[0] for path in PATHS if isinstance(path[0], ad.acc.AdRef)]) +def test_to_from_json(acc): + serialized = A.to_json(acc) + if isinstance(acc.acc, md.acc.ModMapAcc): + assert serialized[0] == f"{acc.acc.dim}map" + assert serialized[1] == acc.idx + else: + assert serialized[0] == "mod" + assert serialized[1] == acc.acc.mod + assert serialized[2] == ad.acc.A.to_json(acc) + + assert A.from_json(serialized) == acc + + +@pytest.mark.parametrize( + "acc", + [path[0] for path in PATHS if isinstance(path[0], ad.acc.AdRef) and not isinstance(path[0].acc, md.acc.ModMapAcc)], +) +def test_to_from_json_mod(acc): + modA = A.mod["foobar"] + + serialized = modA.to_json(acc) + assert serialized[0] == "mod" + assert serialized[1] == "foobar" + assert serialized[2] == ad.acc.A.to_json(acc) + + assert A.from_json(serialized).acc.mod == "foobar" diff --git a/tests/test_update.py b/tests/test_update.py index 0a32eb0..318ff49 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -146,6 +146,9 @@ def test_update_simple(mdata: MuData, axis: Axis): for mod in mdata.mod.keys(): assert mdata.obsmap[mod].dtype.kind == "u" assert mdata.varmap[mod].dtype.kind == "u" + assert mod in mdata + with pytest.raises(TypeError): + 1 in mdata # noqa: B015 # names along non-axis are concatenated assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())