From 5dd8fed10dff5e5a585f2fbc8027164230a22375 Mon Sep 17 00:00:00 2001 From: chris-colinsky Date: Wed, 10 Jun 2026 16:44:14 -0700 Subject: [PATCH 1/2] Add RetryConfig record for RetryMiddleware Replace RetryMiddleware's individual constructor kwargs with a single frozen RetryConfig (max_attempts / classifier / backoff / on_retry), constructed as RetryMiddleware(RetryConfig(...)). This is the shared record the upcoming call-level complete(retry=...) parameter will take, so one config serves both the per-node and per-call retry layers. The config fields are Optional and resolve to the canonical defaults (default_classifier / exponential_jitter_backoff) once in the consumer, preserving the prior None-means-default behavior so fixture-driven construction stays robust. Breaking change to the RetryMiddleware constructor; all call sites across tests, examples, and docs are migrated. First of two refactor PRs splitting proposal 0050's remaining work; call-level retry follows. --- CHANGELOG.md | 4 ++ docs/concepts/middleware.md | 14 ++-- examples/fan-out-with-retry/main.py | 11 ++-- examples/parallel-branches/main.py | 7 +- src/openarmature/graph/__init__.py | 2 + src/openarmature/graph/middleware/__init__.py | 2 + src/openarmature/graph/middleware/retry.py | 66 +++++++++++-------- tests/conformance/test_observability.py | 10 +-- tests/conformance/test_pipeline_utilities.py | 9 ++- .../unit/test_failure_isolation_middleware.py | 3 +- tests/unit/test_fan_out.py | 3 +- tests/unit/test_middleware.py | 28 ++++++-- tests/unit/test_observability_metadata.py | 12 ++-- tests/unit/test_observability_otel.py | 12 ++-- 14 files changed, 122 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05af67e..03b6641 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The - **`FailureIsolationMiddleware`** (proposal 0050, pipeline-utilities §6.3). A third bundled middleware primitive alongside `RetryMiddleware` and `TimingMiddleware`. It catches exceptions escaping the wrapped node's inner chain and returns a configured degraded partial update, so a non-critical node can fail without aborting the whole invocation. Configuration: `degraded_update` (a static mapping or a `state -> partial_update` callable, resolved at catch time), `event_name` (required, no default, since a generic name makes downstream telemetry strictly worse), an optional `predicate` (`Exception -> bool`; only matching exceptions are caught, others propagate), and an optional async `on_caught` hook. It catches `Exception`; `BaseException` (cancellation) propagates, matching `RetryMiddleware`. On a catch it dispatches a new framework-emitted `FailureIsolatedEvent` (a distinct observer-event variant carrying `event_name`, the wrapped node's lineage identity, `pre_state` / `post_state`, and a `CaughtException` record of category plus message) onto the observer delivery queue; the bundled OTel and Langfuse observers render it as a marker span / observation. Compose it OUTER of `RetryMiddleware` for the "retry transients, degrade gracefully on exhaustion" pattern. Additive: existing pipelines see no behavior change, and the spec pin is unchanged (0050 is already within the v0.53.0 pin). +### Changed + +- **`RetryMiddleware` now takes a `RetryConfig` record** instead of individual constructor kwargs (proposal 0050 prep). The four retry settings (`max_attempts` / `classifier` / `backoff` / `on_retry`, each optional) move onto a frozen `RetryConfig`; construct as `RetryMiddleware(RetryConfig(max_attempts=...))`, while bare `RetryMiddleware()` still applies the defaults. This is a breaking change to the `RetryMiddleware` constructor. The record is the same shape the upcoming call-level `complete(retry=...)` parameter will accept, so one retry config serves both the per-node and per-call layers. `None` fields resolve to the canonical defaults (`default_classifier` / `exponential_jitter_backoff`) at use, preserving the prior behavior. + ## [0.13.0] — 2026-06-09 LLM provider hardening release. The pinned spec advances from v0.46.0 to v0.53.0, absorbing four implemented proposals. Proposal 0049 introduces the first spec-normatively-typed observer event variant, `LlmCompletionEvent`, dispatched on every successful LLM provider call; proposal 0058 adds the failure-side counterpart, `LlmFailedEvent`; proposal 0057 extends the completion variant with eight request-side fields. The bundled `OpenAIProvider` retires its sentinel-namespace `NodeEvent` emission for LLM calls entirely, and the OTel and Langfuse observers now drive their LLM span / Generation from the typed events with back-dated timestamps so durations reflect the adapter boundary. Proposal 0047 closes implicit prefix-cache wire-byte stability: `Response.usage` gains cache-stat fields, the OTel observer emits `openarmature.llm.cache_read` attributes, and the OpenAI Chat Completions request body is byte-stable across equivalent inputs regardless of dict insertion order. Custom observers that filtered LLM calls by sentinel namespace MUST migrate to `isinstance` discrimination; `LLM_NAMESPACE` and `LlmEventPayload` remain as a documented compatibility surface. diff --git a/docs/concepts/middleware.md b/docs/concepts/middleware.md index 1d15c4c..64d1743 100644 --- a/docs/concepts/middleware.md +++ b/docs/concepts/middleware.md @@ -126,7 +126,7 @@ hand a transformed state down the chain, pass a new state instance to ## Built-in: RetryMiddleware ```python -from openarmature.graph import RetryMiddleware, exponential_jitter_backoff +from openarmature.graph import RetryConfig, RetryMiddleware, exponential_jitter_backoff async def on_retry(exc: Exception, attempt: int) -> None: @@ -134,13 +134,15 @@ async def on_retry(exc: Exception, attempt: int) -> None: retry = RetryMiddleware( - max_attempts=3, - backoff=exponential_jitter_backoff, - on_retry=on_retry, + RetryConfig( + max_attempts=3, + backoff=exponential_jitter_backoff, + on_retry=on_retry, + ) ) ``` -Four plug points, all optional: +Configured with a `RetryConfig`; four fields, all optional: - **`max_attempts`** is the total attempt count including the first call. `1` disables retry. Default `3`. @@ -277,7 +279,7 @@ builder.add_node( degraded_update={"summary": ""}, event_name="summary_degraded", ), - RetryMiddleware(max_attempts=3), + RetryMiddleware(RetryConfig(max_attempts=3)), ], ) ``` diff --git a/examples/fan-out-with-retry/main.py b/examples/fan-out-with-retry/main.py index ad5ac60..c777cf7 100644 --- a/examples/fan-out-with-retry/main.py +++ b/examples/fan-out-with-retry/main.py @@ -84,6 +84,7 @@ append, ) from openarmature.graph.middleware import ( + RetryConfig, RetryMiddleware, TimingMiddleware, TimingRecord, @@ -261,10 +262,12 @@ def build_graph(error_policy: str = "fail_fast") -> CompiledGraph[BatchState]: headline_subgraph = build_headline_subgraph() retry = RetryMiddleware( - max_attempts=3, - # Short fixed delay so the demo isn't slow. A production app would - # use exponential_jitter_backoff (the default). - backoff=deterministic_backoff(0.2), + RetryConfig( + max_attempts=3, + # Short fixed delay so the demo isn't slow. A production app would + # use exponential_jitter_backoff (the default). + backoff=deterministic_backoff(0.2), + ) ) timing = TimingMiddleware( node_name="headline_run", diff --git a/examples/parallel-branches/main.py b/examples/parallel-branches/main.py index b1937a5..2aafcb2 100644 --- a/examples/parallel-branches/main.py +++ b/examples/parallel-branches/main.py @@ -76,6 +76,7 @@ append, ) from openarmature.graph.middleware import ( + RetryConfig, RetryMiddleware, deterministic_backoff, ) @@ -268,8 +269,10 @@ def build_graph() -> CompiledGraph[ArticleState]: # the same policy on a longer summarize call (where a retry doubles # cost) or on a topic-extract that has different transient profile. sentiment_retry = RetryMiddleware( - max_attempts=3, - backoff=deterministic_backoff(0.2), + RetryConfig( + max_attempts=3, + backoff=deterministic_backoff(0.2), + ) ) return ( diff --git a/src/openarmature/graph/__init__.py b/src/openarmature/graph/__init__.py index 21d3984..557d37d 100644 --- a/src/openarmature/graph/__init__.py +++ b/src/openarmature/graph/__init__.py @@ -51,6 +51,7 @@ FailureIsolationMiddleware, Middleware, NextCall, + RetryConfig, RetryMiddleware, TimingMiddleware, TimingRecord, @@ -115,6 +116,7 @@ "Reducer", "ReducerError", "RemoveHandle", + "RetryConfig", "RetryMiddleware", "RoutingError", "RuntimeGraphError", diff --git a/src/openarmature/graph/middleware/__init__.py b/src/openarmature/graph/middleware/__init__.py index 80f2e99..e4bc299 100644 --- a/src/openarmature/graph/middleware/__init__.py +++ b/src/openarmature/graph/middleware/__init__.py @@ -24,6 +24,7 @@ BackoffStrategy, Classifier, OnRetryCallback, + RetryConfig, RetryMiddleware, default_classifier, deterministic_backoff, @@ -41,6 +42,7 @@ "NextCall", "OnCompleteCallback", "OnRetryCallback", + "RetryConfig", "RetryMiddleware", "TRANSIENT_CATEGORIES", "TimingMiddleware", diff --git a/src/openarmature/graph/middleware/retry.py b/src/openarmature/graph/middleware/retry.py index f546392..8cf9902 100644 --- a/src/openarmature/graph/middleware/retry.py +++ b/src/openarmature/graph/middleware/retry.py @@ -20,6 +20,7 @@ import asyncio import random from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass from typing import Any from openarmature.llm.errors import TRANSIENT_CATEGORIES @@ -100,39 +101,51 @@ def fn(_attempt: int) -> float: OnRetryCallback = Callable[[Exception, int], Awaitable[None]] -class RetryMiddleware: - """Canonical retry middleware. - - Configuration: +@dataclass(frozen=True) +class RetryConfig: + """Canonical retry configuration record consumed by + :class:`RetryMiddleware`. - ``max_attempts``: total attempts including the first call. ``1`` disables retry. Default ``3``. - - ``classifier``: predicate ``(exception, state) -> bool``. Default - :func:`default_classifier` (matches ``category`` against + - ``classifier``: predicate ``(exception, state) -> bool`` deciding + whether a failure is retry-eligible. ``None`` (the default) + selects :func:`default_classifier` (matches ``category`` against ``TRANSIENT_CATEGORIES``). - - ``backoff``: callable ``(attempt_index) -> seconds``. Default - :func:`exponential_jitter_backoff` (base 1s, cap 30s, full jitter). + - ``backoff``: callable ``(attempt_index) -> seconds``. ``None`` + (the default) selects :func:`exponential_jitter_backoff` (base + 1s, cap 30s, full jitter). - ``on_retry``: optional async callback ``(exception, attempt_index) - -> None``. Fires before each sleep. + -> None`` fired before each backoff sleep. """ - def __init__( - self, - *, - max_attempts: int = 3, - classifier: Classifier | None = None, - backoff: BackoffStrategy | None = None, - on_retry: OnRetryCallback | None = None, - ) -> None: - if max_attempts < 1: + max_attempts: int = 3 + classifier: Classifier | None = None + backoff: BackoffStrategy | None = None + on_retry: OnRetryCallback | None = None + + def __post_init__(self) -> None: + if self.max_attempts < 1: raise ValueError("max_attempts must be >= 1") - self.max_attempts = max_attempts - self.classifier: Classifier = classifier or default_classifier - self.backoff: BackoffStrategy = backoff or exponential_jitter_backoff - self.on_retry: OnRetryCallback | None = on_retry + + +class RetryMiddleware: + """Canonical retry middleware. + + Configured with a :class:`RetryConfig` (or the default + ``RetryConfig()`` when omitted). Construct as + ``RetryMiddleware(RetryConfig(max_attempts=...))``. + """ + + def __init__(self, config: RetryConfig | None = None) -> None: + self.config = config if config is not None else RetryConfig() async def __call__(self, state: Any, next_: NextCall) -> Mapping[str, Any]: attempt = 0 + # ``None`` config fields select the canonical defaults; resolve + # once here so the loop works against concrete callables. + classifier = self.config.classifier or default_classifier + backoff = self.config.backoff or exponential_jitter_backoff # Spec observability §3.4 per-attempt scoping: each retry # attempt sees only the metadata in scope at retry-loop entry # ("pre-attempt baseline") plus that attempt's own writes; @@ -176,11 +189,11 @@ async def __call__(self, state: Any, next_: NextCall) -> Mapping[str, Any]: # metadata for the error span) sees the baseline, # not the failed attempt's transient state. _reset_invocation_metadata(metadata_token) - if attempt + 1 >= self.max_attempts or not self.classifier(exc, state): + if attempt + 1 >= self.config.max_attempts or not classifier(exc, state): raise - if self.on_retry is not None: - await self.on_retry(exc, attempt) - await asyncio.sleep(self.backoff(attempt)) + if self.config.on_retry is not None: + await self.config.on_retry(exc, attempt) + await asyncio.sleep(backoff(attempt)) attempt += 1 except BaseException: # Cancellation path. `CancelledError` (or other @@ -202,6 +215,7 @@ async def __call__(self, state: Any, next_: NextCall) -> Mapping[str, Any]: "BackoffStrategy", "Classifier", "OnRetryCallback", + "RetryConfig", "RetryMiddleware", "TRANSIENT_CATEGORIES", "default_classifier", diff --git a/tests/conformance/test_observability.py b/tests/conformance/test_observability.py index d9a43c4..6dcef62 100644 --- a/tests/conformance/test_observability.py +++ b/tests/conformance/test_observability.py @@ -678,7 +678,7 @@ async def _run_fixture_007_case(case: Mapping[str, Any]) -> None: from opentelemetry.trace import StatusCode from openarmature.graph import RuntimeGraphError - from openarmature.graph.middleware import RetryMiddleware + from openarmature.graph.middleware import RetryConfig, RetryMiddleware from openarmature.graph.middleware.retry import deterministic_backoff observer, exporter = _build_observer() @@ -725,9 +725,11 @@ def _classifier(exc: Exception, _state: Any, _transient: frozenset[str] = transi classifier_fn = None node_middleware.setdefault(flaky_node_name, []).append( RetryMiddleware( - max_attempts=int(mw_spec.get("max_attempts", 3)), - backoff=backoff, - classifier=classifier_fn, + RetryConfig( + max_attempts=int(mw_spec.get("max_attempts", 3)), + backoff=backoff, + classifier=classifier_fn, + ) ) ) diff --git a/tests/conformance/test_pipeline_utilities.py b/tests/conformance/test_pipeline_utilities.py index 7be9b35..80acac5 100644 --- a/tests/conformance/test_pipeline_utilities.py +++ b/tests/conformance/test_pipeline_utilities.py @@ -30,6 +30,7 @@ from openarmature.graph.middleware import ( Middleware, OnCompleteCallback, + RetryConfig, RetryMiddleware, TimingMiddleware, TimingRecord, @@ -234,9 +235,11 @@ def _build_middleware( classifier_cfg = config.get("classifier") classifier = _build_classifier(classifier_cfg) if classifier_cfg is not None else None return RetryMiddleware( - max_attempts=int(config.get("max_attempts", 3)), - backoff=backoff, - classifier=classifier, + RetryConfig( + max_attempts=int(config.get("max_attempts", 3)), + backoff=backoff, + classifier=classifier, + ) ) if mw_type == "timing": on_complete_cfg = cast("dict[str, Any]", config.get("on_complete") or {}) diff --git a/tests/unit/test_failure_isolation_middleware.py b/tests/unit/test_failure_isolation_middleware.py index b29a303..57f3b08 100644 --- a/tests/unit/test_failure_isolation_middleware.py +++ b/tests/unit/test_failure_isolation_middleware.py @@ -22,6 +22,7 @@ FailureIsolationMiddleware, GraphBuilder, ObserverEvent, + RetryConfig, RetryMiddleware, State, append, @@ -290,7 +291,7 @@ async def _flaky(_s: _DocState) -> Mapping[str, Any]: degraded_update={"note": "gave_up"}, event_name="flaky_failed", ), - RetryMiddleware(max_attempts=3, backoff=deterministic_backoff(0.0)), + RetryMiddleware(RetryConfig(max_attempts=3, backoff=deterministic_backoff(0.0))), ], ) .add_edge("flaky", END) diff --git a/tests/unit/test_fan_out.py b/tests/unit/test_fan_out.py index d42689a..6a4dbae 100644 --- a/tests/unit/test_fan_out.py +++ b/tests/unit/test_fan_out.py @@ -37,6 +37,7 @@ FanOutFieldNotList, GraphBuilder, NodeException, + RetryConfig, RetryMiddleware, State, append, @@ -578,7 +579,7 @@ async def maybe_fail(state: WorkerState) -> Mapping[str, Any]: inner_builder.add_edge("compute", END) inner = inner_builder.compile() - retry = RetryMiddleware(max_attempts=3, backoff=deterministic_backoff(0)) + retry = RetryMiddleware(RetryConfig(max_attempts=3, backoff=deterministic_backoff(0))) builder: GraphBuilder[InstanceMwParentState] = GraphBuilder(InstanceMwParentState) builder.set_entry("process") diff --git a/tests/unit/test_middleware.py b/tests/unit/test_middleware.py index b4ee0eb..6d5c309 100644 --- a/tests/unit/test_middleware.py +++ b/tests/unit/test_middleware.py @@ -25,6 +25,7 @@ END, GraphBuilder, Middleware, + RetryConfig, RetryMiddleware, State, TimingMiddleware, @@ -148,7 +149,7 @@ async def innermost(_state: Any) -> Mapping[str, Any]: raise _CategorizedTransient() return {"trace": ["ok"]} - retry = RetryMiddleware(max_attempts=3, backoff=deterministic_backoff(0)) + retry = RetryMiddleware(RetryConfig(max_attempts=3, backoff=deterministic_backoff(0))) chain = compose_chain([retry], innermost) result = await chain(TraceState()) @@ -163,7 +164,7 @@ async def innermost(_state: Any) -> Mapping[str, Any]: attempts[0] += 1 raise _CategorizedTransient() - retry = RetryMiddleware(max_attempts=3, backoff=deterministic_backoff(0)) + retry = RetryMiddleware(RetryConfig(max_attempts=3, backoff=deterministic_backoff(0))) chain = compose_chain([retry], innermost) with pytest.raises(_CategorizedTransient): @@ -180,7 +181,7 @@ async def innermost(_state: Any) -> Mapping[str, Any]: attempts[0] += 1 raise _CategorizedFatal() - retry = RetryMiddleware(max_attempts=5, backoff=deterministic_backoff(0)) + retry = RetryMiddleware(RetryConfig(max_attempts=5, backoff=deterministic_backoff(0))) chain = compose_chain([retry], innermost) with pytest.raises(_CategorizedFatal): @@ -201,7 +202,7 @@ async def innermost(_state: Any) -> Mapping[str, Any]: attempts[0] += 1 raise asyncio.CancelledError("aborted by host") - retry = RetryMiddleware(max_attempts=5, backoff=deterministic_backoff(0)) + retry = RetryMiddleware(RetryConfig(max_attempts=5, backoff=deterministic_backoff(0))) chain = compose_chain([retry], innermost) with pytest.raises(asyncio.CancelledError): @@ -210,6 +211,25 @@ async def innermost(_state: Any) -> Mapping[str, Any]: assert attempts[0] == 1 +# ===== RetryConfig record ===== + + +def test_retry_config_validates_max_attempts() -> None: + with pytest.raises(ValueError, match="max_attempts must be >= 1"): + RetryConfig(max_attempts=0) + + +def test_retry_config_defaults_resolve_at_use() -> None: + # Optional fields default to None; the consumer (RetryMiddleware, and + # the upcoming call-level retry) resolves None to the canonical + # defaults. A bare RetryMiddleware() uses the default RetryConfig(). + cfg = RetryConfig() + assert cfg.max_attempts == 3 + assert cfg.classifier is None + assert cfg.backoff is None + assert RetryMiddleware().config == RetryConfig() + + # ===== 6. General error recovery ===== diff --git a/tests/unit/test_observability_metadata.py b/tests/unit/test_observability_metadata.py index 63763b9..cc07236 100644 --- a/tests/unit/test_observability_metadata.py +++ b/tests/unit/test_observability_metadata.py @@ -619,7 +619,7 @@ class _RetryTransient(Exception): async def test_per_attempt_scoping_under_retry_discards_failed_attempt_writes() -> None: - from openarmature.graph.middleware import RetryMiddleware + from openarmature.graph.middleware import RetryConfig, RetryMiddleware captured_attempt_1_read: dict[str, Any] = {} captured_downstream_read: dict[str, Any] = {} @@ -647,7 +647,7 @@ async def _downstream(_s: _SimpleState) -> dict[str, Any]: .add_node( "retried", _retried, - middleware=[RetryMiddleware(max_attempts=2, backoff=lambda _i: 0.0)], + middleware=[RetryMiddleware(RetryConfig(max_attempts=2, backoff=lambda _i: 0.0))], ) .add_node("downstream", _downstream) .add_edge("retried", "downstream") @@ -675,7 +675,7 @@ async def test_terminal_failure_discards_final_failed_attempt_writes() -> None: # AFTER the retry middleware re-raises a terminal failure, the # metadata ContextVar is back at the pre-attempt baseline — no # leak of the final failed attempt's writes. - from openarmature.graph.middleware import RetryMiddleware, compose_chain + from openarmature.graph.middleware import RetryConfig, RetryMiddleware, compose_chain from openarmature.observability.metadata import ( _reset_invocation_metadata, _set_invocation_metadata, @@ -689,7 +689,7 @@ async def _always_fails(_state: Any) -> Mapping[str, Any]: set_invocation_metadata(attempt_marker=f"attempt_{len(attempts) - 1}") raise _RetryTransient() - retry = RetryMiddleware(max_attempts=2, backoff=lambda _i: 0.0) + retry = RetryMiddleware(RetryConfig(max_attempts=2, backoff=lambda _i: 0.0)) chain = compose_chain([retry], _always_fails) # Establish a baseline outside the middleware so we can read it @@ -716,7 +716,7 @@ async def test_cancellation_discards_in_flight_attempt_writes() -> None: # metadata-scoping perspective. Spec §6.1: cancellation MUST # propagate (no retry, no swallow), so the reset must happen IN # ADDITION to, not instead of, propagating ``CancelledError``. - from openarmature.graph.middleware import RetryMiddleware, compose_chain + from openarmature.graph.middleware import RetryConfig, RetryMiddleware, compose_chain from openarmature.observability.metadata import ( _reset_invocation_metadata, _set_invocation_metadata, @@ -730,7 +730,7 @@ async def _writes_then_cancels(_state: Any) -> Mapping[str, Any]: set_invocation_metadata(attempt_marker="leaked") raise asyncio.CancelledError("aborted") - retry = RetryMiddleware(max_attempts=3, backoff=lambda _i: 0.0) + retry = RetryMiddleware(RetryConfig(max_attempts=3, backoff=lambda _i: 0.0)) chain = compose_chain([retry], _writes_then_cancels) baseline_token = _set_invocation_metadata(validate_invocation_metadata({"tenantId": "T1"})) diff --git a/tests/unit/test_observability_otel.py b/tests/unit/test_observability_otel.py index d7fdfc0..89f5e21 100644 --- a/tests/unit/test_observability_otel.py +++ b/tests/unit/test_observability_otel.py @@ -1370,7 +1370,7 @@ async def test_llm_call_inside_retried_node_parents_per_attempt() -> None: ``innermost`` scope) is what makes this work.""" import httpx - from openarmature.graph.middleware import RetryMiddleware + from openarmature.graph.middleware import RetryConfig, RetryMiddleware from openarmature.llm.errors import ProviderRateLimit from openarmature.llm.messages import UserMessage from openarmature.llm.providers.openai import OpenAIProvider @@ -1419,7 +1419,11 @@ async def _flaky(s: _S) -> dict[str, int]: g = ( GraphBuilder(_S) - .add_node("flaky", _flaky, middleware=[RetryMiddleware(max_attempts=3, backoff=lambda _i: 0.0)]) + .add_node( + "flaky", + _flaky, + middleware=[RetryMiddleware(RetryConfig(max_attempts=3, backoff=lambda _i: 0.0))], + ) .add_edge("flaky", END) .set_entry("flaky") .compile() @@ -2241,7 +2245,7 @@ async def test_parallel_branches_node_under_retry_middleware_emits_per_attempt_d # - two per-branch dispatch spans per branch (one per attempt) # - each attempt's dispatch span parented under THAT attempt's # NODE span (not the wrong attempt's) - from openarmature.graph import BranchSpec, RetryMiddleware + from openarmature.graph import BranchSpec, RetryConfig, RetryMiddleware class _S(State): result: str = "" @@ -2271,7 +2275,7 @@ async def _flaky_branch(_s: _InnerS) -> dict[str, int]: .add_parallel_branches_node( "dispatcher", branches={"only_branch": BranchSpec(subgraph=inner)}, - middleware=[RetryMiddleware(max_attempts=2, classifier=lambda _exc, _state: True)], + middleware=[RetryMiddleware(RetryConfig(max_attempts=2, classifier=lambda _exc, _state: True))], ) .add_edge("dispatcher", END) .set_entry("dispatcher") From 16a4e62c1e59f804499b96d9a884fff49bd72972 Mon Sep 17 00:00:00 2001 From: chris-colinsky Date: Wed, 10 Jun 2026 16:57:14 -0700 Subject: [PATCH 2/2] Guard RetryMiddleware against non-RetryConfig args From CoPilot review of PR #150: RetryMiddleware now takes a positional config, so a non-RetryConfig argument (e.g. RetryMiddleware(3)) would construct and then fail with a cryptic AttributeError at retry time. Raise TypeError eagerly in __init__ with the correct-usage idiom, and add a test. --- src/openarmature/graph/middleware/retry.py | 14 +++++++++++++- tests/unit/test_middleware.py | 5 +++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/openarmature/graph/middleware/retry.py b/src/openarmature/graph/middleware/retry.py index 8cf9902..b90e125 100644 --- a/src/openarmature/graph/middleware/retry.py +++ b/src/openarmature/graph/middleware/retry.py @@ -138,7 +138,19 @@ class RetryMiddleware: """ def __init__(self, config: RetryConfig | None = None) -> None: - self.config = config if config is not None else RetryConfig() + if config is None: + config = RetryConfig() + # Defensive guard for untyped callers: the static type already + # rules a non-RetryConfig out (pyright flags this as redundant), + # but an eager TypeError beats a cryptic AttributeError when a + # mistyped value (e.g. ``RetryMiddleware(3)``) reaches ``.config``. + if not isinstance(config, RetryConfig): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"RetryMiddleware expects a RetryConfig (or None); got " + f"{type(config).__name__}. Construct as " + f"RetryMiddleware(RetryConfig(max_attempts=...))." + ) + self.config = config async def __call__(self, state: Any, next_: NextCall) -> Mapping[str, Any]: attempt = 0 diff --git a/tests/unit/test_middleware.py b/tests/unit/test_middleware.py index 6d5c309..efc9b33 100644 --- a/tests/unit/test_middleware.py +++ b/tests/unit/test_middleware.py @@ -230,6 +230,11 @@ def test_retry_config_defaults_resolve_at_use() -> None: assert RetryMiddleware().config == RetryConfig() +def test_retry_middleware_rejects_non_config() -> None: + with pytest.raises(TypeError, match="expects a RetryConfig"): + RetryMiddleware(3) # pyright: ignore[reportArgumentType] + + # ===== 6. General error recovery =====