diff --git a/CHANGELOG.md b/CHANGELOG.md index 03b6641..d6a0243 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The ### Added - **`FailureIsolationMiddleware`** (proposal 0050, pipeline-utilities §6.3). A third bundled middleware primitive alongside `RetryMiddleware` and `TimingMiddleware`. It catches exceptions escaping the wrapped node's inner chain and returns a configured degraded partial update, so a non-critical node can fail without aborting the whole invocation. Configuration: `degraded_update` (a static mapping or a `state -> partial_update` callable, resolved at catch time), `event_name` (required, no default, since a generic name makes downstream telemetry strictly worse), an optional `predicate` (`Exception -> bool`; only matching exceptions are caught, others propagate), and an optional async `on_caught` hook. It catches `Exception`; `BaseException` (cancellation) propagates, matching `RetryMiddleware`. On a catch it dispatches a new framework-emitted `FailureIsolatedEvent` (a distinct observer-event variant carrying `event_name`, the wrapped node's lineage identity, `pre_state` / `post_state`, and a `CaughtException` record of category plus message) onto the observer delivery queue; the bundled OTel and Langfuse observers render it as a marker span / observation. Compose it OUTER of `RetryMiddleware` for the "retry transients, degrade gracefully on exhaustion" pattern. Additive: existing pipelines see no behavior change, and the spec pin is unchanged (0050 is already within the v0.53.0 pin). +- **Call-level retry on `Provider.complete()`** (proposal 0050, llm-provider §7). The provider's `complete()` gains an optional `retry: RetryConfig | None` parameter. When supplied, the wire call is retried in-call on transient provider errors per the config (classifier, backoff, `on_retry`, `max_attempts`), so a node issuing several LLM calls in a loop does not re-run the already-successful calls when a later call hits a transient failure. The request is built and validated once (pre-send validation errors are never retried), and the call stays terminal-only on the observability surface: exactly one `LlmCompletionEvent` (eventual success) or `LlmFailedEvent` (retry exhaustion or a non-transient error) fires per `complete()` call, with a single `call_id` shared across attempts. The per-attempt span surface (N per-attempt spans and the `openarmature.llm.attempt_index` attribute) is deferred to a future cycle; `conformance.toml` marks proposal 0050 `partial` accordingly. No spec-pin change. ### Changed diff --git a/conformance.toml b/conformance.toml index 587505c..1a4b602 100644 --- a/conformance.toml +++ b/conformance.toml @@ -372,10 +372,23 @@ status = "implemented" since = "0.13.0" # Spec v0.42.0 (proposal 0050). Retry & degradation primitives — -# failure-isolation middleware + call-level retry. Queued for -# v0.14.0 (largest single piece in the roadmap). +# failure-isolation middleware (§6.3) + call-level retry (§7). Both +# primitives implemented across the v0.14.0 cycle: +# FailureIsolationMiddleware (distinct FailureIsolatedEvent + +# CaughtException) and the call-level ``retry`` parameter on +# ``Provider.complete()`` — an in-call loop over transient §7 errors +# reusing the §6.1 RetryConfig record. ``partial`` because §7.1's +# per-attempt span surface — N ``openarmature.llm.complete`` spans + +# the ``openarmature.llm.attempt_index`` attribute — is DEFERRED: the +# python LLM span is rendered from the typed event, which is +# terminal-only per the graph-engine §6 mutual-exclusion contract, so +# per-attempt spans require a dedicated within-call sub-event +# (LlmRetryAttemptEvent) scoped to a future cycle. Call-level retry +# ships terminal-only: exactly one LlmCompletionEvent / LlmFailedEvent +# per ``complete()`` call. [proposals."0050"] -status = "not-yet" +status = "partial" +since = "0.14.0" # Spec v0.43.0 (proposal 0051). Langfuse trace.input/trace.output # implementation-surface caveat. Purely textual: documents that the diff --git a/docs/concepts/llms.md b/docs/concepts/llms.md index af8c67c..be9c2f7 100644 --- a/docs/concepts/llms.md +++ b/docs/concepts/llms.md @@ -85,6 +85,51 @@ stateless calls. Conversational memory (if you want it) is the caller's responsibility: thread it through state and pass the accumulated message list into each call. +## Retrying transient failures + +LLM endpoints fail in transient ways (rate limits, 503s, brief +outages). Pass a `RetryConfig` to `complete(retry=...)` to retry the +call in-place on those transient categories, without re-running any +surrounding work: + +```python +from openarmature.graph import RetryConfig + +response = await provider.complete( + messages, + retry=RetryConfig(max_attempts=3), +) +``` + +When `retry` is omitted the call is a single attempt (the default). +With a config, the request is built and validated once, then the wire +call is retried on transient errors per the config's classifier and +backoff; a non-transient error (a bad request, an auth failure) +propagates immediately without retrying. From observability's point of +view the call stays a single unit: exactly one completion-or-failure +event fires for the terminal outcome, regardless of how many attempts +it took. + +### Call-level vs node-level retry + +There are two retry layers, for different jobs: + +- **Call-level** (`complete(retry=...)`) retries one LLM call. Reach + for it when a node issues several LLM calls in a loop (chunked + processing, multi-step) and you do not want a transient failure on + the fifth call to re-run the four that already succeeded. +- **Node-level** (`RetryMiddleware`, see [Middleware](middleware.md)) + retries a whole node. Reach for it when the node does LLM work plus + other work (a DB write, a parse) and you want to re-run the entire + body on failure. + +They use the same `RetryConfig` shape and compose: a node-level retry +re-runs the node, and each fresh run gets its own call-level budget. +The thing to avoid is stacking both with overlapping budgets without +meaning to: a 3-attempt node retry wrapping a 5-call node with +3-attempt call-level retry can issue up to 45 calls in the worst case. +Pick intentional budgets per layer. + ## Pre-flight readiness check `Provider.ready()` is the optional pre-flight call you make before diff --git a/src/openarmature/llm/provider.py b/src/openarmature/llm/provider.py index 42dfb89..71e03ac 100644 --- a/src/openarmature/llm/provider.py +++ b/src/openarmature/llm/provider.py @@ -38,7 +38,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Any, Protocol, cast +from typing import TYPE_CHECKING, Any, Protocol, cast from urllib.parse import unquote import jsonschema @@ -58,6 +58,9 @@ ) from .response import Response, RuntimeConfig +if TYPE_CHECKING: + from openarmature.graph.middleware import RetryConfig + class Provider(Protocol): """The shape of any llm-provider implementation. @@ -78,6 +81,7 @@ async def complete( config: RuntimeConfig | None = None, response_schema: dict[str, Any] | type[BaseModel] | None = None, tool_choice: ToolChoice | None = None, + retry: RetryConfig | None = None, ) -> Response: """Perform a single completion call. @@ -102,6 +106,12 @@ async def complete( the wire ``tool_choice`` field is omitted and the provider's own default applies. Pre-send validation routes through ``provider_invalid_request``. + retry: Optional call-level retry configuration. When + supplied, transient provider errors are retried in-call + per the config; the request is built and validated once, + and exactly one observability event fires for the + terminal outcome. ``None`` (the default) performs a + single attempt. """ ... diff --git a/src/openarmature/llm/providers/openai.py b/src/openarmature/llm/providers/openai.py index 0e68b95..b1551fb 100644 --- a/src/openarmature/llm/providers/openai.py +++ b/src/openarmature/llm/providers/openai.py @@ -50,13 +50,14 @@ from __future__ import annotations +import asyncio import hashlib import json import re import time import uuid from collections.abc import Mapping, Sequence -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast from urllib.parse import urlparse import httpx @@ -116,6 +117,9 @@ ) from ..response import FinishReason, ParsedValue, Response, RuntimeConfig, Usage +if TYPE_CHECKING: + from openarmature.graph.middleware import RetryConfig + # Runtime guard for ``OpenAIProvider(..., readiness_probe=...)``. The # Literal type narrows callers under static checkers but is not enforced # at runtime, so an unknown string would silently no-op both dispatch @@ -348,6 +352,7 @@ async def complete( config: RuntimeConfig | None = None, response_schema: dict[str, Any] | type[BaseModel] | None = None, tool_choice: ToolChoice | None = None, + retry: RetryConfig | None = None, ) -> Response: """Single completion call. @@ -370,6 +375,18 @@ async def complete( non-empty ``tools``, and ``ForceTool.name`` must appear in the supplied list. Violations raise ``provider_invalid_request`` BEFORE any HTTP request is sent. + + When ``retry`` is supplied, the wire call is retried on + transient provider errors per the config's classifier and + backoff (defaulting to the canonical transient categories with + exponential jittered backoff). The request is built and + validated once; pre-send validation errors are never retried. + Exactly one observability event fires for the call's terminal + outcome regardless of attempt count, and its ``latency_ms`` + covers the whole call, retries and backoff included. The + ``on_retry`` hook is not exception-isolated (mirroring + ``RetryMiddleware``); an exception raised by it propagates out + of the call. """ # Spec observability §5.5 LLM provider span: when an # observability backend is active in the current invocation, @@ -464,7 +481,7 @@ async def complete( include_response_format=(schema_dict is None or not self._force_prompt_augmentation_fallback), tool_choice=tool_choice, ) - response = await self._do_complete(body, schema_dict, schema_class) + response = await self._do_complete_with_retry(body, schema_dict, schema_class, retry) except LlmProviderError as exc: # Failure path: dispatch a typed LlmFailedEvent per # proposal 0058. Only §7 category exceptions @@ -510,6 +527,52 @@ async def complete( ) return response + async def _do_complete_with_retry( + self, + body: dict[str, Any], + schema_dict: dict[str, Any] | None, + schema_class: type[BaseModel] | None, + retry: RetryConfig | None, + ) -> Response: + """Run the wire call with optional call-level retry. + + Loops the underlying wire call on transient provider errors per + the retry config. Intermediate transient attempts are caught + here and emit no observability event; only the terminal outcome + (success, retry exhaustion, or a non-transient error) reaches + ``complete()``'s typed-event dispatch, so exactly one event + fires per ``complete()`` call. + """ + if retry is None: + return await self._do_complete(body, schema_dict, schema_class) + # Lazy import avoids a module-load cycle: graph.middleware.retry + # imports llm.errors. Resolve None config fields to the canonical + # defaults, mirroring RetryMiddleware. + from openarmature.graph.middleware.retry import ( + default_classifier, + exponential_jitter_backoff, + ) + + classifier = retry.classifier or default_classifier + backoff = retry.backoff or exponential_jitter_backoff + attempt = 0 + while True: + try: + return await self._do_complete(body, schema_dict, schema_class) + except LlmProviderError as exc: + # No graph state at the call boundary; pass None (the + # default classifier ignores it). Re-raise on exhaustion + # or a non-transient category so complete() emits the + # single terminal LlmFailedEvent. + if attempt + 1 >= retry.max_attempts or not classifier(exc, None): + raise + # on_retry is not exception-isolated (matches + # RetryMiddleware); a raise propagates out of the call. + if retry.on_retry is not None: + await retry.on_retry(exc, attempt) + await asyncio.sleep(backoff(attempt)) + attempt += 1 + def _build_llm_completion_event( self, response: Response, diff --git a/tests/unit/test_llm_provider.py b/tests/unit/test_llm_provider.py index 4d93bb5..11f330a 100644 --- a/tests/unit/test_llm_provider.py +++ b/tests/unit/test_llm_provider.py @@ -21,6 +21,7 @@ from pydantic import ValidationError from openarmature.graph.events import LlmCompletionEvent, LlmFailedEvent, NodeEvent +from openarmature.graph.middleware import RetryConfig, deterministic_backoff from openarmature.graph.observer import ObserverEvent from openarmature.llm import ( PROVIDER_AUTHENTICATION, @@ -1336,6 +1337,138 @@ def _503(_req: httpx.Request) -> httpx.Response: assert failed_events[0].error_type == "ProviderUnavailable" +# --------------------------------------------------------------------------- +# Call-level retry (proposal 0050) +# --------------------------------------------------------------------------- + + +def _ok_chat_completion() -> dict[str, object]: + return { + "id": "x", + "object": "chat.completion", + "created": 0, + "model": "m", + "choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + + +def _fail_n_then_ok(calls: list[int], fail_count: int) -> Callable[[httpx.Request], httpx.Response]: + def handler(_req: httpx.Request) -> httpx.Response: + calls[0] += 1 + if calls[0] <= fail_count: + return httpx.Response(503, json={"error": {"message": "down"}}) + return httpx.Response(200, json=_ok_chat_completion()) + + return handler + + +async def test_call_level_retry_succeeds_after_transient() -> None: + calls = [0] + events, token = _collecting_dispatch() + provider = OpenAIProvider( + base_url="http://test", + model="m", + api_key="k", + transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=1)), + ) + try: + response = await provider.complete( + [UserMessage(content="hi")], + retry=RetryConfig(max_attempts=2, backoff=deterministic_backoff(0)), + ) + finally: + await provider.aclose() + _release_dispatch(token) + + # One transient failure then success: the wire call was retried. + assert calls[0] == 2 + assert response.message.content == "ok" + # Terminal-only: one LlmCompletionEvent, no LlmFailedEvent for the + # intermediate transient attempt. + assert len([e for e in events if isinstance(e, LlmCompletionEvent)]) == 1 + assert [e for e in events if isinstance(e, LlmFailedEvent)] == [] + + +async def test_call_level_retry_exhaustion_emits_one_failed_event() -> None: + calls = [0] + events, token = _collecting_dispatch() + provider = OpenAIProvider( + base_url="http://test", + model="m", + api_key="k", + transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=99)), + ) + try: + with pytest.raises(ProviderUnavailable): + await provider.complete( + [UserMessage(content="hi")], + retry=RetryConfig(max_attempts=3, backoff=deterministic_backoff(0)), + ) + finally: + await provider.aclose() + _release_dispatch(token) + + # Exhausted all 3 attempts, then propagated. Terminal-only: one + # LlmFailedEvent (not one per attempt), no LlmCompletionEvent. + assert calls[0] == 3 + assert [e for e in events if isinstance(e, LlmCompletionEvent)] == [] + assert len([e for e in events if isinstance(e, LlmFailedEvent)]) == 1 + + +async def test_call_level_retry_skips_non_transient() -> None: + calls = [0] + events, token = _collecting_dispatch() + + def _400(_req: httpx.Request) -> httpx.Response: + calls[0] += 1 + return httpx.Response(400, json={"error": {"message": "bad"}}) + + provider = OpenAIProvider( + base_url="http://test", model="m", api_key="k", transport=httpx.MockTransport(_400) + ) + try: + with pytest.raises(ProviderInvalidRequest): + await provider.complete( + [UserMessage(content="hi")], + retry=RetryConfig(max_attempts=5, backoff=deterministic_backoff(0)), + ) + finally: + await provider.aclose() + _release_dispatch(token) + + # provider_invalid_request is non-transient: no retry, single attempt. + assert calls[0] == 1 + assert len([e for e in events if isinstance(e, LlmFailedEvent)]) == 1 + + +async def test_call_level_retry_invokes_on_retry_per_attempt() -> None: + calls = [0] + retries: list[tuple[str, int]] = [] + + async def _on_retry(exc: Exception, attempt: int) -> None: + retries.append((type(exc).__name__, attempt)) + + provider = OpenAIProvider( + base_url="http://test", + model="m", + api_key="k", + transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=2)), + ) + try: + await provider.complete( + [UserMessage(content="hi")], + retry=RetryConfig(max_attempts=3, backoff=deterministic_backoff(0), on_retry=_on_retry), + ) + finally: + await provider.aclose() + + # Two transient failures then success: on_retry fires once per + # retried attempt (before each backoff), with the 0-based index. + assert calls[0] == 3 + assert retries == [("ProviderUnavailable", 0), ("ProviderUnavailable", 1)] + + # --------------------------------------------------------------------------- # Proposal 0058: per-category field-mapping + pre-send + mutual exclusion # ---------------------------------------------------------------------------