Skip to content

Commit abfdd8e

Browse files
committed
eager eval on jsonrpc
1 parent be4c5ff commit abfdd8e

2 files changed

Lines changed: 84 additions & 1 deletion

File tree

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,34 @@ async def _process_streaming_request(
376376
if stream is None:
377377
raise UnsupportedOperationError(message='Stream not supported')
378378

379+
# Eagerly fetch the first item from the stream so that errors raised
380+
# before any event is yielded (e.g. task terminal state checks)
381+
# propagate here and result in a standard JSON-RPC error response
382+
# instead of establishing a broken SSE stream.
383+
stream = aiter(stream)
384+
try:
385+
first_event = await anext(stream)
386+
except StopAsyncIteration:
387+
388+
async def _empty_gen() -> AsyncGenerator[dict[str, Any], None]:
389+
if False:
390+
yield {}
391+
392+
return _empty_gen()
393+
379394
async def _wrap_stream(
395+
first_evt: Any,
380396
st: AsyncGenerator,
381397
) -> AsyncGenerator[dict[str, Any], None]:
382398
try:
399+
# Yield the first event
400+
stream_response = proto_utils.to_stream_response(first_evt)
401+
result = MessageToDict(
402+
stream_response, preserving_proto_field_name=False
403+
)
404+
yield JSONRPC20Response(result=result, _id=request_id).data
405+
406+
# Yield the rest of the events
383407
async for event in st:
384408
stream_response = proto_utils.to_stream_response(event)
385409
result = MessageToDict(
@@ -389,7 +413,7 @@ async def _wrap_stream(
389413
except A2AError as e:
390414
yield build_error_response(request_id, e)
391415

392-
return _wrap_stream(stream)
416+
return _wrap_stream(first_event, stream)
393417

394418
async def _handle_send_message(
395419
self, request_obj: SendMessageRequest, context: ServerCallContext

tests/integration/test_client_server_integration.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,65 @@ async def mock_generator(*args, **kwargs):
10191019
await client.close()
10201020

10211021

1022+
@pytest.mark.asyncio
1023+
@pytest.mark.parametrize(
1024+
'error_cls,handler_attr,client_method,request_params',
1025+
[
1026+
pytest.param(
1027+
UnsupportedOperationError,
1028+
'on_subscribe_to_task',
1029+
'subscribe',
1030+
SubscribeToTaskRequest(id='some-id'),
1031+
id='subscribe',
1032+
),
1033+
],
1034+
)
1035+
async def test_server_rejects_stream_on_validation_error(
1036+
transport_setups, error_cls, handler_attr, client_method, request_params
1037+
) -> None:
1038+
"""Verify that the server returns an error directly and doesn't open a stream on validation error."""
1039+
client = transport_setups.client
1040+
handler = transport_setups.handler
1041+
1042+
async def mock_generator(*args, **kwargs):
1043+
raise error_cls('Validation failed')
1044+
yield
1045+
1046+
getattr(handler, handler_attr).side_effect = mock_generator
1047+
1048+
transport = client._transport
1049+
1050+
if isinstance(transport, (RestTransport, JsonRpcTransport)):
1051+
# Spy on httpx client to check response headers
1052+
original_send = transport.httpx_client.send
1053+
response_headers = {}
1054+
1055+
async def mock_send(*args, **kwargs):
1056+
resp = await original_send(*args, **kwargs)
1057+
response_headers['Content-Type'] = resp.headers.get('Content-Type')
1058+
return resp
1059+
1060+
transport.httpx_client.send = mock_send
1061+
1062+
try:
1063+
with pytest.raises(error_cls):
1064+
async for _ in getattr(client, client_method)(request=request_params):
1065+
pass
1066+
finally:
1067+
transport.httpx_client.send = original_send
1068+
1069+
# Verify that the response content type was NOT text/event-stream
1070+
assert not response_headers.get('Content-Type', '').startswith('text/event-stream')
1071+
else:
1072+
# For gRPC, we just verify it raises the error
1073+
with pytest.raises(error_cls):
1074+
async for _ in getattr(client, client_method)(request=request_params):
1075+
pass
1076+
1077+
getattr(handler, handler_attr).side_effect = None
1078+
await client.close()
1079+
1080+
10221081
@pytest.mark.asyncio
10231082
@pytest.mark.parametrize(
10241083
'request_kwargs, expected_error_code',

0 commit comments

Comments
 (0)