Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion tests/test_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import pytest
import xarray as xr

from xarray_plotly import add_secondary_y, overlay, subplots, xpx
from xarray_plotly import (
add_secondary_y,
overlay,
simplify_facet_titles,
subplots,
xpx,
)


class TestOverlayBasic:
Expand Down Expand Up @@ -922,3 +928,64 @@ def test_source_not_modified(self) -> None:
original_count = len(fig.data)
_ = subplots(fig, fig, cols=2)
assert len(fig.data) == original_count


class TestSimplifyFacetTitles:
"""Tests for the simplify_facet_titles helper and the `facet_titles` kwarg."""

@pytest.fixture(autouse=True)
def setup(self) -> None:
self.da = xr.DataArray(
np.random.rand(10, 3),
dims=["x", "country"],
coords={"country": ["United States", "China", "Brazil"]},
name="value",
)

def test_helper_strips_dim_prefix(self) -> None:
fig = xpx(self.da).line(facet_col="country")
# PX writes annotations like "country=United States"
original_texts = [a.text for a in fig.layout.annotations]
assert any(t and t.startswith("country=") for t in original_texts)

simplify_facet_titles(fig)

for ann in fig.layout.annotations:
if ann.text:
assert "=" not in ann.text or ann.text.split("=", 1)[0] != "country"

def test_helper_full_is_noop(self) -> None:
fig = xpx(self.da).line(facet_col="country")
before = [a.text for a in fig.layout.annotations]
simplify_facet_titles(fig, mode="default")
after = [a.text for a in fig.layout.annotations]
assert before == after

def test_helper_invalid_mode_raises(self) -> None:
fig = xpx(self.da).line(facet_col="country")
with pytest.raises(ValueError, match="facet_titles must be"):
simplify_facet_titles(fig, mode="bogus") # type: ignore[arg-type]

def test_helper_leaves_user_annotations_alone(self) -> None:
"""User-added annotations without a Python-identifier prefix are preserved."""
fig = xpx(self.da).line(facet_col="country")
fig.add_annotation(text="Some note", x=0, y=0, showarrow=False)
simplify_facet_titles(fig)
texts = [a.text for a in fig.layout.annotations]
assert "Some note" in texts

def test_kwarg_default_keeps_px_format(self) -> None:
fig = xpx(self.da).line(facet_col="country")
# At least one annotation still carries the dim= prefix.
assert any(a.text and a.text.startswith("country=") for a in fig.layout.annotations)

def test_kwarg_value_strips_prefix(self) -> None:
fig = xpx(self.da).line(facet_col="country", facet_titles="value")
for ann in fig.layout.annotations:
if ann.text:
# Should not start with "country="; the dim prefix is stripped.
assert not ann.text.startswith("country=")

def test_kwarg_invalid_raises(self) -> None:
with pytest.raises(ValueError, match="facet_titles must be"):
xpx(self.da).line(facet_col="country", facet_titles="bogus") # type: ignore[arg-type]
2 changes: 2 additions & 0 deletions xarray_plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from xarray_plotly.figures import (
add_secondary_y,
overlay,
simplify_facet_titles,
subplots,
update_traces,
)
Expand All @@ -67,6 +68,7 @@
"auto",
"config",
"overlay",
"simplify_facet_titles",
"subplots",
"update_traces",
"xpx",
Expand Down
32 changes: 31 additions & 1 deletion xarray_plotly/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xarray import DataArray, Dataset

from xarray_plotly import plotting
from xarray_plotly.common import Colors, SlotValue, auto
from xarray_plotly.common import Colors, FacetTitlesMode, SlotValue, auto
from xarray_plotly.config import _options


Expand Down Expand Up @@ -54,6 +54,7 @@ def line(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive line plot.
Expand Down Expand Up @@ -84,6 +85,7 @@ def line(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -97,6 +99,7 @@ def bar(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive bar chart.
Expand Down Expand Up @@ -125,6 +128,7 @@ def bar(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -138,6 +142,7 @@ def area(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive stacked area chart.
Expand Down Expand Up @@ -166,6 +171,7 @@ def area(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -178,6 +184,7 @@ def fast_bar(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create a bar-like chart using stacked areas for better performance.
Expand All @@ -204,6 +211,7 @@ def fast_bar(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -218,6 +226,7 @@ def scatter(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive scatter plot.
Expand Down Expand Up @@ -252,6 +261,7 @@ def scatter(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -264,6 +274,7 @@ def box(
facet_row: SlotValue = None,
animation_frame: SlotValue = None,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive box plot.
Expand Down Expand Up @@ -293,6 +304,7 @@ def box(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -305,6 +317,7 @@ def imshow(
animation_frame: SlotValue = auto,
robust: bool = False,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive heatmap image.
Expand Down Expand Up @@ -337,6 +350,7 @@ def imshow(
animation_frame=animation_frame,
robust=robust,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -348,6 +362,7 @@ def pie(
facet_col: SlotValue = auto,
facet_row: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive pie chart.
Expand All @@ -372,6 +387,7 @@ def pie(
facet_col=facet_col,
facet_row=facet_row,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand Down Expand Up @@ -452,6 +468,7 @@ def line(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive line plot.
Expand Down Expand Up @@ -482,6 +499,7 @@ def line(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -496,6 +514,7 @@ def bar(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive bar chart.
Expand Down Expand Up @@ -524,6 +543,7 @@ def bar(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -538,6 +558,7 @@ def area(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive stacked area chart.
Expand Down Expand Up @@ -566,6 +587,7 @@ def area(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -579,6 +601,7 @@ def fast_bar(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create a bar-like chart using stacked areas for better performance.
Expand All @@ -605,6 +628,7 @@ def fast_bar(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -620,6 +644,7 @@ def scatter(
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive scatter plot.
Expand Down Expand Up @@ -650,6 +675,7 @@ def scatter(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -663,6 +689,7 @@ def box(
facet_row: SlotValue = None,
animation_frame: SlotValue = None,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive box plot.
Expand All @@ -689,6 +716,7 @@ def box(
facet_row=facet_row,
animation_frame=animation_frame,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)

Expand All @@ -701,6 +729,7 @@ def pie(
facet_col: SlotValue = auto,
facet_row: SlotValue = auto,
colors: Colors = None,
facet_titles: FacetTitlesMode = "default",
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive pie chart.
Expand All @@ -725,5 +754,6 @@ def pie(
facet_col=facet_col,
facet_row=facet_row,
colors=colors,
facet_titles=facet_titles,
**px_kwargs,
)
9 changes: 8 additions & 1 deletion xarray_plotly/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import functools
import warnings
from collections.abc import Hashable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import plotly.express as px

Expand Down Expand Up @@ -39,6 +39,13 @@ def __repr__(self) -> str:
- None: Use Plotly defaults
"""

FacetTitlesMode = Literal["value", "default"]
"""Type alias for facet_titles parameter.

- "default" (default): keep PX's ``"<dim>=<value>"`` subplot titles.
- "value": strip the ``<dim>=`` prefix, leaving just the value.
"""

# Re-export for backward compatibility
SLOT_ORDERS = DEFAULT_SLOT_ORDERS
"""Slot orders per plot type.
Expand Down
Loading