|
20 | 20 | ) |
21 | 21 | from a2a.utils import constants, proto_utils |
22 | 22 | from a2a.utils.error_handlers import ( |
| 23 | + build_rest_error_payload, |
23 | 24 | rest_error_handler, |
24 | 25 | rest_stream_error_handler, |
25 | 26 | ) |
|
32 | 33 |
|
33 | 34 |
|
34 | 35 | if TYPE_CHECKING: |
| 36 | + from sse_starlette.event import ServerSentEvent |
35 | 37 | from sse_starlette.sse import EventSourceResponse |
36 | 38 | from starlette.requests import Request |
37 | 39 | from starlette.responses import JSONResponse, Response |
38 | 40 |
|
39 | 41 | _package_starlette_installed = True |
40 | 42 | else: |
41 | 43 | try: |
| 44 | + from sse_starlette.event import ServerSentEvent |
42 | 45 | from sse_starlette.sse import EventSourceResponse |
43 | 46 | from starlette.requests import Request |
44 | 47 | from starlette.responses import JSONResponse, Response |
45 | 48 |
|
46 | 49 | _package_starlette_installed = True |
47 | 50 | except ImportError: |
48 | 51 | EventSourceResponse = Any |
| 52 | + ServerSentEvent = Any |
49 | 53 | Request = Any |
50 | 54 | JSONResponse = Any |
51 | 55 | Response = Any |
@@ -135,10 +139,17 @@ async def _handle_streaming( |
135 | 139 | except StopAsyncIteration: |
136 | 140 | return EventSourceResponse(iter([])) |
137 | 141 |
|
138 | | - async def event_generator() -> AsyncIterator[str]: |
| 142 | + async def event_generator() -> AsyncIterator[str | ServerSentEvent]: |
139 | 143 | yield json.dumps(first_item) |
140 | | - async for item in stream: |
141 | | - yield json.dumps(item) |
| 144 | + try: |
| 145 | + async for item in stream: |
| 146 | + yield json.dumps(item) |
| 147 | + except Exception as e: |
| 148 | + logger.exception('Error during REST SSE stream') |
| 149 | + yield ServerSentEvent( |
| 150 | + data=json.dumps(build_rest_error_payload(e)), |
| 151 | + event='error', |
| 152 | + ) |
142 | 153 |
|
143 | 154 | return EventSourceResponse(event_generator()) |
144 | 155 |
|
|
0 commit comments