diff --git a/tests/test_figures.py b/tests/test_figures.py index 254a3bd..7789daa 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -9,7 +9,13 @@ import pytest import xarray as xr -from xarray_plotly import add_secondary_y, overlay, subplots, xpx +from xarray_plotly import ( + add_secondary_y, + overlay, + simplify_facet_titles, + subplots, + xpx, +) class TestOverlayBasic: @@ -922,3 +928,64 @@ def test_source_not_modified(self) -> None: original_count = len(fig.data) _ = subplots(fig, fig, cols=2) assert len(fig.data) == original_count + + +class TestSimplifyFacetTitles: + """Tests for the simplify_facet_titles helper and the `facet_titles` kwarg.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.da = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "country"], + coords={"country": ["United States", "China", "Brazil"]}, + name="value", + ) + + def test_helper_strips_dim_prefix(self) -> None: + fig = xpx(self.da).line(facet_col="country") + # PX writes annotations like "country=United States" + original_texts = [a.text for a in fig.layout.annotations] + assert any(t and t.startswith("country=") for t in original_texts) + + simplify_facet_titles(fig) + + for ann in fig.layout.annotations: + if ann.text: + assert "=" not in ann.text or ann.text.split("=", 1)[0] != "country" + + def test_helper_full_is_noop(self) -> None: + fig = xpx(self.da).line(facet_col="country") + before = [a.text for a in fig.layout.annotations] + simplify_facet_titles(fig, mode="default") + after = [a.text for a in fig.layout.annotations] + assert before == after + + def test_helper_invalid_mode_raises(self) -> None: + fig = xpx(self.da).line(facet_col="country") + with pytest.raises(ValueError, match="facet_titles must be"): + simplify_facet_titles(fig, mode="bogus") # type: ignore[arg-type] + + def test_helper_leaves_user_annotations_alone(self) -> None: + """User-added annotations without a Python-identifier prefix are preserved.""" + fig = xpx(self.da).line(facet_col="country") + fig.add_annotation(text="Some note", x=0, y=0, showarrow=False) + simplify_facet_titles(fig) + texts = [a.text for a in fig.layout.annotations] + assert "Some note" in texts + + def test_kwarg_default_keeps_px_format(self) -> None: + fig = xpx(self.da).line(facet_col="country") + # At least one annotation still carries the dim= prefix. + assert any(a.text and a.text.startswith("country=") for a in fig.layout.annotations) + + def test_kwarg_value_strips_prefix(self) -> None: + fig = xpx(self.da).line(facet_col="country", facet_titles="value") + for ann in fig.layout.annotations: + if ann.text: + # Should not start with "country="; the dim prefix is stripped. + assert not ann.text.startswith("country=") + + def test_kwarg_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="facet_titles must be"): + xpx(self.da).line(facet_col="country", facet_titles="bogus") # type: ignore[arg-type] diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index c526173..49250e1 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -56,6 +56,7 @@ from xarray_plotly.figures import ( add_secondary_y, overlay, + simplify_facet_titles, subplots, update_traces, ) @@ -67,6 +68,7 @@ "auto", "config", "overlay", + "simplify_facet_titles", "subplots", "update_traces", "xpx", diff --git a/xarray_plotly/accessor.py b/xarray_plotly/accessor.py index eb3ddf0..231e4e6 100644 --- a/xarray_plotly/accessor.py +++ b/xarray_plotly/accessor.py @@ -6,7 +6,7 @@ from xarray import DataArray, Dataset from xarray_plotly import plotting -from xarray_plotly.common import Colors, SlotValue, auto +from xarray_plotly.common import Colors, FacetTitlesMode, SlotValue, auto from xarray_plotly.config import _options @@ -54,6 +54,7 @@ def line( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive line plot. @@ -84,6 +85,7 @@ def line( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -97,6 +99,7 @@ def bar( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive bar chart. @@ -125,6 +128,7 @@ def bar( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -138,6 +142,7 @@ def area( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive stacked area chart. @@ -166,6 +171,7 @@ def area( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -178,6 +184,7 @@ def fast_bar( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create a bar-like chart using stacked areas for better performance. @@ -204,6 +211,7 @@ def fast_bar( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -218,6 +226,7 @@ def scatter( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive scatter plot. @@ -252,6 +261,7 @@ def scatter( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -264,6 +274,7 @@ def box( facet_row: SlotValue = None, animation_frame: SlotValue = None, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive box plot. @@ -293,6 +304,7 @@ def box( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -305,6 +317,7 @@ def imshow( animation_frame: SlotValue = auto, robust: bool = False, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive heatmap image. @@ -337,6 +350,7 @@ def imshow( animation_frame=animation_frame, robust=robust, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -348,6 +362,7 @@ def pie( facet_col: SlotValue = auto, facet_row: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive pie chart. @@ -372,6 +387,7 @@ def pie( facet_col=facet_col, facet_row=facet_row, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -452,6 +468,7 @@ def line( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive line plot. @@ -482,6 +499,7 @@ def line( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -496,6 +514,7 @@ def bar( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive bar chart. @@ -524,6 +543,7 @@ def bar( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -538,6 +558,7 @@ def area( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive stacked area chart. @@ -566,6 +587,7 @@ def area( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -579,6 +601,7 @@ def fast_bar( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create a bar-like chart using stacked areas for better performance. @@ -605,6 +628,7 @@ def fast_bar( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -620,6 +644,7 @@ def scatter( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive scatter plot. @@ -650,6 +675,7 @@ def scatter( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -663,6 +689,7 @@ def box( facet_row: SlotValue = None, animation_frame: SlotValue = None, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive box plot. @@ -689,6 +716,7 @@ def box( facet_row=facet_row, animation_frame=animation_frame, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) @@ -701,6 +729,7 @@ def pie( facet_col: SlotValue = auto, facet_row: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """Create an interactive pie chart. @@ -725,5 +754,6 @@ def pie( facet_col=facet_col, facet_row=facet_row, colors=colors, + facet_titles=facet_titles, **px_kwargs, ) diff --git a/xarray_plotly/common.py b/xarray_plotly/common.py index 7805d4b..40e0012 100644 --- a/xarray_plotly/common.py +++ b/xarray_plotly/common.py @@ -5,7 +5,7 @@ import functools import warnings from collections.abc import Hashable, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import plotly.express as px @@ -39,6 +39,13 @@ def __repr__(self) -> str: - None: Use Plotly defaults """ +FacetTitlesMode = Literal["value", "default"] +"""Type alias for facet_titles parameter. + +- "default" (default): keep PX's ``"="`` subplot titles. +- "value": strip the ``=`` prefix, leaving just the value. +""" + # Re-export for backward compatibility SLOT_ORDERS = DEFAULT_SLOT_ORDERS """Slot orders per plot type. diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 8a4bf00..2d9b231 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -5,6 +5,7 @@ from __future__ import annotations import copy +import re from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -12,6 +13,8 @@ import plotly.graph_objects as go + from xarray_plotly.common import FacetTitlesMode + def _get_yaxis_title(fig: go.Figure) -> str: """Extract the primary y-axis title text from a figure. @@ -922,3 +925,49 @@ def update_traces( trace.update(**kwargs) return fig + + +# Matches an identifier-style PX facet prefix like "country=" at the start of +# annotation text. Defensive: ignores annotations a user added themselves +# whose text doesn't look like a dim assignment. +_FACET_TITLE_PREFIX_RE = re.compile(r"^[A-Za-z_]\w*=") + + +def simplify_facet_titles( + fig: go.Figure, + mode: FacetTitlesMode = "value", +) -> go.Figure: + """Strip the ``=`` prefix from Plotly Express facet subplot titles. + + PX renders faceted subplot titles as annotations like ``"country=Brazil"``. + With ``mode="value"`` (default), the prefix is stripped to just the value + (``"Brazil"``). With ``mode="default"``, the figure is returned unchanged. + + Only annotations whose text matches a Python-identifier prefix followed by + ``=`` are touched, so user-added annotations are left alone. + + Args: + fig: A Plotly figure (mutated in place). + mode: ``"value"`` to strip the prefix, ``"default"`` to keep PX's format. + + Returns: + The (possibly mutated) figure, for chaining. + + Raises: + ValueError: If ``mode`` is not ``"value"`` or ``"default"``. + + Example: + >>> from xarray_plotly import xpx, simplify_facet_titles + >>> fig = xpx(da).line(facet_col="country") + >>> simplify_facet_titles(fig) # "country=Brazil" -> "Brazil" + """ + if mode == "default": + return fig + if mode != "value": + raise ValueError(f"facet_titles must be 'value' or 'full', got {mode!r}") + + for ann in fig.layout.annotations or (): + text = ann.text + if text and _FACET_TITLE_PREFIX_RE.match(text): + ann.text = text.split("=", 1)[1] + return fig diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index a45cbd5..5ec8a90 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -13,6 +13,7 @@ from xarray_plotly.common import ( Colors, + FacetTitlesMode, SlotValue, assign_slots, auto, @@ -24,6 +25,7 @@ ) from xarray_plotly.figures import ( _iter_all_traces, + simplify_facet_titles, ) if TYPE_CHECKING: @@ -42,6 +44,7 @@ def line( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """ @@ -99,7 +102,7 @@ def line( value_col = get_value_col(darray) labels = {**build_labels(darray, slots, value_col), **px_kwargs.pop("labels", {})} - return px.line( + fig = px.line( df, x=slots.get("x"), y=value_col, @@ -112,6 +115,7 @@ def line( labels=labels, **px_kwargs, ) + return simplify_facet_titles(fig, facet_titles) def bar( @@ -124,6 +128,7 @@ def bar( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """ @@ -178,7 +183,7 @@ def bar( value_col = get_value_col(darray) labels = {**build_labels(darray, slots, value_col), **px_kwargs.pop("labels", {})} - return px.bar( + fig = px.bar( df, x=slots.get("x"), y=value_col, @@ -190,6 +195,7 @@ def bar( labels=labels, **px_kwargs, ) + return simplify_facet_titles(fig, facet_titles) def _classify_trace_sign(y_values: npt.ArrayLike) -> str: @@ -285,6 +291,7 @@ def fast_bar( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """ @@ -359,7 +366,7 @@ def fast_bar( _style_traces_as_bars(fig) - return fig + return simplify_facet_titles(fig, facet_titles) def area( @@ -372,6 +379,7 @@ def area( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """ @@ -426,7 +434,7 @@ def area( value_col = get_value_col(darray) labels = {**build_labels(darray, slots, value_col), **px_kwargs.pop("labels", {})} - return px.area( + fig = px.area( df, x=slots.get("x"), y=value_col, @@ -438,6 +446,7 @@ def area( labels=labels, **px_kwargs, ) + return simplify_facet_titles(fig, facet_titles) def box( @@ -449,6 +458,7 @@ def box( facet_row: SlotValue = None, animation_frame: SlotValue = None, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """ @@ -503,7 +513,7 @@ def box( value_col = get_value_col(darray) labels = {**build_labels(darray, slots, value_col), **px_kwargs.pop("labels", {})} - return px.box( + fig = px.box( df, x=slots.get("x"), y=value_col, @@ -514,6 +524,7 @@ def box( labels=labels, **px_kwargs, ) + return simplify_facet_titles(fig, facet_titles) def scatter( @@ -527,6 +538,7 @@ def scatter( facet_row: SlotValue = auto, animation_frame: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """ @@ -600,7 +612,7 @@ def scatter( if y_is_dim and str(y) not in labels: labels[str(y)] = get_label(darray, y) - return px.scatter( + fig = px.scatter( df, x=slots.get("x"), y=y_col, @@ -612,6 +624,7 @@ def scatter( labels=labels, **px_kwargs, ) + return simplify_facet_titles(fig, facet_titles) def imshow( @@ -623,6 +636,7 @@ def imshow( animation_frame: SlotValue = auto, robust: bool = False, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """ @@ -697,12 +711,13 @@ def imshow( px_kwargs.setdefault("zmin", zmin) px_kwargs.setdefault("zmax", zmax) - return px.imshow( + fig = px.imshow( plot_data, facet_col=slots.get("facet_col"), animation_frame=slots.get("animation_frame"), **px_kwargs, ) + return simplify_facet_titles(fig, facet_titles) def pie( @@ -713,6 +728,7 @@ def pie( facet_col: SlotValue = auto, facet_row: SlotValue = auto, colors: Colors = None, + facet_titles: FacetTitlesMode = "default", **px_kwargs: Any, ) -> go.Figure: """ @@ -762,7 +778,7 @@ def pie( # Use names dimension for color if not explicitly set color_col = color if color is not None else slots.get("names") - return px.pie( + fig = px.pie( df, names=slots.get("names"), values=value_col, @@ -772,3 +788,4 @@ def pie( labels=labels, **px_kwargs, ) + return simplify_facet_titles(fig, facet_titles)