Skip to content

Commit eb2016b

Browse files
feat(client): add RetryTransport
1 parent 4586c3e commit eb2016b

4 files changed

Lines changed: 1456 additions & 0 deletions

File tree

src/a2a/client/transports/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from a2a.client.transports.base import ClientTransport
44
from a2a.client.transports.jsonrpc import JsonRpcTransport
55
from a2a.client.transports.rest import RestTransport
6+
from a2a.client.transports.retry import RetryTransport, default_retry_predicate
67

78

89
try:
@@ -16,4 +17,6 @@
1617
'GrpcTransport',
1718
'JsonRpcTransport',
1819
'RestTransport',
20+
'RetryTransport',
21+
'default_retry_predicate',
1922
]

src/a2a/client/transports/retry.py

Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
"""A transport decorator that adds retry logic with exponential backoff."""
2+
3+
import asyncio
4+
import inspect
5+
import logging
6+
import random
7+
8+
from collections.abc import AsyncGenerator, Awaitable, Callable
9+
from typing import Any, TypeVar
10+
11+
import httpx
12+
13+
from a2a.client.client import ClientCallContext
14+
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
15+
from a2a.client.transports.base import ClientTransport
16+
from a2a.types.a2a_pb2 import (
17+
AgentCard,
18+
CancelTaskRequest,
19+
DeleteTaskPushNotificationConfigRequest,
20+
GetExtendedAgentCardRequest,
21+
GetTaskPushNotificationConfigRequest,
22+
GetTaskRequest,
23+
ListTaskPushNotificationConfigsRequest,
24+
ListTaskPushNotificationConfigsResponse,
25+
ListTasksRequest,
26+
ListTasksResponse,
27+
SendMessageRequest,
28+
SendMessageResponse,
29+
StreamResponse,
30+
SubscribeToTaskRequest,
31+
Task,
32+
TaskPushNotificationConfig,
33+
)
34+
35+
36+
logger = logging.getLogger(__name__)
37+
38+
T = TypeVar('T')
39+
40+
RetryPredicate = Callable[[Exception], bool]
41+
OnRetryCallback = Callable[[int, Exception, float], Awaitable[None] | None]
42+
43+
44+
def default_retry_predicate(error: Exception) -> bool: # noqa: PLR0911
45+
"""Determines if an error is retriable based on its type and cause.
46+
47+
Retriable conditions:
48+
- A2AClientTimeoutError (always)
49+
- A2AClientError caused by httpx.RequestError (network errors)
50+
- A2AClientError caused by httpx.HTTPStatusError with status 429, 502, 503, 504
51+
- A2AClientError caused by grpc.aio.AioRpcError with UNAVAILABLE or RESOURCE_EXHAUSTED
52+
53+
Non-retriable:
54+
- Domain-specific errors (TaskNotFoundError, etc.) — inherit A2AError, not A2AClientError
55+
- A2AClientError caused by json.JSONDecodeError or SSEError
56+
- A2AClientError with no recognized __cause__
57+
- Any non-A2AClientError exception
58+
"""
59+
if isinstance(error, A2AClientTimeoutError):
60+
return True
61+
62+
if not isinstance(error, A2AClientError):
63+
return False
64+
65+
cause = error.__cause__
66+
if cause is None:
67+
return False
68+
69+
if isinstance(cause, httpx.RequestError):
70+
return True
71+
if isinstance(cause, httpx.HTTPStatusError):
72+
return cause.response.status_code in {429, 502, 503, 504}
73+
74+
try:
75+
import grpc # noqa: PLC0415
76+
77+
if isinstance(cause, grpc.aio.AioRpcError):
78+
return cause.code() in {
79+
grpc.StatusCode.UNAVAILABLE,
80+
grpc.StatusCode.RESOURCE_EXHAUSTED,
81+
}
82+
except ImportError:
83+
pass
84+
85+
return False
86+
87+
88+
class RetryTransport(ClientTransport):
89+
"""A transport decorator that adds retry logic with exponential backoff.
90+
91+
Wraps any ClientTransport and retries failed operations that match
92+
the retry predicate. Streaming methods (send_message_streaming,
93+
subscribe) only retry pre-stream failures; once the first event
94+
is yielded, errors propagate without retry.
95+
"""
96+
97+
def __init__( # noqa: PLR0913
98+
self,
99+
base: ClientTransport,
100+
*,
101+
max_retries: int = 3,
102+
base_delay: float = 1.0,
103+
max_delay: float = 30.0,
104+
jitter: bool = True,
105+
retry_predicate: RetryPredicate | None = None,
106+
on_retry: OnRetryCallback | None = None,
107+
) -> None:
108+
if max_retries < 0:
109+
raise ValueError('max_retries must be >= 0')
110+
if base_delay <= 0:
111+
raise ValueError('base_delay must be > 0')
112+
if max_delay <= 0:
113+
raise ValueError('max_delay must be > 0')
114+
self._base = base
115+
self._max_retries = max_retries
116+
self._base_delay = base_delay
117+
self._max_delay = max_delay
118+
self._jitter = jitter
119+
self._retry_predicate = retry_predicate or default_retry_predicate
120+
self._on_retry = on_retry
121+
122+
def _calculate_delay(self, attempt: int) -> float:
123+
"""Calculates the delay for a given retry attempt using exponential backoff.
124+
125+
Args:
126+
attempt: The retry attempt number (1-indexed).
127+
128+
Returns:
129+
The delay in seconds before the next retry.
130+
"""
131+
delay = min(self._base_delay * (2 ** (attempt - 1)), self._max_delay)
132+
if self._jitter:
133+
delay = random.uniform(0, delay) # noqa: S311
134+
return delay
135+
136+
async def _execute_with_retry(
137+
self,
138+
operation: Callable[[], Awaitable[T]],
139+
method_name: str,
140+
) -> T:
141+
"""Executes an async operation with retry logic.
142+
143+
Args:
144+
operation: A zero-argument async callable that performs the transport call.
145+
method_name: Name of the method being called, used for logging.
146+
147+
Returns:
148+
The result of the operation.
149+
150+
Raises:
151+
The last exception if all retry attempts are exhausted.
152+
"""
153+
last_error: Exception | None = None
154+
for attempt in range(self._max_retries + 1):
155+
try:
156+
return await operation()
157+
except Exception as e: # noqa: PERF203
158+
last_error = e
159+
if attempt >= self._max_retries or not self._retry_predicate(e):
160+
raise
161+
delay = self._calculate_delay(attempt + 1)
162+
logger.warning(
163+
'Retry %d/%d for %s after %.2fs: %s',
164+
attempt + 1,
165+
self._max_retries,
166+
method_name,
167+
delay,
168+
e,
169+
)
170+
if self._on_retry is not None:
171+
result: Any = self._on_retry(attempt + 1, e, delay)
172+
if inspect.isawaitable(result):
173+
await result
174+
await asyncio.sleep(delay)
175+
raise last_error # type: ignore[misc]
176+
177+
async def _execute_streaming_with_retry(
178+
self,
179+
operation: Callable[[], AsyncGenerator[StreamResponse]],
180+
method_name: str,
181+
) -> AsyncGenerator[StreamResponse]:
182+
"""Executes a streaming operation with retry logic for pre-stream failures.
183+
184+
Retries only apply before the first event is yielded. Once streaming
185+
has started, errors propagate to the caller without retry.
186+
187+
Args:
188+
operation: A zero-argument callable returning an async generator.
189+
method_name: Name of the method being called, used for logging.
190+
191+
Yields:
192+
StreamResponse events from the underlying transport.
193+
"""
194+
last_error: Exception | None = None
195+
for attempt in range(self._max_retries + 1):
196+
first = True
197+
try:
198+
stream = operation()
199+
async for event in stream:
200+
first = False
201+
yield event
202+
except Exception as e:
203+
if not first:
204+
raise
205+
last_error = e
206+
if attempt >= self._max_retries or not self._retry_predicate(e):
207+
raise
208+
delay = self._calculate_delay(attempt + 1)
209+
logger.warning(
210+
'Retry %d/%d for %s after %.2fs: %s',
211+
attempt + 1,
212+
self._max_retries,
213+
method_name,
214+
delay,
215+
e,
216+
)
217+
if self._on_retry is not None:
218+
result: Any = self._on_retry(attempt + 1, e, delay)
219+
if inspect.isawaitable(result):
220+
await result
221+
await asyncio.sleep(delay)
222+
else:
223+
return
224+
raise last_error # type: ignore[misc]
225+
226+
async def send_message(
227+
self,
228+
request: SendMessageRequest,
229+
*,
230+
context: ClientCallContext | None = None,
231+
) -> SendMessageResponse:
232+
"""Sends a non-streaming message request to the agent."""
233+
return await self._execute_with_retry(
234+
lambda: self._base.send_message(request, context=context),
235+
'send_message',
236+
)
237+
238+
async def send_message_streaming(
239+
self,
240+
request: SendMessageRequest,
241+
*,
242+
context: ClientCallContext | None = None,
243+
) -> AsyncGenerator[StreamResponse]:
244+
"""Sends a streaming message request to the agent and yields responses."""
245+
async for event in self._execute_streaming_with_retry(
246+
lambda: self._base.send_message_streaming(request, context=context),
247+
'send_message_streaming',
248+
):
249+
yield event
250+
251+
async def get_task(
252+
self,
253+
request: GetTaskRequest,
254+
*,
255+
context: ClientCallContext | None = None,
256+
) -> Task:
257+
"""Retrieves the current state and history of a specific task."""
258+
return await self._execute_with_retry(
259+
lambda: self._base.get_task(request, context=context),
260+
'get_task',
261+
)
262+
263+
async def list_tasks(
264+
self,
265+
request: ListTasksRequest,
266+
*,
267+
context: ClientCallContext | None = None,
268+
) -> ListTasksResponse:
269+
"""Retrieves tasks for an agent."""
270+
return await self._execute_with_retry(
271+
lambda: self._base.list_tasks(request, context=context),
272+
'list_tasks',
273+
)
274+
275+
async def cancel_task(
276+
self,
277+
request: CancelTaskRequest,
278+
*,
279+
context: ClientCallContext | None = None,
280+
) -> Task:
281+
"""Requests the agent to cancel a specific task."""
282+
return await self._execute_with_retry(
283+
lambda: self._base.cancel_task(request, context=context),
284+
'cancel_task',
285+
)
286+
287+
async def create_task_push_notification_config(
288+
self,
289+
request: TaskPushNotificationConfig,
290+
*,
291+
context: ClientCallContext | None = None,
292+
) -> TaskPushNotificationConfig:
293+
"""Sets or updates the push notification configuration for a specific task."""
294+
return await self._execute_with_retry(
295+
lambda: self._base.create_task_push_notification_config(
296+
request, context=context
297+
),
298+
'create_task_push_notification_config',
299+
)
300+
301+
async def get_task_push_notification_config(
302+
self,
303+
request: GetTaskPushNotificationConfigRequest,
304+
*,
305+
context: ClientCallContext | None = None,
306+
) -> TaskPushNotificationConfig:
307+
"""Retrieves the push notification configuration for a specific task."""
308+
return await self._execute_with_retry(
309+
lambda: self._base.get_task_push_notification_config(
310+
request, context=context
311+
),
312+
'get_task_push_notification_config',
313+
)
314+
315+
async def list_task_push_notification_configs(
316+
self,
317+
request: ListTaskPushNotificationConfigsRequest,
318+
*,
319+
context: ClientCallContext | None = None,
320+
) -> ListTaskPushNotificationConfigsResponse:
321+
"""Lists push notification configurations for a specific task."""
322+
return await self._execute_with_retry(
323+
lambda: self._base.list_task_push_notification_configs(
324+
request, context=context
325+
),
326+
'list_task_push_notification_configs',
327+
)
328+
329+
async def delete_task_push_notification_config(
330+
self,
331+
request: DeleteTaskPushNotificationConfigRequest,
332+
*,
333+
context: ClientCallContext | None = None,
334+
) -> None:
335+
"""Deletes the push notification configuration for a specific task."""
336+
await self._execute_with_retry(
337+
lambda: self._base.delete_task_push_notification_config(
338+
request, context=context
339+
),
340+
'delete_task_push_notification_config',
341+
)
342+
343+
async def subscribe(
344+
self,
345+
request: SubscribeToTaskRequest,
346+
*,
347+
context: ClientCallContext | None = None,
348+
) -> AsyncGenerator[StreamResponse]:
349+
"""Reconnects to get task updates."""
350+
async for event in self._execute_streaming_with_retry(
351+
lambda: self._base.subscribe(request, context=context),
352+
'subscribe',
353+
):
354+
yield event
355+
356+
async def get_extended_agent_card(
357+
self,
358+
request: GetExtendedAgentCardRequest,
359+
*,
360+
context: ClientCallContext | None = None,
361+
) -> AgentCard:
362+
"""Retrieves the Extended AgentCard."""
363+
return await self._execute_with_retry(
364+
lambda: self._base.get_extended_agent_card(
365+
request, context=context
366+
),
367+
'get_extended_agent_card',
368+
)
369+
370+
async def close(self) -> None:
371+
"""Closes the transport."""
372+
await self._base.close()

0 commit comments

Comments
 (0)