Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,8 @@ def _render_images(
)

_ax_show_and_transform(stacked, trans_data, ax, **show_kwargs)
if render_params.channels_as_legend:
logger.warning("channels_as_legend is not supported for true RGB images and will be ignored.")
return

# 1) Image has only 1 channel
Expand Down Expand Up @@ -1386,7 +1388,13 @@ def _render_images(
is_continuous=True,
auto_condition=n_channels == 1,
)
if wants_colorbar and legend_params.colorbar and colorbar_requests is not None:
if render_params.channels_as_legend and channel_legend_entries is not None:
# Sample at 0.75 (upper quarter) for a vivid, non-extreme representative color;
# consistent with the multi-channel composite path below.
_collect_channel_legend_entries(
[channels[0]], [matplotlib.colors.to_hex(cmap(0.75))], channel_legend_entries
)
elif wants_colorbar and legend_params.colorbar and colorbar_requests is not None:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
colorbar_requests.append(
ColorbarSpec(
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 38 additions & 1 deletion tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from spatialdata.models import Image2DModel

import spatialdata_plot # noqa: F401
from spatialdata_plot._logging import logger, logger_warns
from spatialdata_plot.pl.render import _is_rgb_image
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over

Expand Down Expand Up @@ -545,13 +546,49 @@ def test_plot_channels_as_legend_legend_lower_right(self, sdata_blobs: SpatialDa
legend_loc="lower right"
)

def test_plot_channels_as_legend_single_channel(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", channel=0, channels_as_legend=True).pl.show()

def test_plot_channels_as_legend_sequential_single_channels(self, sdata_blobs_str: SpatialData):
(
sdata_blobs_str.pl.render_images(
element="blobs_image",
channel="c1",
palette=["cyan"],
alpha=0.5,
channels_as_legend=True,
)
.pl.render_images(
element="blobs_image",
channel="c2",
palette=["magenta"],
alpha=0.5,
channels_as_legend=True,
)
.pl.show()
)


class TestChannelsAsCategoriesNonVisual:
"""Non-visual tests for channels_as_legend edge cases."""

def test_channels_as_legend_ignored_for_single_channel(self, sdata_blobs: SpatialData):
def test_channels_as_legend_single_channel_shows_legend_no_colorbar(self, sdata_blobs: SpatialData):
fig, ax = plt.subplots()
sdata_blobs.pl.render_images(element="blobs_image", channel=0, channels_as_legend=True).pl.show(ax=ax)
legend = ax.get_legend()
assert legend is not None
assert "0" in [t.get_text() for t in legend.get_texts()]
assert len(fig.axes) == 1 # no colorbar inset axes
plt.close("all")

def test_channels_as_legend_rgb_warns_and_no_legend(self, caplog):
data = np.zeros((3, 50, 50), dtype=np.float64)
data[0], data[1], data[2] = 0.8, 0.2, 0.1
img = Image2DModel.parse(data, dims=("c", "y", "x"), c_coords=["r", "g", "b"])
sdata = SpatialData(images={"img": img})
fig, ax = plt.subplots()
with logger_warns(caplog, logger, match="not supported for true RGB"):
sdata.pl.render_images("img", channels_as_legend=True).pl.show(ax=ax)
assert ax.get_legend() is None
plt.close("all")

Expand Down
Loading