Skip to content

Commit 36fd9bd

Browse files
committed
fix
1 parent 91b0112 commit 36fd9bd

1 file changed

Lines changed: 12 additions & 22 deletions

File tree

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -376,42 +376,32 @@ async def _process_streaming_request(
376376
if stream is None:
377377
raise UnsupportedOperationError(message='Stream not supported')
378378

379-
# Eagerly fetch the first item from the stream so that errors raised
380-
# before any event is yielded (e.g. task terminal state checks)
381-
# propagate here and result in a standard JSON-RPC error response
382-
# instead of establishing a broken SSE stream.
383-
stream = aiter(stream)
379+
# Eagerly fetch the first event to trigger validation/upfront errors
384380
try:
385381
first_event = await anext(stream)
386382
except StopAsyncIteration:
387-
async def _empty_gen() -> AsyncGenerator[dict[str, Any], None]:
388-
return
389-
yield
390-
return _empty_gen()
383+
first_event = None
391384

392385
async def _wrap_stream(
393-
first_evt: Any,
394-
st: AsyncGenerator,
386+
st: AsyncGenerator, event: Any | None
395387
) -> AsyncGenerator[dict[str, Any], None]:
396-
try:
397-
# Yield the first event
398-
stream_response = proto_utils.to_stream_response(first_evt)
388+
def _map_event(evt: Any) -> dict[str, Any]:
389+
stream_response = proto_utils.to_stream_response(evt)
399390
result = MessageToDict(
400391
stream_response, preserving_proto_field_name=False
401392
)
402-
yield JSONRPC20Response(result=result, _id=request_id).data
393+
return JSONRPC20Response(result=result, _id=request_id).data
394+
395+
try:
396+
if event is not None:
397+
yield _map_event(event)
403398

404-
# Yield the rest of the events
405399
async for event in st:
406-
stream_response = proto_utils.to_stream_response(event)
407-
result = MessageToDict(
408-
stream_response, preserving_proto_field_name=False
409-
)
410-
yield JSONRPC20Response(result=result, _id=request_id).data
400+
yield _map_event(event)
411401
except A2AError as e:
412402
yield build_error_response(request_id, e)
413403

414-
return _wrap_stream(first_event, stream)
404+
return _wrap_stream(stream, first_event)
415405

416406
async def _handle_send_message(
417407
self, request_obj: SendMessageRequest, context: ServerCallContext

0 commit comments

Comments
 (0)