Skip to content

Commit 49f3e91

Browse files
authored
Merge branch '1.0-dev' into fix-run-itk-locally
2 parents 3154e0f + f0e1d74 commit 49f3e91

18 files changed

Lines changed: 354 additions & 203 deletions

src/a2a/client/transports/http_helpers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
1313

1414

15+
def _default_sse_error_handler(sse_data: str) -> NoReturn:
16+
raise A2AClientError(f'SSE stream error event received: {sse_data}')
17+
18+
1519
@contextmanager
1620
def handle_http_exceptions(
1721
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
@@ -71,9 +75,22 @@ async def send_http_stream_request(
7175
url: str,
7276
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
7377
| None = None,
78+
sse_error_handler: Callable[[str], NoReturn] = _default_sse_error_handler,
7479
**kwargs: Any,
7580
) -> AsyncGenerator[str]:
76-
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
81+
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.
82+
83+
Args:
84+
httpx_client: The async HTTP client.
85+
method: The HTTP method (e.g. 'POST', 'GET').
86+
url: The URL to send the request to.
87+
status_error_handler: Handler for HTTP status errors. Should raise an
88+
appropriate domain-specific exception.
89+
sse_error_handler: Handler for SSE error events. Called with the
90+
raw SSE data string when an ``event: error`` SSE event is received.
91+
Should raise an appropriate domain-specific exception.
92+
**kwargs: Additional keyword arguments forwarded to ``aconnect_sse``.
93+
"""
7794
with handle_http_exceptions(status_error_handler):
7895
async with _SSEEventSource(
7996
httpx_client, method, url, **kwargs
@@ -97,6 +114,8 @@ async def send_http_stream_request(
97114
async for sse in event_source.aiter_sse():
98115
if not sse.data:
99116
continue
117+
if sse.event == 'error':
118+
sse_error_handler(sse.data)
100119
yield sse.data
101120

102121

src/a2a/client/transports/jsonrpc.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from collections.abc import AsyncGenerator
4-
from typing import Any
4+
from typing import Any, NoReturn
55
from uuid import uuid4
66

77
import httpx
@@ -350,6 +350,7 @@ async def _send_stream_request(
350350
'POST',
351351
self.url,
352352
None,
353+
self._handle_sse_error,
353354
json=rpc_request_payload,
354355
**http_kwargs,
355356
):
@@ -360,3 +361,10 @@ async def _send_stream_request(
360361
json_rpc_response.result, StreamResponse()
361362
)
362363
yield response
364+
365+
def _handle_sse_error(self, sse_data: str) -> NoReturn:
366+
"""Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error."""
367+
json_rpc_response = JSONRPC20Response.from_json(sse_data)
368+
if json_rpc_response.error:
369+
raise self._create_jsonrpc_error(json_rpc_response.error)
370+
raise A2AClientError(f'SSE stream error: {sse_data}')

src/a2a/client/transports/rest.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,47 @@
4141
logger = logging.getLogger(__name__)
4242

4343

44+
def _parse_rest_error(
45+
error_payload: dict[str, Any],
46+
fallback_message: str,
47+
) -> Exception | None:
48+
"""Parses a REST error payload and returns the appropriate A2AError.
49+
50+
Args:
51+
error_payload: The parsed JSON error payload.
52+
fallback_message: Message to use if the payload has no ``message``.
53+
54+
Returns:
55+
The mapped A2AError if a known reason was found, otherwise ``None``.
56+
"""
57+
error_data = error_payload.get('error', {})
58+
message = error_data.get('message', fallback_message)
59+
details = error_data.get('details', [])
60+
if not isinstance(details, list):
61+
return None
62+
63+
# The `details` array can contain multiple different error objects.
64+
# We extract the first `ErrorInfo` object because it contains the
65+
# specific `reason` code needed to map this back to a Python A2AError.
66+
for d in details:
67+
if (
68+
isinstance(d, dict)
69+
and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo'
70+
):
71+
reason = d.get('reason')
72+
metadata = d.get('metadata') or {}
73+
if isinstance(reason, str):
74+
exception_cls = A2A_REASON_TO_ERROR.get(reason)
75+
if exception_cls:
76+
exc = exception_cls(message)
77+
if metadata:
78+
exc.data = metadata
79+
return exc
80+
break
81+
82+
return None
83+
84+
4485
@trace_class(kind=SpanKind.CLIENT)
4586
class RestTransport(ClientTransport):
4687
"""A REST transport for the A2A client."""
@@ -294,39 +335,12 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
294335
"""Handles HTTP status errors and raises the appropriate A2AError."""
295336
try:
296337
error_payload = e.response.json()
297-
error_data = error_payload.get('error', {})
298-
299-
message = error_data.get('message', str(e))
300-
details = error_data.get('details', [])
301-
if not isinstance(details, list):
302-
details = []
303-
304-
# The `details` array can contain multiple different error objects.
305-
# We extract the first `ErrorInfo` object because it contains the
306-
# specific `reason` code needed to map this back to a Python A2AError.
307-
error_info = {}
308-
for d in details:
309-
if (
310-
isinstance(d, dict)
311-
and d.get('@type')
312-
== 'type.googleapis.com/google.rpc.ErrorInfo'
313-
):
314-
error_info = d
315-
break
316-
reason = error_info.get('reason')
317-
metadata = error_info.get('metadata') or {}
318-
319-
if isinstance(reason, str):
320-
exception_cls = A2A_REASON_TO_ERROR.get(reason)
321-
if exception_cls:
322-
exc = exception_cls(message)
323-
if metadata:
324-
exc.data = metadata
325-
raise exc from e
338+
mapped = _parse_rest_error(error_payload, str(e))
339+
if mapped:
340+
raise mapped from e
326341
except (json.JSONDecodeError, ValueError):
327342
pass
328343

329-
# Fallback mappings for status codes if 'type' is missing or unknown
330344
status_code = e.response.status_code
331345
if status_code == httpx.codes.NOT_FOUND:
332346
raise MethodNotFoundError(
@@ -335,6 +349,14 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
335349

336350
raise A2AClientError(f'HTTP Error {status_code}: {e}') from e
337351

352+
def _handle_sse_error(self, sse_data: str) -> NoReturn:
353+
"""Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError."""
354+
error_payload = json.loads(sse_data)
355+
mapped = _parse_rest_error(error_payload, sse_data)
356+
if mapped:
357+
raise mapped
358+
raise A2AClientError(sse_data)
359+
338360
async def _send_stream_request(
339361
self,
340362
method: str,
@@ -352,6 +374,7 @@ async def _send_stream_request(
352374
method,
353375
f'{self.url}{path}',
354376
self._handle_http_error,
377+
self._handle_sse_error,
355378
json=json,
356379
**http_kwargs,
357380
):

src/a2a/server/agent_execution/active_task.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -374,30 +374,33 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
374374
await self._task_manager.process(event)
375375

376376
# Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states
377-
res = await self._task_manager.get_task()
377+
new_task = await self._task_manager.get_task()
378+
if new_task is None:
379+
raise RuntimeError(
380+
f'Task {self.task_id} not found'
381+
)
378382
is_interrupted = (
379-
res
380-
and res.status.state
383+
new_task.status.state
381384
in INTERRUPTED_TASK_STATES
382385
)
383386
is_terminal = (
384-
res
385-
and res.status.state in TERMINAL_TASK_STATES
387+
new_task.status.state
388+
in TERMINAL_TASK_STATES
386389
)
387390

388391
# If we hit a breakpoint or terminal state, lock in the result.
389-
if (is_interrupted or is_terminal) and res:
392+
if is_interrupted or is_terminal:
390393
logger.debug(
391394
'Consumer[%s]: Setting first result as Task (state=%s)',
392395
self._task_id,
393-
res.status.state,
396+
new_task.status.state,
394397
)
395398

396399
if is_terminal:
397400
logger.debug(
398401
'Consumer[%s]: Reached terminal state %s',
399402
self._task_id,
400-
res.status.state if res else 'unknown',
403+
new_task.status.state,
401404
)
402405
if not self._is_finished.is_set():
403406
async with self._lock:
@@ -413,7 +416,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
413416
logger.debug(
414417
'Consumer[%s]: Interrupted with state %s',
415418
self._task_id,
416-
res.status.state if res else 'unknown',
419+
new_task.status.state,
417420
)
418421

419422
if (

src/a2a/server/events/event_consumer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pydantic import ValidationError
77

8-
from a2a.server.events.event_queue import Event, EventQueue, QueueShutDown
8+
from a2a.server.events.event_queue import Event, EventQueueLegacy, QueueShutDown
99
from a2a.types.a2a_pb2 import (
1010
Message,
1111
Task,
@@ -22,7 +22,7 @@
2222
class EventConsumer:
2323
"""Consumer to read events from the agent event queue."""
2424

25-
def __init__(self, queue: EventQueue):
25+
def __init__(self, queue: EventQueueLegacy):
2626
"""Initializes the EventConsumer.
2727
2828
Args:

src/a2a/server/events/event_queue.py

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -92,73 +92,6 @@ async def enqueue_event(self, event: Event) -> None:
9292
Only main queue can enqueue events. Child queues can only dequeue events.
9393
"""
9494

95-
@abstractmethod
96-
async def dequeue_event(self) -> Event:
97-
"""Pulls an event from the queue."""
98-
99-
@abstractmethod
100-
def task_done(self) -> None:
101-
"""Signals that a work on dequeued event is complete."""
102-
103-
@abstractmethod
104-
async def tap(
105-
self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE
106-
) -> 'EventQueue':
107-
"""Creates a child queue that receives future events.
108-
109-
Note: The tapped queue may receive some old events if the incoming event
110-
queue is lagging behind and hasn't dispatched them yet.
111-
"""
112-
113-
@abstractmethod
114-
async def close(self, immediate: bool = False) -> None:
115-
"""Closes the queue.
116-
117-
For parent queue: it closes the main queue and all its child queues.
118-
For child queue: it closes only child queue.
119-
120-
It is safe to call it multiple times.
121-
If immediate is True, the queue will be closed without waiting for all events to be processed.
122-
If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls).
123-
124-
WARNING: Closing the parent queue with immediate=False is a deadlock risk if there are unconsumed events
125-
in any of the child sinks and the consumer has crashed without draining its queue.
126-
It is highly recommended to wrap graceful shutdowns with a timeout, e.g.,
127-
`asyncio.wait_for(queue.close(immediate=False), timeout=...)`.
128-
"""
129-
130-
@abstractmethod
131-
def is_closed(self) -> bool:
132-
"""[DEPRECATED] Checks if the queue is closed.
133-
134-
NOTE: Relying on this for enqueue logic introduces race conditions.
135-
It is maintained primarily for backwards compatibility, workarounds for
136-
Python 3.10/3.12 async queues in consumers, and for the test suite.
137-
"""
138-
139-
@abstractmethod
140-
async def __aenter__(self) -> Self:
141-
"""Enters the async context manager, returning the queue itself.
142-
143-
WARNING: See `__aexit__` for important deadlock risks associated with
144-
exiting this context manager if unconsumed events remain.
145-
"""
146-
147-
@abstractmethod
148-
async def __aexit__(
149-
self,
150-
exc_type: type[BaseException] | None,
151-
exc_val: BaseException | None,
152-
exc_tb: TracebackType | None,
153-
) -> None:
154-
"""Exits the async context manager, ensuring close() is called.
155-
156-
WARNING: The context manager calls `close(immediate=False)` by default.
157-
If a consumer exits the `async with` block early (e.g., due to an exception
158-
or an explicit `break`) while unconsumed events remain in the queue,
159-
`__aexit__` will deadlock waiting for `task_done()` to be called on those events.
160-
"""
161-
16295

16396
@trace_class(kind=SpanKind.SERVER)
16497
class EventQueueLegacy(EventQueue):
@@ -180,7 +113,7 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None:
180113
self._queue: AsyncQueue[Event] = _create_async_queue(
181114
maxsize=max_queue_size
182115
)
183-
self._children: list[EventQueue] = []
116+
self._children: list[EventQueueLegacy] = []
184117
self._is_closed = False
185118
self._lock = asyncio.Lock()
186119
logger.debug('EventQueue initialized.')

0 commit comments

Comments
 (0)