Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,9 +1054,10 @@ def show(

# Check if user specified only certain elements to be plotted
cs_contents = _get_cs_contents(sdata)
cs_index = cs_contents.set_index("cs")
pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]] = []

elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_contents, cs)
elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs)

# filter out cs without relevant elements
cmds = [cmd for cmd, _ in render_cmds]
Expand All @@ -1079,7 +1080,7 @@ def show(
strict_cs = [
cs_name
for cs_name in coordinate_systems
if all(cs_contents.query(f"cs == '{cs_name}'").iloc[0][flag] for flag in required_flags)
if cs_name in cs_index.index and all(cs_index.loc[cs_name][flag] for flag in required_flags)
]
if strict_cs:
coordinate_systems = strict_cs
Expand Down Expand Up @@ -1197,15 +1198,15 @@ def _draw_colorbar(
elif location == "top":
trackers_axes["top"] = pad_axes + bbox_axes.height

cs_contents = _get_cs_contents(sdata)

# go through tree

for i, cs in enumerate(coordinate_systems):
sdata = self._copy()
_, has_images, has_labels, has_points, has_shapes = (
cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist()
)
cs_row = cs_index.loc[cs]
has_images = cs_row["has_images"]
has_labels = cs_row["has_labels"]
has_points = cs_row["has_points"]
has_shapes = cs_row["has_shapes"]
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
assert isinstance(ax, Axes)
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None
Expand Down
50 changes: 18 additions & 32 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,35 +334,21 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
"""Check which coordinate systems contain which elements and return that info."""
cs_mapping = _get_coordinate_system_mapping(sdata)
content_flags = ["has_images", "has_labels", "has_points", "has_shapes"]
cs_contents = pd.DataFrame(columns=["cs"] + content_flags)

rows = []
for cs_name, element_ids in cs_mapping.items():
# determine if coordinate system has the respective elements
cs_has_images = any(e in sdata.images for e in element_ids)
cs_has_labels = any(e in sdata.labels for e in element_ids)
cs_has_points = any(e in sdata.points for e in element_ids)
cs_has_shapes = any(e in sdata.shapes for e in element_ids)

cs_contents = pd.concat(
[
cs_contents,
pd.DataFrame(
{
"cs": cs_name,
"has_images": [cs_has_images],
"has_labels": [cs_has_labels],
"has_points": [cs_has_points],
"has_shapes": [cs_has_shapes],
}
),
]
rows.append(
{
"cs": cs_name,
"has_images": any(e in sdata.images for e in element_ids),
"has_labels": any(e in sdata.labels for e in element_ids),
"has_points": any(e in sdata.points for e in element_ids),
"has_shapes": any(e in sdata.shapes for e in element_ids),
}
)

cs_contents["has_images"] = cs_contents["has_images"].astype("bool")
cs_contents["has_labels"] = cs_contents["has_labels"].astype("bool")
cs_contents["has_points"] = cs_contents["has_points"].astype("bool")
cs_contents["has_shapes"] = cs_contents["has_shapes"].astype("bool")

cs_contents = pd.DataFrame(rows, columns=["cs"] + content_flags)
cs_contents[content_flags] = cs_contents[content_flags].astype("bool")
return cs_contents


Expand Down Expand Up @@ -2106,7 +2092,7 @@ def _get_elements_to_be_rendered(
ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
]
],
cs_contents: pd.DataFrame,
cs_index: pd.DataFrame,
cs: str,
) -> list[str]:
"""
Expand All @@ -2116,23 +2102,23 @@ def _get_elements_to_be_rendered(
----------
render_cmds
List of tuples containing the commands and their respective parameters.
cs_contents
The dataframe indicating for each coordinate system which SpatialElements it contains.
cs_index
The cs_contents dataframe indexed by the "cs" column.
cs
The name of the coordinate system to query cs_contents for.
The name of the coordinate system to query cs_index for.

Returns
-------
List of names of the SpatialElements to be rendered in the plot.
"""
elements_to_be_rendered: list[str] = []

cs_query = cs_contents.query(f"cs == '{cs}'")
cs_row = cs_index.loc[cs] if cs in cs_index.index else None

for cmd, params in render_cmds:
key = _RENDER_CMD_TO_CS_FLAG.get(cmd)
if key and cs_query[key][0]:
elements_to_be_rendered += [params.element]
if key and cs_row is not None and cs_row[key]:
elements_to_be_rendered.append(params.element)

return elements_to_be_rendered

Expand Down
18 changes: 18 additions & 0 deletions tests/pl/test_render.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from spatialdata import SpatialData
from spatialdata.models import Image2DModel
from spatialdata.transformations import Identity, set_transformation

import spatialdata_plot # noqa: F401


def test_render_images_can_plot_one_cyx_image(request):
Expand Down Expand Up @@ -97,3 +103,15 @@ def test_single_ax_auto_cs_unresolvable_raises(sdata_multi_cs):
with pytest.raises(ValueError, match="coordinate_systems="):
# Only render shapes (present in both CS), so strict filter can't narrow down
sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax)


def test_cs_name_with_apostrophe_does_not_crash():
# Regression test for #602: .query(f"cs == '{cs}'") raised TokenError for cs names
# containing single quotes.
data = np.zeros((1, 10, 10), dtype=np.float64)
img = Image2DModel.parse(data, dims=("c", "y", "x"))
sdata = SpatialData(images={"img": img})
set_transformation(sdata["img"], Identity(), to_coordinate_system="patient's_cs")
_, ax = plt.subplots()
sdata.pl.render_images("img").pl.show(ax=ax, coordinate_systems="patient's_cs")
plt.close("all")
Loading