Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/a2a/server/routes/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,20 +376,32 @@
if stream is None:
raise UnsupportedOperationError(message='Stream not supported')

# Eagerly fetch the first event to trigger validation/upfront errors
try:
first_event = await anext(stream)
except StopAsyncIteration:
first_event = None

async def _wrap_stream(
st: AsyncGenerator,
st: AsyncGenerator, first_evt: Any | None

Check failure on line 386 in src/a2a/server/routes/jsonrpc_dispatcher.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`evt` is not a recognized word (unrecognized-spelling)
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated
) -> AsyncGenerator[dict[str, Any], None]:
def _map_event(evt: Any) -> dict[str, Any]:

Check failure on line 388 in src/a2a/server/routes/jsonrpc_dispatcher.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`evt` is not a recognized word (unrecognized-spelling)
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated
stream_response = proto_utils.to_stream_response(evt)

Check failure on line 389 in src/a2a/server/routes/jsonrpc_dispatcher.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`evt` is not a recognized word (unrecognized-spelling)
result = MessageToDict(
stream_response, preserving_proto_field_name=False
)
return JSONRPC20Response(result=result, _id=request_id).data

try:
if first_evt is not None:

Check failure on line 396 in src/a2a/server/routes/jsonrpc_dispatcher.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`evt` is not a recognized word (unrecognized-spelling)
yield _map_event(first_evt)

Check failure on line 397 in src/a2a/server/routes/jsonrpc_dispatcher.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`evt` is not a recognized word (unrecognized-spelling)

async for event in st:
stream_response = proto_utils.to_stream_response(event)
result = MessageToDict(
stream_response, preserving_proto_field_name=False
)
yield JSONRPC20Response(result=result, _id=request_id).data
yield _map_event(event)
except A2AError as e:
yield build_error_response(request_id, e)

return _wrap_stream(stream)
return _wrap_stream(stream, first_event)

async def _handle_send_message(
self, request_obj: SendMessageRequest, context: ServerCallContext
Expand Down
65 changes: 65 additions & 0 deletions tests/integration/test_client_server_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,71 @@ async def mock_generator(*args, **kwargs):
await client.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'error_cls,handler_attr,client_method,request_params',
[
pytest.param(
UnsupportedOperationError,
'on_subscribe_to_task',
'subscribe',
SubscribeToTaskRequest(id='some-id'),
id='subscribe',
),
],
)
Comment thread
guglielmo-san marked this conversation as resolved.
async def test_server_rejects_stream_on_validation_error(
transport_setups, error_cls, handler_attr, client_method, request_params
) -> None:
"""Verify that the server returns an error directly and doesn't open a stream on validation error."""
client = transport_setups.client
handler = transport_setups.handler

async def mock_generator(*args, **kwargs):
raise error_cls('Validation failed')
yield

getattr(handler, handler_attr).side_effect = mock_generator

transport = client._transport

if isinstance(transport, (RestTransport, JsonRpcTransport)):
# Spy on httpx client to check response headers
original_send = transport.httpx_client.send
response_headers = {}

async def mock_send(*args, **kwargs):
resp = await original_send(*args, **kwargs)
response_headers['Content-Type'] = resp.headers.get('Content-Type')
return resp

transport.httpx_client.send = mock_send

try:
with pytest.raises(error_cls):
async for _ in getattr(client, client_method)(
request=request_params
):
pass
finally:
transport.httpx_client.send = original_send

# Verify that the response content type was NOT text/event-stream
assert not response_headers.get('Content-Type', '').startswith(
'text/event-stream'
)
else:
# For gRPC, we just verify it raises the error
with pytest.raises(error_cls):
async for _ in getattr(client, client_method)(
request=request_params
):
pass

getattr(handler, handler_attr).side_effect = None
await client.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'request_kwargs, expected_error_code',
Expand Down
Loading