Skip to content

Commit c6ae449

Browse files
committed
revert
1 parent 9e018a1 commit c6ae449

1 file changed

Lines changed: 25 additions & 10 deletions

File tree

src/a2a/server/routes/rest_routes.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,27 +139,42 @@ async def _handle_request(
139139

140140
@rest_stream_error_handler
141141
async def _handle_streaming_request(
142-
method: Callable[['Request', ServerCallContext], AsyncIterable[Any]],
143-
request: 'Request',
144-
) -> 'EventSourceResponse':
142+
self,
143+
method: Callable[[Request, ServerCallContext], AsyncIterable[Any]],
144+
request: Request,
145+
) -> EventSourceResponse:
146+
# Pre-consume and cache the request body to prevent deadlock in streaming context
147+
# This is required because Starlette's request.body() can only be consumed once,
148+
# and attempting to consume it after EventSourceResponse starts causes deadlock
145149
try:
146150
await request.body()
147151
except (ValueError, RuntimeError, OSError) as e:
148152
raise InvalidRequestError(
149153
message=f'Failed to pre-consume request body: {e}'
150154
) from e
151155

152-
call_context = _build_call_context(request)
156+
call_context = self._build_call_context(request)
153157

154-
async def event_generator(
155-
stream: AsyncIterable[Any],
156-
) -> AsyncIterator[str]:
158+
# Eagerly fetch the first item from the stream so that errors raised
159+
# before any event is yielded (e.g. validation, parsing, or handler
160+
# failures) propagate here and are caught by
161+
# @rest_stream_error_handler, which returns a JSONResponse with
162+
# the correct HTTP status code instead of starting an SSE stream.
163+
# Without this, the error would be raised after SSE headers are
164+
# already sent, and the client would see a broken stream instead
165+
# of a proper error response.
166+
stream = aiter(method(request, call_context))
167+
try:
168+
first_item = await anext(stream)
169+
except StopAsyncIteration:
170+
return EventSourceResponse(iter([]))
171+
172+
async def event_generator() -> AsyncIterator[str]:
173+
yield json.dumps(first_item)
157174
async for item in stream:
158175
yield json.dumps(item)
159176

160-
return EventSourceResponse(
161-
event_generator(method(request, call_context))
162-
)
177+
return EventSourceResponse(event_generator())
163178

164179
async def _handle_authenticated_agent_card(
165180
request: 'Request', call_context: ServerCallContext | None = None

0 commit comments

Comments
 (0)