Skip to content

Commit f5e0f72

Browse files
committed
Merge
1 parent a3fa1d1 commit f5e0f72

2 files changed

Lines changed: 16 additions & 13 deletions

File tree

src/a2a/server/routes/rest_dispatcher.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from a2a.utils import constants, proto_utils
2222
from a2a.utils.error_handlers import (
23+
build_rest_error_payload,
2324
rest_error_handler,
2425
rest_stream_error_handler,
2526
)
@@ -32,20 +33,23 @@
3233

3334

3435
if TYPE_CHECKING:
36+
from sse_starlette.event import ServerSentEvent
3537
from sse_starlette.sse import EventSourceResponse
3638
from starlette.requests import Request
3739
from starlette.responses import JSONResponse, Response
3840

3941
_package_starlette_installed = True
4042
else:
4143
try:
44+
from sse_starlette.event import ServerSentEvent
4245
from sse_starlette.sse import EventSourceResponse
4346
from starlette.requests import Request
4447
from starlette.responses import JSONResponse, Response
4548

4649
_package_starlette_installed = True
4750
except ImportError:
4851
EventSourceResponse = Any
52+
ServerSentEvent = Any
4953
Request = Any
5054
JSONResponse = Any
5155
Response = Any
@@ -135,10 +139,17 @@ async def _handle_streaming(
135139
except StopAsyncIteration:
136140
return EventSourceResponse(iter([]))
137141

138-
async def event_generator() -> AsyncIterator[str]:
142+
async def event_generator() -> AsyncIterator[str | ServerSentEvent]:
139143
yield json.dumps(first_item)
140-
async for item in stream:
141-
yield json.dumps(item)
144+
try:
145+
async for item in stream:
146+
yield json.dumps(item)
147+
except Exception as e:
148+
logger.exception('Error during REST SSE stream')
149+
yield ServerSentEvent(
150+
data=json.dumps(build_rest_error_payload(e)),
151+
event='error',
152+
)
142153

143154
return EventSourceResponse(event_generator())
144155

tests/server/routes/test_rest_routes.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from a2a.server.request_handlers.request_handler import RequestHandler
1313
from a2a.server.routes.rest_routes import create_rest_routes
1414
from a2a.types.a2a_pb2 import (
15-
AgentCapabilities,
1615
AgentCard,
1716
ListTasksResponse,
1817
Message,
@@ -108,16 +107,9 @@ def test_rest_list_tasks(agent_card, mock_handler):
108107

109108

110109
@pytest.fixture
111-
def streaming_agent_card():
112-
return AgentCard(
113-
capabilities=AgentCapabilities(streaming=True),
114-
)
115-
116-
117-
@pytest.fixture
118-
def streaming_app(streaming_agent_card, mock_handler):
110+
def streaming_app(mock_handler):
119111
routes = create_rest_routes(
120-
agent_card=streaming_agent_card, request_handler=mock_handler
112+
request_handler=mock_handler,
121113
)
122114
return Starlette(routes=routes)
123115

0 commit comments

Comments
 (0)