Skip to content

Commit e6462ca

Browse files
committed
WIP
1 parent 81d1e16 commit e6462ca

3 files changed

Lines changed: 62 additions & 21 deletions

File tree

src/a2a/client/transports/http_helpers.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,20 @@ async def send_http_stream_request(
7878
async with aconnect_sse(
7979
httpx_client, method, url, **kwargs
8080
) as event_source:
81-
event_source.response.raise_for_status()
82-
async for sse in event_source.aiter_sse():
83-
if not sse.data:
84-
continue
85-
yield sse.data
81+
try:
82+
event_source.response.raise_for_status()
83+
except httpx.HTTPStatusError as e:
84+
await event_source.response.aread()
85+
raise e
86+
87+
try:
88+
async for sse in event_source.aiter_sse():
89+
if not sse.data:
90+
continue
91+
yield sse.data
92+
except SSEError as e:
93+
if 'application/json' in event_source.response.headers.get('content-type', ''):
94+
content = await event_source.response.aread()
95+
yield content.decode('utf-8')
96+
else:
97+
raise e

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ async def _process_streaming_request(
482482
request_obj, context
483483
)
484484

485-
return self._create_response(context, handler_result)
485+
return await self._create_response(context, handler_result)
486486

487487
async def _process_non_streaming_request(
488488
self,
@@ -562,9 +562,9 @@ async def _process_non_streaming_request(
562562
)
563563
return self._generate_error_response(request_id, error)
564564

565-
return self._create_response(context, handler_result)
565+
return await self._create_response(context, handler_result)
566566

567-
def _create_response(
567+
async def _create_response(
568568
self,
569569
context: ServerCallContext,
570570
handler_result: AsyncGenerator[dict[str, Any]] | dict[str, Any],
@@ -587,15 +587,35 @@ def _create_response(
587587
if exts := context.activated_extensions:
588588
headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts))
589589
if isinstance(handler_result, AsyncGenerator):
590+
try:
591+
# Prime to see if it fails upfront
592+
first_item = await handler_result.__anext__()
593+
except StopAsyncIteration:
594+
595+
async def empty_generator() -> AsyncGenerator[dict[str, str], None]:
596+
if False:
597+
yield {}
598+
599+
return EventSourceResponse(empty_generator(), headers=headers)
600+
except Exception as e:
601+
logger.debug('Upfront exception in streaming handler: %s', e)
602+
if not isinstance(e, A2AError | JSONRPCError):
603+
e = InternalError(message=str(e))
604+
request_id = context.state.get('request_id')
605+
error_payload = build_error_response(request_id, e)
606+
return JSONResponse(error_payload, headers=headers)
607+
590608
# Result is a stream of dict objects
591609
async def event_generator(
592610
stream: AsyncGenerator[dict[str, Any]],
593-
) -> AsyncGenerator[dict[str, str]]:
611+
first_item: dict[str, Any],
612+
) -> AsyncGenerator[dict[str, str], None]:
613+
yield {'data': json.dumps(first_item)}
594614
async for item in stream:
595615
yield {'data': json.dumps(item)}
596616

597617
return EventSourceResponse(
598-
event_generator(handler_result), headers=headers
618+
event_generator(handler_result, first_item), headers=headers
599619
)
600620

601621
# handler_result is a dict (JSON-RPC response)

src/a2a/utils/error_handlers.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,26 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
165165
):
166166
original_iterator = response.body_iterator
167167

168-
async def error_catching_iterator() -> AsyncGenerator[
169-
Any, None
170-
]:
171-
try:
172-
async for item in original_iterator:
173-
yield item
174-
except Exception as stream_error:
175-
_log_error(stream_error)
176-
raise stream_error
177-
178-
response.body_iterator = error_catching_iterator()
168+
try:
169+
# Prime the stream to catch upfront errors
170+
first_item = await original_iterator.__anext__()
171+
except StopAsyncIteration:
172+
# Stream is empty
173+
pass
174+
except Exception as e: # noqa: BLE001
175+
return _create_error_response(e)
176+
else:
177+
178+
async def error_catching_iterator() -> AsyncGenerator[Any, None]:
179+
yield first_item
180+
try:
181+
async for item in original_iterator:
182+
yield item
183+
except Exception as stream_error:
184+
_log_error(stream_error)
185+
raise stream_error
186+
187+
response.body_iterator = error_catching_iterator()
179188

180189
except Exception as e: # noqa: BLE001
181190
return _create_error_response(e)

0 commit comments

Comments
 (0)