Skip to content

Commit 6af1b8b

Browse files
feat(client): add RetryTransport for automatic retry with exponential backoff
1 parent eb37091 commit 6af1b8b

4 files changed

Lines changed: 1337 additions & 0 deletions

File tree

src/a2a/client/transports/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
from a2a.client.transports.base import ClientTransport
44
from a2a.client.transports.jsonrpc import JsonRpcTransport
55
from a2a.client.transports.rest import RestTransport
6+
from a2a.client.transports.retry import (
7+
OnRetryCallback,
8+
RetryPredicate,
9+
RetryTransport,
10+
default_retry_predicate,
11+
)
612

713

814
try:
@@ -15,5 +21,9 @@
1521
'ClientTransport',
1622
'GrpcTransport',
1723
'JsonRpcTransport',
24+
'OnRetryCallback',
1825
'RestTransport',
26+
'RetryPredicate',
27+
'RetryTransport',
28+
'default_retry_predicate',
1929
]

src/a2a/client/transports/retry.py

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
import asyncio
2+
import inspect
3+
import logging
4+
import random
5+
6+
from collections.abc import AsyncGenerator, Awaitable, Callable
7+
from typing import Any, TypeVar
8+
9+
import httpx
10+
11+
from a2a.client.client import ClientCallContext
12+
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
13+
from a2a.client.transports.base import ClientTransport
14+
from a2a.types.a2a_pb2 import (
15+
AgentCard,
16+
CancelTaskRequest,
17+
DeleteTaskPushNotificationConfigRequest,
18+
GetExtendedAgentCardRequest,
19+
GetTaskPushNotificationConfigRequest,
20+
GetTaskRequest,
21+
ListTaskPushNotificationConfigsRequest,
22+
ListTaskPushNotificationConfigsResponse,
23+
ListTasksRequest,
24+
ListTasksResponse,
25+
SendMessageRequest,
26+
SendMessageResponse,
27+
StreamResponse,
28+
SubscribeToTaskRequest,
29+
Task,
30+
TaskPushNotificationConfig,
31+
)
32+
33+
34+
logger = logging.getLogger(__name__)
35+
36+
T = TypeVar('T')
37+
38+
RetryPredicate = Callable[[Exception], bool]
39+
OnRetryCallback = Callable[[int, Exception, float], Awaitable[None] | None]
40+
41+
_RETRYABLE_HTTP_STATUS: frozenset[int] = frozenset({408, 429, 502, 503, 504})
42+
43+
# grpc is an optional dependency.
44+
try:
45+
import grpc as _grpc
46+
47+
_AioRpcError: Any = _grpc.aio.AioRpcError
48+
_RETRYABLE_GRPC_CODES: frozenset[Any] = frozenset({
49+
_grpc.StatusCode.UNAVAILABLE,
50+
_grpc.StatusCode.RESOURCE_EXHAUSTED,
51+
})
52+
except ImportError:
53+
_AioRpcError = None
54+
_RETRYABLE_GRPC_CODES = frozenset()
55+
56+
57+
def default_retry_predicate(error: Exception) -> bool: # noqa: PLR0911
58+
"""Returns True for transient errors, False otherwise.
59+
60+
Retried: A2AClientTimeoutError; A2AClientError caused by httpx network
61+
errors, HTTP 408/429/502/503/504, or gRPC UNAVAILABLE/RESOURCE_EXHAUSTED.
62+
63+
Not retried: domain errors (TaskNotFoundError, etc.), HTTP 5xx other than
64+
502/503/504 (replaying server bugs is not safe), JSON decode / SSE errors.
65+
66+
The cause is read from ``__cause__`` first (set by ``raise … from e``),
67+
falling back to ``__context__`` for callers that don't chain explicitly.
68+
"""
69+
if isinstance(error, A2AClientTimeoutError):
70+
return True
71+
if not isinstance(error, A2AClientError):
72+
return False
73+
74+
cause = error.__cause__ or error.__context__
75+
if cause is None:
76+
return False
77+
if isinstance(cause, httpx.HTTPStatusError):
78+
return cause.response.status_code in _RETRYABLE_HTTP_STATUS
79+
if isinstance(cause, httpx.RequestError):
80+
return True
81+
if _AioRpcError is not None and isinstance(cause, _AioRpcError):
82+
return cause.code() in _RETRYABLE_GRPC_CODES
83+
return False
84+
85+
86+
class RetryTransport(ClientTransport):
87+
"""A transport decorator that retries transient failures with exponential backoff.
88+
89+
Streaming methods only retry before the first event is yielded.
90+
"""
91+
92+
def __init__( # noqa: PLR0913
93+
self,
94+
base: ClientTransport,
95+
*,
96+
max_retries: int = 3,
97+
base_delay: float = 1.0,
98+
max_delay: float = 30.0,
99+
jitter: bool = True,
100+
retry_predicate: RetryPredicate | None = None,
101+
on_retry: OnRetryCallback | None = None,
102+
) -> None:
103+
if max_retries < 0:
104+
raise ValueError('max_retries must be >= 0')
105+
if base_delay <= 0:
106+
raise ValueError('base_delay must be > 0')
107+
if max_delay <= 0:
108+
raise ValueError('max_delay must be > 0')
109+
self._base = base
110+
self._max_retries = max_retries
111+
self._base_delay = base_delay
112+
self._max_delay = max_delay
113+
self._jitter = jitter
114+
self._retry_predicate = retry_predicate or default_retry_predicate
115+
self._on_retry = on_retry
116+
117+
def _calculate_delay(self, attempt_index: int) -> float:
118+
delay = min(self._base_delay * (2**attempt_index), self._max_delay)
119+
if self._jitter:
120+
delay = random.uniform(0, delay) # noqa: S311
121+
return delay
122+
123+
async def _delay_and_notify(
124+
self,
125+
attempt_index: int,
126+
error: Exception,
127+
method_name: str,
128+
) -> None:
129+
retry_number = attempt_index + 1
130+
delay = self._calculate_delay(attempt_index)
131+
logger.warning(
132+
'Retry %d/%d for %s after %.2fs: %s',
133+
retry_number,
134+
self._max_retries,
135+
method_name,
136+
delay,
137+
error,
138+
)
139+
if self._on_retry is not None:
140+
try:
141+
result: Any = self._on_retry(retry_number, error, delay)
142+
if inspect.isawaitable(result):
143+
await result
144+
except asyncio.CancelledError:
145+
raise
146+
except Exception:
147+
# A buggy callback must not break the retry loop.
148+
logger.exception(
149+
'on_retry callback raised for %s; continuing retry',
150+
method_name,
151+
)
152+
await asyncio.sleep(delay)
153+
154+
@staticmethod
155+
async def _safe_aclose(stream: Any) -> None:
156+
aclose = getattr(stream, 'aclose', None)
157+
if aclose is None:
158+
return
159+
try:
160+
await aclose()
161+
except asyncio.CancelledError:
162+
raise
163+
except Exception:
164+
logger.debug(
165+
'Ignoring error while closing stream during retry cleanup',
166+
exc_info=True,
167+
)
168+
169+
async def _execute_with_retry(
170+
self,
171+
operation: Callable[[], Awaitable[T]],
172+
method_name: str,
173+
) -> T:
174+
attempt = 0
175+
while True:
176+
try:
177+
return await operation()
178+
except asyncio.CancelledError: # noqa: PERF203
179+
raise
180+
except Exception as e:
181+
if attempt >= self._max_retries or not self._retry_predicate(e):
182+
raise
183+
await self._delay_and_notify(attempt, e, method_name)
184+
attempt += 1
185+
186+
async def _execute_streaming_with_retry(
187+
self,
188+
operation: Callable[[], AsyncGenerator[StreamResponse]],
189+
method_name: str,
190+
) -> AsyncGenerator[StreamResponse]:
191+
# Retry only pre-stream failures. The inner finally closes the inner
192+
# generator on every exit path (success, retry, exception, consumer
193+
# break) so transport resources are not leaked.
194+
attempt = 0
195+
while True:
196+
first = True
197+
stream: Any = None
198+
try:
199+
stream = operation()
200+
try:
201+
async for event in stream:
202+
first = False
203+
yield event
204+
finally:
205+
await self._safe_aclose(stream)
206+
except asyncio.CancelledError:
207+
raise
208+
except Exception as e:
209+
if (
210+
not first
211+
or attempt >= self._max_retries
212+
or not self._retry_predicate(e)
213+
):
214+
raise
215+
await self._delay_and_notify(attempt, e, method_name)
216+
attempt += 1
217+
else:
218+
return
219+
220+
async def send_message(
221+
self,
222+
request: SendMessageRequest,
223+
*,
224+
context: ClientCallContext | None = None,
225+
) -> SendMessageResponse:
226+
"""Sends a non-streaming message request to the agent."""
227+
return await self._execute_with_retry(
228+
lambda: self._base.send_message(request, context=context),
229+
'send_message',
230+
)
231+
232+
async def send_message_streaming(
233+
self,
234+
request: SendMessageRequest,
235+
*,
236+
context: ClientCallContext | None = None,
237+
) -> AsyncGenerator[StreamResponse]:
238+
"""Sends a streaming message request to the agent and yields responses as they arrive."""
239+
inner = self._execute_streaming_with_retry(
240+
lambda: self._base.send_message_streaming(request, context=context),
241+
'send_message_streaming',
242+
)
243+
try:
244+
async for event in inner:
245+
yield event
246+
finally:
247+
await inner.aclose()
248+
249+
async def get_task(
250+
self,
251+
request: GetTaskRequest,
252+
*,
253+
context: ClientCallContext | None = None,
254+
) -> Task:
255+
"""Retrieves the current state and history of a specific task."""
256+
return await self._execute_with_retry(
257+
lambda: self._base.get_task(request, context=context),
258+
'get_task',
259+
)
260+
261+
async def list_tasks(
262+
self,
263+
request: ListTasksRequest,
264+
*,
265+
context: ClientCallContext | None = None,
266+
) -> ListTasksResponse:
267+
"""Retrieves tasks for an agent."""
268+
return await self._execute_with_retry(
269+
lambda: self._base.list_tasks(request, context=context),
270+
'list_tasks',
271+
)
272+
273+
async def cancel_task(
274+
self,
275+
request: CancelTaskRequest,
276+
*,
277+
context: ClientCallContext | None = None,
278+
) -> Task:
279+
"""Requests the agent to cancel a specific task."""
280+
return await self._execute_with_retry(
281+
lambda: self._base.cancel_task(request, context=context),
282+
'cancel_task',
283+
)
284+
285+
async def create_task_push_notification_config(
286+
self,
287+
request: TaskPushNotificationConfig,
288+
*,
289+
context: ClientCallContext | None = None,
290+
) -> TaskPushNotificationConfig:
291+
"""Sets or updates the push notification configuration for a specific task."""
292+
return await self._execute_with_retry(
293+
lambda: self._base.create_task_push_notification_config(
294+
request, context=context
295+
),
296+
'create_task_push_notification_config',
297+
)
298+
299+
async def get_task_push_notification_config(
300+
self,
301+
request: GetTaskPushNotificationConfigRequest,
302+
*,
303+
context: ClientCallContext | None = None,
304+
) -> TaskPushNotificationConfig:
305+
"""Retrieves the push notification configuration for a specific task."""
306+
return await self._execute_with_retry(
307+
lambda: self._base.get_task_push_notification_config(
308+
request, context=context
309+
),
310+
'get_task_push_notification_config',
311+
)
312+
313+
async def list_task_push_notification_configs(
314+
self,
315+
request: ListTaskPushNotificationConfigsRequest,
316+
*,
317+
context: ClientCallContext | None = None,
318+
) -> ListTaskPushNotificationConfigsResponse:
319+
"""Lists push notification configurations for a specific task."""
320+
return await self._execute_with_retry(
321+
lambda: self._base.list_task_push_notification_configs(
322+
request, context=context
323+
),
324+
'list_task_push_notification_configs',
325+
)
326+
327+
async def delete_task_push_notification_config(
328+
self,
329+
request: DeleteTaskPushNotificationConfigRequest,
330+
*,
331+
context: ClientCallContext | None = None,
332+
) -> None:
333+
"""Deletes the push notification configuration for a specific task."""
334+
await self._execute_with_retry(
335+
lambda: self._base.delete_task_push_notification_config(
336+
request, context=context
337+
),
338+
'delete_task_push_notification_config',
339+
)
340+
341+
async def subscribe(
342+
self,
343+
request: SubscribeToTaskRequest,
344+
*,
345+
context: ClientCallContext | None = None,
346+
) -> AsyncGenerator[StreamResponse]:
347+
"""Reconnects to get task updates."""
348+
inner = self._execute_streaming_with_retry(
349+
lambda: self._base.subscribe(request, context=context),
350+
'subscribe',
351+
)
352+
try:
353+
async for event in inner:
354+
yield event
355+
finally:
356+
await inner.aclose()
357+
358+
async def get_extended_agent_card(
359+
self,
360+
request: GetExtendedAgentCardRequest,
361+
*,
362+
context: ClientCallContext | None = None,
363+
) -> AgentCard:
364+
"""Retrieves the Extended AgentCard."""
365+
return await self._execute_with_retry(
366+
lambda: self._base.get_extended_agent_card(
367+
request, context=context
368+
),
369+
'get_extended_agent_card',
370+
)
371+
372+
async def close(self) -> None:
373+
"""Closes the transport."""
374+
await self._base.close()

0 commit comments

Comments
 (0)