From a24eaa5b746a63a4e30e7f32c894884b4414c9ab Mon Sep 17 00:00:00 2001 From: anon Date: Thu, 7 May 2026 18:48:59 +0200 Subject: [PATCH 1/5] =?UTF-8?q?Fix=20O(n=C2=B2)=20pd.concat=20in=20=5Fget?= =?UTF-8?q?=5Fcs=5Fcontents;=20eliminate=20.query()=20injection=20risk=20(?= =?UTF-8?q?#602)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/spatialdata_plot/pl/basic.py | 13 +++++----- src/spatialdata_plot/pl/utils.py | 42 ++++++++++++-------------------- 2 files changed, 22 insertions(+), 33 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 41965dec..5153d762 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1054,6 +1054,7 @@ def show( # Check if user specified only certain elements to be plotted cs_contents = _get_cs_contents(sdata) + cs_index = cs_contents.set_index("cs") pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]] = [] elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_contents, cs) @@ -1079,7 +1080,7 @@ def show( strict_cs = [ cs_name for cs_name in coordinate_systems - if all(cs_contents.query(f"cs == '{cs_name}'").iloc[0][flag] for flag in required_flags) + if cs_name in cs_index.index and all(cs_index.loc[cs_name][flag] for flag in required_flags) ] if strict_cs: coordinate_systems = strict_cs @@ -1197,15 +1198,15 @@ def _draw_colorbar( elif location == "top": trackers_axes["top"] = pad_axes + bbox_axes.height - cs_contents = _get_cs_contents(sdata) - # go through tree for i, cs in enumerate(coordinate_systems): sdata = self._copy() - _, has_images, has_labels, has_points, has_shapes = ( - cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist() - ) + cs_row = cs_index.loc[cs] + has_images = cs_row["has_images"] + has_labels = cs_row["has_labels"] + has_points = cs_row["has_points"] + has_shapes = cs_row["has_shapes"] ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i] assert isinstance(ax, Axes) axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 1aa283e9..2bd6a748 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -334,35 +334,22 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame: """Check which coordinate systems contain which elements and return that info.""" cs_mapping = _get_coordinate_system_mapping(sdata) content_flags = ["has_images", "has_labels", "has_points", "has_shapes"] - cs_contents = pd.DataFrame(columns=["cs"] + content_flags) + rows = [] for cs_name, element_ids in cs_mapping.items(): - # determine if coordinate system has the respective elements - cs_has_images = any(e in sdata.images for e in element_ids) - cs_has_labels = any(e in sdata.labels for e in element_ids) - cs_has_points = any(e in sdata.points for e in element_ids) - cs_has_shapes = any(e in sdata.shapes for e in element_ids) - - cs_contents = pd.concat( - [ - cs_contents, - pd.DataFrame( - { - "cs": cs_name, - "has_images": [cs_has_images], - "has_labels": [cs_has_labels], - "has_points": [cs_has_points], - "has_shapes": [cs_has_shapes], - } - ), - ] + rows.append( + { + "cs": cs_name, + "has_images": any(e in sdata.images for e in element_ids), + "has_labels": any(e in sdata.labels for e in element_ids), + "has_points": any(e in sdata.points for e in element_ids), + "has_shapes": any(e in sdata.shapes for e in element_ids), + } ) - cs_contents["has_images"] = cs_contents["has_images"].astype("bool") - cs_contents["has_labels"] = cs_contents["has_labels"].astype("bool") - cs_contents["has_points"] = cs_contents["has_points"].astype("bool") - cs_contents["has_shapes"] = cs_contents["has_shapes"].astype("bool") - + cs_contents = pd.DataFrame(rows, columns=["cs"] + content_flags) + for flag in content_flags: + cs_contents[flag] = cs_contents[flag].astype("bool") return cs_contents @@ -2127,11 +2114,12 @@ def _get_elements_to_be_rendered( """ elements_to_be_rendered: list[str] = [] - cs_query = cs_contents.query(f"cs == '{cs}'") + cs_index = cs_contents.set_index("cs") + cs_row = cs_index.loc[cs] if cs in cs_index.index else None for cmd, params in render_cmds: key = _RENDER_CMD_TO_CS_FLAG.get(cmd) - if key and cs_query[key][0]: + if key and cs_row is not None and cs_row[key]: elements_to_be_rendered += [params.element] return elements_to_be_rendered From c0f1e76d5eaccaebe604440e0607c8810c98ac10 Mon Sep 17 00:00:00 2001 From: anon Date: Thu, 7 May 2026 22:05:13 +0200 Subject: [PATCH 2/5] Add regression tests for apostrophe cs name and O(n) performance (#602) Co-Authored-By: Claude Sonnet 4.6 --- tests/pl/test_render.py | 44 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/pl/test_render.py b/tests/pl/test_render.py index 83c6ee3c..2db4587a 100644 --- a/tests/pl/test_render.py +++ b/tests/pl/test_render.py @@ -1,5 +1,14 @@ +import timeit + import matplotlib.pyplot as plt +import numpy as np import pytest +from spatialdata import SpatialData +from spatialdata.models import Image2DModel +from spatialdata.transformations import Identity, set_transformation + +import spatialdata_plot # noqa: F401 +from spatialdata_plot.pl.utils import _get_cs_contents def test_render_images_can_plot_one_cyx_image(request): @@ -97,3 +106,38 @@ def test_single_ax_auto_cs_unresolvable_raises(sdata_multi_cs): with pytest.raises(ValueError, match="coordinate_systems="): # Only render shapes (present in both CS), so strict filter can't narrow down sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax) + + +def test_cs_name_with_apostrophe_does_not_crash(): + # Regression test for #602: .query(f"cs == '{cs}'") raised TokenError for names + # containing single quotes (e.g. "patient's_sample"). + data = np.zeros((1, 10, 10), dtype=np.float64) + img = Image2DModel.parse(data, dims=("c", "y", "x")) + sdata = SpatialData(images={"img": img}) + set_transformation(sdata["img"], Identity(), to_coordinate_system="patient's_cs") + _, ax = plt.subplots() + sdata.pl.render_images("img").pl.show(ax=ax, coordinate_systems="patient's_cs") + plt.close("all") + + +def test_get_cs_contents_is_linear(): + # Regression test for #602: pd.concat inside loop was O(n²). + # Build two SpatialData objects: n=10 and n=50 coordinate systems. + def build(n: int) -> SpatialData: + data = np.zeros((1, 4, 4), dtype=np.float64) + images = {} + for i in range(n): + img_i = Image2DModel.parse(data.copy(), dims=("c", "y", "x")) + set_transformation(img_i, Identity(), to_coordinate_system=f"cs_{i}") + images[f"img_{i}"] = img_i + return SpatialData(images=images) + + sd10 = build(10) + sd50 = build(50) + t10 = timeit.timeit(lambda: _get_cs_contents(sd10), number=20) / 20 + t50 = timeit.timeit(lambda: _get_cs_contents(sd50), number=20) / 20 + ratio = t50 / t10 + # O(n) → ratio ≈ 5×; O(n²) → ratio ≈ 25×. Allow generous headroom for CI variance. + assert ratio < 15, ( + f"_get_cs_contents appears quadratic: n=10 {t10 * 1e3:.1f}ms, n=50 {t50 * 1e3:.1f}ms, ratio={ratio:.1f}x" + ) From d198cf30c0883e0d0fefe532acedc4840e7ce8c4 Mon Sep 17 00:00:00 2001 From: anon Date: Thu, 7 May 2026 22:07:53 +0200 Subject: [PATCH 3/5] simplify: eliminate duplicate set_index, vectorize astype, use list.append - Pass pre-built cs_index to _get_elements_to_be_rendered instead of rebuilding set_index("cs") inside the function - Vectorize astype("bool") to a single assignment over content_flags - Replace += [elem] with .append(elem) in render cmd loop Co-Authored-By: Claude Sonnet 4.6 --- src/spatialdata_plot/pl/basic.py | 2 +- src/spatialdata_plot/pl/utils.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 5153d762..1f11f20b 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1057,7 +1057,7 @@ def show( cs_index = cs_contents.set_index("cs") pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]] = [] - elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_contents, cs) + elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs) # filter out cs without relevant elements cmds = [cmd for cmd, _ in render_cmds] diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 2bd6a748..d36698d6 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -348,8 +348,7 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame: ) cs_contents = pd.DataFrame(rows, columns=["cs"] + content_flags) - for flag in content_flags: - cs_contents[flag] = cs_contents[flag].astype("bool") + cs_contents[content_flags] = cs_contents[content_flags].astype("bool") return cs_contents @@ -2093,7 +2092,7 @@ def _get_elements_to_be_rendered( ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, ] ], - cs_contents: pd.DataFrame, + cs_index: pd.DataFrame, cs: str, ) -> list[str]: """ @@ -2103,10 +2102,10 @@ def _get_elements_to_be_rendered( ---------- render_cmds List of tuples containing the commands and their respective parameters. - cs_contents - The dataframe indicating for each coordinate system which SpatialElements it contains. + cs_index + The cs_contents dataframe indexed by the "cs" column. cs - The name of the coordinate system to query cs_contents for. + The name of the coordinate system to query cs_index for. Returns ------- @@ -2114,13 +2113,12 @@ def _get_elements_to_be_rendered( """ elements_to_be_rendered: list[str] = [] - cs_index = cs_contents.set_index("cs") cs_row = cs_index.loc[cs] if cs in cs_index.index else None for cmd, params in render_cmds: key = _RENDER_CMD_TO_CS_FLAG.get(cmd) if key and cs_row is not None and cs_row[key]: - elements_to_be_rendered += [params.element] + elements_to_be_rendered.append(params.element) return elements_to_be_rendered From 82f207d3511c9809a1f341ca8703faf95d765ac7 Mon Sep 17 00:00:00 2001 From: anon Date: Thu, 7 May 2026 22:19:40 +0200 Subject: [PATCH 4/5] Remove superfluous comments from regression tests Co-Authored-By: Claude Sonnet 4.6 --- tests/pl/test_render.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/pl/test_render.py b/tests/pl/test_render.py index 2db4587a..4e45dbed 100644 --- a/tests/pl/test_render.py +++ b/tests/pl/test_render.py @@ -109,8 +109,8 @@ def test_single_ax_auto_cs_unresolvable_raises(sdata_multi_cs): def test_cs_name_with_apostrophe_does_not_crash(): - # Regression test for #602: .query(f"cs == '{cs}'") raised TokenError for names - # containing single quotes (e.g. "patient's_sample"). + # Regression test for #602: .query(f"cs == '{cs}'") raised TokenError for cs names + # containing single quotes. data = np.zeros((1, 10, 10), dtype=np.float64) img = Image2DModel.parse(data, dims=("c", "y", "x")) sdata = SpatialData(images={"img": img}) @@ -122,7 +122,6 @@ def test_cs_name_with_apostrophe_does_not_crash(): def test_get_cs_contents_is_linear(): # Regression test for #602: pd.concat inside loop was O(n²). - # Build two SpatialData objects: n=10 and n=50 coordinate systems. def build(n: int) -> SpatialData: data = np.zeros((1, 4, 4), dtype=np.float64) images = {} @@ -137,7 +136,6 @@ def build(n: int) -> SpatialData: t10 = timeit.timeit(lambda: _get_cs_contents(sd10), number=20) / 20 t50 = timeit.timeit(lambda: _get_cs_contents(sd50), number=20) / 20 ratio = t50 / t10 - # O(n) → ratio ≈ 5×; O(n²) → ratio ≈ 25×. Allow generous headroom for CI variance. assert ratio < 15, ( f"_get_cs_contents appears quadratic: n=10 {t10 * 1e3:.1f}ms, n=50 {t50 * 1e3:.1f}ms, ratio={ratio:.1f}x" ) From 6ffb163e586d0a4dd87337cfb7fee543382db543 Mon Sep 17 00:00:00 2001 From: anon Date: Thu, 7 May 2026 22:21:04 +0200 Subject: [PATCH 5/5] Remove timing-based linearity test for _get_cs_contents Co-Authored-By: Claude Sonnet 4.6 --- tests/pl/test_render.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/pl/test_render.py b/tests/pl/test_render.py index 4e45dbed..29dd36b1 100644 --- a/tests/pl/test_render.py +++ b/tests/pl/test_render.py @@ -1,5 +1,3 @@ -import timeit - import matplotlib.pyplot as plt import numpy as np import pytest @@ -8,7 +6,6 @@ from spatialdata.transformations import Identity, set_transformation import spatialdata_plot # noqa: F401 -from spatialdata_plot.pl.utils import _get_cs_contents def test_render_images_can_plot_one_cyx_image(request): @@ -118,24 +115,3 @@ def test_cs_name_with_apostrophe_does_not_crash(): _, ax = plt.subplots() sdata.pl.render_images("img").pl.show(ax=ax, coordinate_systems="patient's_cs") plt.close("all") - - -def test_get_cs_contents_is_linear(): - # Regression test for #602: pd.concat inside loop was O(n²). - def build(n: int) -> SpatialData: - data = np.zeros((1, 4, 4), dtype=np.float64) - images = {} - for i in range(n): - img_i = Image2DModel.parse(data.copy(), dims=("c", "y", "x")) - set_transformation(img_i, Identity(), to_coordinate_system=f"cs_{i}") - images[f"img_{i}"] = img_i - return SpatialData(images=images) - - sd10 = build(10) - sd50 = build(50) - t10 = timeit.timeit(lambda: _get_cs_contents(sd10), number=20) / 20 - t50 = timeit.timeit(lambda: _get_cs_contents(sd50), number=20) / 20 - ratio = t50 / t10 - assert ratio < 15, ( - f"_get_cs_contents appears quadratic: n=10 {t10 * 1e3:.1f}ms, n=50 {t50 * 1e3:.1f}ms, ratio={ratio:.1f}x" - )