diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index a72d752..b190828 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -351,6 +351,70 @@ "cell_type": "markdown", "id": "21", "metadata": {}, + "source": [ + "### Multi-Trace Figures\n", + "\n", + "When both base and secondary figures already split into multiple traces (e.g. via a categorical xarray dimension), `add_secondary_y` keeps every trace visible in the legend. By default (`legend=\"suffix\"`), traces with the same name across the two figures are disambiguated by appending each figure's y-axis title — `\"Brazil (Population)\"` vs `\"Brazil (GDP per Capita)\"` — so each trace is its own legend entry and toggles independently." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "# All four countries on both axes — population on left, GDP per capita on right.\n", + "# Both figures have legendgroups \"United States\", \"China\", ... — without\n", + "# namespacing, the secondary's three traces would be deduped out of the legend.\n", + "pop_fig = xpx(population).line(markers=True)\n", + "gdp_fig = xpx(gdp_per_capita).line()\n", + "\n", + "combined = add_secondary_y(pop_fig, gdp_fig, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(\n", + " title=\"Population (left) vs GDP per Capita (right) — all countries\",\n", + " yaxis_title=\"Population\",\n", + ")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "#### Choosing how the legend treats same-named traces\n", + "\n", + "`add_secondary_y` accepts a `legend=` argument controlling how same-named\n", + "traces from the two figures are presented:\n", + "\n", + "- `\"suffix\"` (default): each trace gets its own entry with the source y-axis title appended.\n", + "- `\"merge\"`: same-named traces share a `legendgroup`, collapsing to a single entry that toggles both axes together.\n", + "- `\"separate\"`: PX entries are left as-is; duplicate names are accepted and each trace toggles alone." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "# legend=\"merge\": clicking \"United States\" in the legend toggles both\n", + "# the Population trace (left axis) and the GDP per Capita trace (right axis).\n", + "pop_fig = xpx(population).line()\n", + "gdp_fig = xpx(gdp_per_capita).line()\n", + "combined_merge = add_secondary_y(\n", + " pop_fig, gdp_fig, secondary_y_title=\"GDP per Capita ($)\", legend=\"merge\"\n", + ")\n", + "combined_merge.update_layout(title='legend=\"merge\" — one entry per country, toggles both axes')\n", + "combined_merge" + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, "source": [ "### With Animation\n", "\n", @@ -360,7 +424,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -378,7 +442,7 @@ }, { "cell_type": "markdown", - "id": "23", + "id": "27", "metadata": {}, "source": [ "### Static Secondary on Animated Base\n", @@ -389,7 +453,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -410,7 +474,7 @@ }, { "cell_type": "markdown", - "id": "25", + "id": "29", "metadata": {}, "source": [ "### With Facets\n", @@ -421,7 +485,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -439,7 +503,95 @@ }, { "cell_type": "markdown", - "id": "27", + "id": "31", + "metadata": {}, + "source": [ + "### Multi-Trace + Facets\n", + "\n", + "Both axes carry multiple traces *within* each facet. This combines facet\n", + "structure, multi-trace legendgroups, and the secondary y-axis layout." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [ + "# Synthetic 3D data: month x sensor x site, on two scales (Temperature vs Power).\n", + "# facet_col=\"site\" splits into facets; the \"sensor\" dimension produces multiple\n", + "# traces within each facet, identically named on both axes.\n", + "import numpy as np\n", + "\n", + "rng = np.random.default_rng(0)\n", + "months = np.arange(1, 13)\n", + "sensors = [\"A\", \"B\", \"C\"]\n", + "sites = [\"North\", \"South\"]\n", + "\n", + "temp_da = xr.DataArray(\n", + " 20 + rng.standard_normal((12, 3, 2)) * 5,\n", + " dims=[\"month\", \"sensor\", \"site\"],\n", + " coords={\"month\": months, \"sensor\": sensors, \"site\": sites},\n", + " name=\"Temperature (°C)\",\n", + ")\n", + "power_da = xr.DataArray(\n", + " 400 + rng.standard_normal((12, 3, 2)) * 80,\n", + " dims=[\"month\", \"sensor\", \"site\"],\n", + " coords={\"month\": months, \"sensor\": sensors, \"site\": sites},\n", + " name=\"Power (W)\",\n", + ")\n", + "\n", + "temp_fig = xpx(temp_da).line(facet_col=\"site\", markers=True)\n", + "power_fig = xpx(power_da).line(facet_col=\"site\")\n", + "\n", + "combined = add_secondary_y(temp_fig, power_fig)\n", + "combined.update_layout(title=\"Sensor Temperature (left) vs Power (right) — faceted by site\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "33", + "metadata": {}, + "source": [ + "### Composing `overlay` with `add_secondary_y`\n", + "\n", + "`add_secondary_y` accepts any base figure — including one already produced by `overlay`. Useful when you want to overlay a trend on the primary axis, then bring in a second variable on a secondary axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "# Primary axis: GOOG daily price + 20-day moving average (via overlay).\n", + "# Secondary axis: AAPL daily price for comparison.\n", + "goog = stocks.sel(company=\"GOOG\")\n", + "goog_ma = goog.rolling(date=20, center=True).mean()\n", + "goog_ma.name = \"GOOG 20-day MA\"\n", + "\n", + "price_fig = xpx(goog).scatter()\n", + "price_fig.update_traces(marker={\"size\": 4, \"opacity\": 0.5}, name=\"GOOG Daily\")\n", + "ma_fig = xpx(goog_ma).line()\n", + "ma_fig.update_traces(line={\"color\": \"red\", \"width\": 2}, name=\"GOOG 20-day MA\")\n", + "\n", + "base = overlay(price_fig, ma_fig)\n", + "\n", + "aapl = stocks.sel(company=\"AAPL\")\n", + "aapl_fig = xpx(aapl).line()\n", + "aapl_fig.update_traces(line={\"color\": \"green\", \"width\": 2}, name=\"AAPL\")\n", + "\n", + "combined = add_secondary_y(base, aapl_fig, secondary_y_title=\"AAPL Price\")\n", + "combined.update_layout(title=\"GOOG (left, raw + MA) vs AAPL (right)\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "35", "metadata": {}, "source": [ "## subplots\n", @@ -450,7 +602,7 @@ }, { "cell_type": "markdown", - "id": "28", + "id": "36", "metadata": {}, "source": [ "### Different Variables Side by Side" @@ -459,7 +611,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "37", "metadata": {}, "outputs": [], "source": [ @@ -479,7 +631,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "38", "metadata": {}, "source": [ "### 2x2 Grid\n", @@ -490,7 +642,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "39", "metadata": {}, "outputs": [], "source": [ @@ -507,7 +659,7 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "40", "metadata": {}, "source": [ "### Mixed Chart Types\n", @@ -519,7 +671,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "41", "metadata": {}, "outputs": [], "source": [ @@ -535,7 +687,7 @@ }, { "cell_type": "markdown", - "id": "34", + "id": "42", "metadata": {}, "source": [ "### With Facets\n", @@ -546,7 +698,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "43", "metadata": {}, "outputs": [], "source": [ @@ -561,7 +713,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "44", "metadata": {}, "source": [ "---\n", @@ -573,7 +725,7 @@ }, { "cell_type": "markdown", - "id": "37", + "id": "45", "metadata": {}, "source": [ "### overlay: Mismatched Facet Structure\n", @@ -584,7 +736,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "46", "metadata": {}, "outputs": [], "source": [ @@ -602,7 +754,7 @@ }, { "cell_type": "markdown", - "id": "39", + "id": "47", "metadata": {}, "source": [ "### overlay: Animated Overlay on Static Base\n", @@ -613,7 +765,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40", + "id": "48", "metadata": {}, "outputs": [], "source": [ @@ -631,7 +783,7 @@ }, { "cell_type": "markdown", - "id": "41", + "id": "49", "metadata": {}, "source": [ "### overlay: Mismatched Animation Frames\n", @@ -642,7 +794,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42", + "id": "50", "metadata": {}, "outputs": [], "source": [ @@ -658,7 +810,7 @@ }, { "cell_type": "markdown", - "id": "43", + "id": "51", "metadata": {}, "source": [ "### add_secondary_y: Mismatched Facet Structure\n", @@ -669,7 +821,7 @@ { "cell_type": "code", "execution_count": null, - "id": "44", + "id": "52", "metadata": {}, "outputs": [], "source": [ @@ -687,7 +839,7 @@ }, { "cell_type": "markdown", - "id": "45", + "id": "53", "metadata": {}, "source": [ "### add_secondary_y: Animated Secondary on Static Base\n", @@ -698,7 +850,7 @@ { "cell_type": "code", "execution_count": null, - "id": "46", + "id": "54", "metadata": {}, "outputs": [], "source": [ @@ -716,7 +868,7 @@ }, { "cell_type": "markdown", - "id": "47", + "id": "55", "metadata": {}, "source": [ "### add_secondary_y: Mismatched Animation Frames" @@ -725,7 +877,7 @@ { "cell_type": "code", "execution_count": null, - "id": "48", + "id": "56", "metadata": {}, "outputs": [], "source": [ @@ -741,7 +893,7 @@ }, { "cell_type": "markdown", - "id": "49", + "id": "57", "metadata": {}, "source": [ "## Summary\n", @@ -761,8 +913,16 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.12.0" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" } }, "nbformat": 4, diff --git a/tests/test_figures.py b/tests/test_figures.py index 254a3bd..5dde538 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -703,6 +703,127 @@ def test_add_secondary_y_single_trace_with_names(self) -> None: assert combined.data[0].showlegend is True assert combined.data[1].showlegend is True + def _multi_trace_pair(self) -> tuple[xr.DataArray, xr.DataArray]: + da1 = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "cat"], + coords={"cat": ["a", "b", "c"]}, + name="Var1", + ) + da2 = xr.DataArray( + np.random.rand(10, 3) * 100, + dims=["x", "cat"], + coords={"cat": ["a", "b", "c"]}, + name="Var2", + ) + return da1, da2 + + def test_add_secondary_y_legend_suffix_default(self) -> None: + """legend="suffix" (default) gives every trace its own entry, names suffixed.""" + da1, da2 = self._multi_trace_pair() + combined = add_secondary_y(xpx(da1).line(), xpx(da2).line()) + + # All 6 traces visible with distinct legendgroups. + assert all(t.showlegend is True for t in combined.data) + legendgroups = [t.legendgroup for t in combined.data] + assert len(set(legendgroups)) == len(legendgroups) + # Names suffixed with the source y-axis title. + for t in combined.data[:3]: + assert t.name.endswith("(Var1)") + for t in combined.data[3:]: + assert t.name.endswith("(Var2)") + # Axis routing preserved. + assert all(t.yaxis == "y" for t in combined.data[:3]) + assert all(t.yaxis == "y2" for t in combined.data[3:]) + + def test_add_secondary_y_legend_merge(self) -> None: + """legend="merge" collapses same-named traces to a single legend entry.""" + da1, da2 = self._multi_trace_pair() + combined = add_secondary_y(xpx(da1).line(), xpx(da2).line(), legend="merge") + + # 3 visible entries (one per cat), names not suffixed, both axes share legendgroup. + visible = [t for t in combined.data if t.showlegend] + assert len(visible) == 3 + assert {t.name for t in visible} == {"a", "b", "c"} + # Each cat's two traces share a legendgroup so togglegroup toggles both. + for cat in ("a", "b", "c"): + members = [t for t in combined.data if t.legendgroup == cat] + assert len(members) == 2 + + def test_add_secondary_y_legend_separate(self) -> None: + """legend="separate" keeps PX entries as-is, accepting duplicate names.""" + da1, da2 = self._multi_trace_pair() + combined = add_secondary_y(xpx(da1).line(), xpx(da2).line(), legend="separate") + + # All 6 traces stay visible. + assert all(t.showlegend is True for t in combined.data) + # Names are duplicated across the two sources, not suffixed. + names = [t.name for t in combined.data] + assert names == ["a", "b", "c", "a", "b", "c"] + + def test_add_secondary_y_legend_invalid_raises(self) -> None: + da1, da2 = self._multi_trace_pair() + with pytest.raises(ValueError, match="legend mode must be"): + add_secondary_y(xpx(da1).line(), xpx(da2).line(), legend="bogus") # type: ignore[arg-type] + + def test_add_secondary_y_legend_anchored_to_container(self) -> None: + """Default layout anchors the legend to the figure container's right edge, + with automargin on the secondary y-axis so the axis title doesn't overlap.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") + da2 = xr.DataArray([100, 200, 300], dims=["x"], name="Precipitation") + + combined = add_secondary_y(xpx(da1).line(), xpx(da2).bar()) + + assert combined.layout.legend.x == 1.0 + assert combined.layout.legend.xanchor == "right" + assert combined.layout.legend.xref == "container" + assert combined.layout.legend.y == 1.0 + assert combined.layout.legend.yanchor == "top" + # yref left as default ("paper") so legend top aligns with plot top, + # not figure top (avoids overlapping the figure title). + assert combined.layout.legend.yref != "container" + # Secondary axis reserves its own margin space. + assert combined.layout.yaxis2.automargin is True + + def test_add_secondary_y_preserves_user_legend_position(self) -> None: + """User-set legend.x/y on the base figure is not overridden.""" + da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature") + da2 = xr.DataArray([100, 200, 300], dims=["x"], name="Precipitation") + + base = xpx(da1).line() + base.update_layout(legend={"x": 0.5, "y": 0.5}) + + combined = add_secondary_y(base, xpx(da2).bar()) + + assert combined.layout.legend.x == 0.5 + assert combined.layout.legend.y == 0.5 + + def test_add_secondary_y_after_overlay_keeps_secondary_visible(self) -> None: + """overlay → add_secondary_y must not hide the secondary's traces.""" + da1 = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "cat"], + coords={"cat": ["a", "b", "c"]}, + name="Var1", + ) + da2 = xr.DataArray( + np.random.rand(10, 3) * 100, + dims=["x", "cat"], + coords={"cat": ["a", "b", "c"]}, + name="Var2", + ) + fig1 = xpx(da1).line() + fig2 = xpx(da1).area() + overlaid = overlay(fig1, fig2) + fig3 = xpx(da2).line() + + combined = add_secondary_y(overlaid, fig3) + + # Secondary traces (last 3) must all be visible in the legend. + for t in combined.data[-3:]: + assert t.showlegend is True + assert t.yaxis == "y2" + def test_overlay_faceted_legendgroup_dedup(self) -> None: """Faceted overlay keeps only one showlegend=True per legendgroup.""" da = xr.DataArray( diff --git a/xarray_plotly/common.py b/xarray_plotly/common.py index 7805d4b..bc76018 100644 --- a/xarray_plotly/common.py +++ b/xarray_plotly/common.py @@ -5,7 +5,7 @@ import functools import warnings from collections.abc import Hashable, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import plotly.express as px @@ -39,6 +39,21 @@ def __repr__(self) -> str: - None: Use Plotly defaults """ +LegendMode = Literal["suffix", "merge", "separate"] +"""Type alias for the legend kwarg on ``add_secondary_y``. + +- ``"suffix"`` (default): each trace gets its own legend entry with the + source figure's y-axis title appended, e.g. ``"Brazil (Population)"`` + and ``"Brazil (GDP per Capita)"``. Every trace toggles independently. +- ``"merge"``: same-named traces across the two figures share a + legendgroup, collapsing to one legend entry per name. Clicking the + entry toggles both axes together (Plotly's default + ``legend.groupclick="togglegroup"``). +- ``"separate"``: leave Plotly Express legend output untouched, + accepting duplicate names across the two figures. Each trace + toggles alone. +""" + # Re-export for backward compatibility SLOT_ORDERS = DEFAULT_SLOT_ORDERS """Slot orders per plot type. diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 8a4bf00..a55803b 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -12,6 +12,8 @@ import plotly.graph_objects as go + from xarray_plotly.common import LegendMode + def _get_yaxis_title(fig: go.Figure) -> str: """Extract the primary y-axis title text from a figure. @@ -28,28 +30,66 @@ def _get_yaxis_title(fig: go.Figure) -> str: return "" +def _dedup_legend_within_traces(traces: list[Any]) -> None: + """Ensure one ``showlegend=True`` per ``legendgroup`` among the given traces.""" + from collections import defaultdict + + grouped: dict[str, list[Any]] = defaultdict(list) + ungrouped: list[Any] = [] + + for trace in traces: + lg = getattr(trace, "legendgroup", None) or "" + if lg: + grouped[lg].append(trace) + else: + ungrouped.append(trace) + + for group_traces in grouped.values(): + has_visible = False + for t in group_traces: + if has_visible: + t.showlegend = False + elif getattr(t, "name", None): + t.showlegend = True + has_visible = True + + for trace in ungrouped: + if getattr(trace, "name", None): + trace.showlegend = True + + def _ensure_legend_visibility( combined: go.Figure, source_figs: list[go.Figure], trace_slices: list[slice], + *, + mode: LegendMode = "merge", ) -> None: """Fix legend visibility on a combined figure. - Handles three problems that arise when combining Plotly Express figures: + Three modes control how same-named traces from different source figures + are presented in the legend: + + - ``"merge"``: traces sharing a ``legendgroup`` collapse to a single + legend entry (with one ``showlegend=True``). The default for + ``overlay`` and for ``add_secondary_y(legend="merge")``. + - ``"suffix"``: colliding ``legendgroup`` names across slices are + namespaced with the source figure's y-axis title, so each trace + becomes its own legend entry (e.g. ``"Brazil (Population)"`` and + ``"Brazil (GDP per Capita)"``). + - ``"separate"``: each source figure's traces are deduped only within + that source. Across sources, duplicate names are accepted as-is. - 1. **Unnamed traces** — PX sets ``name=""`` on single-trace (no color) - figures. We derive a name from each source figure's y-axis title. - 2. **Hidden named traces** — PX sets ``showlegend=False`` on single-trace - figures. We ensure at least one trace per ``legendgroup`` (or each - ungrouped named trace) has ``showlegend=True``. - 3. **Duplicate legend entries** — when two source figures share the same - ``legendgroup`` names, we deduplicate so only the first trace per - group shows in the legend. + All three modes also: + * Label unnamed traces using each source figure's y-axis title. + * Propagate name/legendgroup/style to animation frame traces, since + Plotly overwrites these on each frame. Args: combined: The combined Plotly figure (mutated in place). source_figs: The original source figures, in trace order. trace_slices: Slices into ``combined.data`` for each source figure. + mode: How to handle cross-source legend entries. """ from collections import defaultdict @@ -70,30 +110,46 @@ def _ensure_legend_visibility( trace.legendgroup = label # --- Step 2 & 3: fix showlegend per legendgroup ----------------------- - grouped: dict[str, list[Any]] = defaultdict(list) - ungrouped: list[Any] = [] - - for trace in combined.data: - lg = getattr(trace, "legendgroup", None) or "" - if lg: - grouped[lg].append(trace) - else: - ungrouped.append(trace) - - for traces in grouped.values(): - has_visible = False - for t in traces: - if has_visible: - # Deduplicate: only first keeps showlegend - t.showlegend = False - elif getattr(t, "name", None): - t.showlegend = True - has_visible = True - - # Ungrouped traces with a name should show in the legend - for trace in ungrouped: - if getattr(trace, "name", None): - trace.showlegend = True + if mode == "merge": + _dedup_legend_within_traces(list(combined.data)) + elif mode == "suffix": + # Namespace legendgroups that collide across slices, so each source + # keeps its own legend entries instead of being deduped away. + slice_groups: list[set[str]] = [] + for sl in trace_slices: + groups: set[str] = set() + for t in combined.data[sl]: + lg = getattr(t, "legendgroup", None) + if lg: + groups.add(lg) + slice_groups.append(groups) + group_counts: dict[str, int] = defaultdict(int) + for sg in slice_groups: + for g in sg: + group_counts[g] += 1 + colliding = {g for g, cnt in group_counts.items() if cnt > 1} + + for label, sl in zip(labels, trace_slices, strict=False): + if not label: + continue + for trace in combined.data[sl]: + lg = getattr(trace, "legendgroup", None) + if lg and lg in colliding: + new_lg = f"{lg} ({label})" + trace.legendgroup = new_lg + if getattr(trace, "name", None) == lg: + trace.name = new_lg + + for sl in trace_slices: + _dedup_legend_within_traces(list(combined.data[sl])) + elif mode == "separate": + # Dedup only within each source slice. Cross-source duplicates are + # left visible — same-named traces from different figures appear as + # distinct legend entries (with possibly identical names). + for sl in trace_slices: + _dedup_legend_within_traces(list(combined.data[sl])) + else: + raise ValueError(f"legend mode must be 'suffix', 'merge', or 'separate', got {mode!r}") # --- Step 4: propagate style properties to animation frame traces ------ # When Plotly animates, frame trace data overwrites fig.data properties. @@ -436,6 +492,7 @@ def add_secondary_y( secondary: go.Figure, *, secondary_y_title: str | None = None, + legend: LegendMode = "suffix", ) -> go.Figure: """Add a secondary y-axis with traces from another figure. @@ -449,6 +506,13 @@ def add_secondary_y( secondary: The figure whose traces use the secondary y-axis (right). secondary_y_title: Optional title for the secondary y-axis. If not provided, uses the secondary figure's y-axis title. + legend: How to handle same-named traces across the two figures. + ``"suffix"`` (default) gives each trace its own legend entry + with the source figure's y-axis title appended. ``"merge"`` + collapses same-named traces into a single legend entry that + toggles both axes together. ``"separate"`` leaves PX legend + output untouched, accepting duplicate names across the two + figures. Returns: A new figure with both primary and secondary y-axes. @@ -475,6 +539,9 @@ def add_secondary_y( >>> fig1 = xpx(data).line(facet_col="facet") >>> fig2 = xpx(data * 100).bar(facet_col="facet") # Different scale >>> combined = add_secondary_y(fig1, fig2) + >>> + >>> # Click "Brazil" in the legend toggles both Population and GDP traces + >>> combined = add_secondary_y(fig1, fig2, legend="merge") """ import plotly.graph_objects as go @@ -538,6 +605,9 @@ def add_secondary_y( "showticklabels": is_rightmost, # Link non-rightmost axes to the rightmost for consistent scaling "matches": None if is_rightmost else rightmost_secondary_y, + # Reserve margin space for tick labels and title so the legend + # placed at x>=1 can't clip them. + "automargin": True, } # Remove None values axis_config = {k: v for k, v in axis_config.items() if v is not None} @@ -557,11 +627,41 @@ def add_secondary_y( combined, [base, secondary], [slice(0, base_n), slice(base_n, base_n + sec_n)], + mode=legend, ) _fix_animation_axis_ranges(combined) + _set_default_secondary_y_layout(combined) return combined +def _set_default_secondary_y_layout(fig: go.Figure) -> None: + """Anchor the legend to the figure container so it doesn't fight the + secondary y-axis for paper-coordinate space. + + With ``xref="container"`` the legend's right edge sits at the figure's + right edge regardless of plot width. Combined with ``automargin=True`` + on the secondary y-axes (set in ``add_secondary_y``), Plotly reserves + space for the axis title between the plot and the legend, so the two + do not overlap. Only fields the user has not already set are touched, + so explicit ``update_layout(legend=...)`` on the source figures wins. + """ + legend_defaults: dict[str, Any] = {} + legend = fig.layout.legend + if legend.x is None: + # Container-relative x so the legend sits at the figure's right edge + # rather than fighting the secondary y-axis title for paper-coord space. + legend_defaults["x"] = 1.0 + legend_defaults["xanchor"] = "right" + legend_defaults["xref"] = "container" + if legend.y is None: + # Paper-relative y so the legend top aligns with the plot top (below + # the figure title) — same vertical position Plotly uses by default. + legend_defaults["y"] = 1.0 + legend_defaults["yanchor"] = "top" + if legend_defaults: + fig.update_layout(legend=legend_defaults) + + def _merge_secondary_y_frames( base: go.Figure, secondary: go.Figure,