Skip to content

Commit 63f986f

Browse files
committed
Merge
1 parent e7c62c3 commit 63f986f

3 files changed

Lines changed: 152 additions & 6 deletions

File tree

src/a2a/server/routes/rest_routes.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder
1515
from a2a.types.a2a_pb2 import AgentCard
1616
from a2a.utils.error_handlers import (
17+
build_rest_error_payload,
1718
rest_error_handler,
1819
rest_stream_error_handler,
1920
)
@@ -25,6 +26,7 @@
2526

2627

2728
if TYPE_CHECKING:
29+
from sse_starlette.event import ServerSentEvent
2830
from sse_starlette.sse import EventSourceResponse
2931
from starlette.requests import Request
3032
from starlette.responses import JSONResponse, Response
@@ -33,6 +35,7 @@
3335
_package_starlette_installed = True
3436
else:
3537
try:
38+
from sse_starlette.event import ServerSentEvent
3639
from sse_starlette.sse import EventSourceResponse
3740
from starlette.requests import Request
3841
from starlette.responses import JSONResponse, Response
@@ -47,6 +50,7 @@
4750
Route = Any
4851
Mount = Any
4952
BaseRoute = Any
53+
ServerSentEvent = Any
5054

5155
_package_starlette_installed = False
5256

@@ -168,10 +172,17 @@ async def _handle_streaming_request(
168172
except StopAsyncIteration:
169173
return EventSourceResponse(iter([]))
170174

171-
async def event_generator() -> AsyncIterator[str]:
175+
async def event_generator() -> AsyncIterator[str | ServerSentEvent]:
172176
yield json.dumps(first_item)
173-
async for item in stream:
174-
yield json.dumps(item)
177+
try:
178+
async for item in stream:
179+
yield json.dumps(item)
180+
except Exception as e:
181+
logger.exception('Error during REST SSE stream')
182+
yield ServerSentEvent(
183+
data=json.dumps(build_rest_error_payload(e)),
184+
event='error',
185+
)
175186

176187
return EventSourceResponse(event_generator())
177188

src/a2a/utils/error_handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ async def error_catching_iterator() -> AsyncGenerator[
173173
try:
174174
async for item in original_iterator:
175175
yield item
176-
except Exception as stream_error: # noqa: BLE001
176+
except Exception as stream_error:
177177
_log_error(stream_error)
178178
raise stream_error
179179

tests/server/routes/test_rest_routes.py

Lines changed: 137 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1+
import json
2+
13
from unittest.mock import AsyncMock
24

5+
import httpx
36
import pytest
7+
48
from starlette.applications import Starlette
9+
from starlette.routing import BaseRoute
510
from starlette.testclient import TestClient
6-
from starlette.routing import BaseRoute, Route
711

812
from a2a.server.request_handlers.request_handler import RequestHandler
913
from a2a.server.routes.rest_routes import create_rest_routes
10-
from a2a.types.a2a_pb2 import AgentCard, Task, ListTasksResponse
14+
from a2a.types.a2a_pb2 import (
15+
AgentCapabilities,
16+
AgentCard,
17+
ListTasksResponse,
18+
Message,
19+
Part,
20+
Role,
21+
Task,
22+
)
23+
from a2a.utils.errors import InternalError
1124

1225

1326
@pytest.fixture
@@ -103,3 +116,125 @@ def test_rest_list_tasks(agent_card, mock_handler):
103116
response = client.get('/tasks', headers={'A2A-Version': '1.0'})
104117
assert response.status_code == 200
105118
assert mock_handler.on_list_tasks.called
119+
120+
121+
@pytest.fixture
122+
def streaming_agent_card():
123+
return AgentCard(
124+
capabilities=AgentCapabilities(streaming=True),
125+
)
126+
127+
128+
@pytest.fixture
129+
def streaming_app(streaming_agent_card, mock_handler):
130+
routes = create_rest_routes(
131+
agent_card=streaming_agent_card, request_handler=mock_handler
132+
)
133+
return Starlette(routes=routes)
134+
135+
136+
@pytest.fixture
137+
def streaming_client(streaming_app):
138+
return httpx.AsyncClient(
139+
transport=httpx.ASGITransport(app=streaming_app),
140+
base_url='http://test',
141+
headers={'A2A-Version': '1.0'},
142+
)
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_streaming_mid_stream_error_emits_sse_error_event(
147+
streaming_client, mock_handler
148+
):
149+
"""Test that mid-stream errors are sent as SSE error events."""
150+
151+
async def mock_stream_then_error(*args, **kwargs):
152+
yield Message(
153+
message_id='stream_msg_1',
154+
role=Role.ROLE_AGENT,
155+
parts=[Part(text='First chunk')],
156+
)
157+
raise InternalError(message='Something went wrong mid-stream')
158+
159+
mock_handler.on_message_send_stream.side_effect = mock_stream_then_error
160+
161+
response = await streaming_client.post(
162+
'/message:stream',
163+
headers={'Accept': 'text/event-stream'},
164+
json={},
165+
)
166+
167+
response.raise_for_status()
168+
assert 'text/event-stream' in response.headers.get('content-type', '')
169+
170+
lines = [line.strip() for line in response.text.strip().splitlines()]
171+
172+
# Should have a normal data event followed by an error event
173+
data_lines = [
174+
json.loads(line[6:]) for line in lines if line.startswith('data: ')
175+
]
176+
assert len(data_lines) >= 1
177+
assert 'message' in data_lines[0]
178+
assert data_lines[0]['message']['messageId'] == 'stream_msg_1'
179+
180+
# Should contain an SSE error event
181+
error_event_lines = [line for line in lines if line == 'event: error']
182+
assert len(error_event_lines) == 1
183+
184+
# Find the error data after the error event
185+
error_data = None
186+
for i, line in enumerate(lines):
187+
if line == 'event: error':
188+
for j in range(i + 1, len(lines)):
189+
if lines[j].startswith('data: '):
190+
error_data = json.loads(lines[j][6:])
191+
break
192+
break
193+
194+
assert error_data is not None
195+
assert error_data['error']['code'] == 500
196+
assert error_data['error']['status'] == 'INTERNAL'
197+
assert 'Something went wrong mid-stream' in error_data['error']['message']
198+
199+
200+
@pytest.mark.asyncio
201+
async def test_streaming_mid_stream_unknown_error_emits_sse_error_event(
202+
streaming_client, mock_handler
203+
):
204+
"""Test that non-A2AError mid-stream errors also produce SSE error events."""
205+
206+
async def mock_stream_then_error(*args, **kwargs):
207+
yield Message(
208+
message_id='stream_msg_1',
209+
role=Role.ROLE_AGENT,
210+
parts=[Part(text='First chunk')],
211+
)
212+
raise RuntimeError('Unexpected failure')
213+
214+
mock_handler.on_message_send_stream.side_effect = mock_stream_then_error
215+
216+
response = await streaming_client.post(
217+
'/message:stream',
218+
headers={'Accept': 'text/event-stream'},
219+
json={},
220+
)
221+
222+
response.raise_for_status()
223+
224+
lines = [line.strip() for line in response.text.strip().splitlines()]
225+
226+
error_event_lines = [line for line in lines if line == 'event: error']
227+
assert len(error_event_lines) == 1
228+
229+
error_data = None
230+
for i, line in enumerate(lines):
231+
if line == 'event: error':
232+
for j in range(i + 1, len(lines)):
233+
if lines[j].startswith('data: '):
234+
error_data = json.loads(lines[j][6:])
235+
break
236+
break
237+
238+
assert error_data is not None
239+
assert error_data['error']['code'] == 500
240+
assert error_data['error']['status'] == 'INTERNAL'

0 commit comments

Comments
 (0)