Skip to content

Commit 0c7874a

Browse files
committed
fix: fix athrow() RuntimeError on streaming responses
Fixes #911
1 parent 8d18d3d commit 0c7874a

2 files changed

Lines changed: 179 additions & 4 deletions

File tree

src/a2a/client/transports/http_helpers.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import json
2-
32
from collections.abc import AsyncGenerator, Callable, Iterator
43
from contextlib import contextmanager
54
from typing import Any, NoReturn
65

76
import httpx
8-
9-
from httpx_sse import SSEError, aconnect_sse
7+
from httpx_sse import EventSource, SSEError
108

119
from a2a.client.client import ClientCallContext
1210
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
@@ -75,7 +73,7 @@ async def send_http_stream_request(
7573
) -> AsyncGenerator[str]:
7674
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
7775
with handle_http_exceptions(status_error_handler):
78-
async with aconnect_sse(
76+
async with _SSEEventSource(
7977
httpx_client, method, url, **kwargs
8078
) as event_source:
8179
try:
@@ -98,3 +96,45 @@ async def send_http_stream_request(
9896
if not sse.data:
9997
continue
10098
yield sse.data
99+
100+
101+
class _SSEEventSource:
102+
"""Class-based async context manager for SSE connections.
103+
104+
A drop-in replacement for ``httpx_sse.aconnect_sse`` that is safe to use
105+
inside async generators. ``aconnect_sse`` uses ``@asynccontextmanager``
106+
which internally creates an async generator registered with the event
107+
loop. When the enclosing async generator is abandoned,
108+
``shutdown_asyncgens`` tries to finalize both generators independently
109+
and the nested ``athrow()`` calls collide with
110+
``RuntimeError: athrow(): asynchronous generator is already running``.
111+
112+
This class avoids the problem because ``__aenter__``/``__aexit__`` are
113+
plain coroutines - no async generators are created, nothing is registered
114+
with ``loop._asyncgens``, and ``shutdown_asyncgens`` has nothing to
115+
collide with.
116+
"""
117+
118+
def __init__(
119+
self,
120+
client: httpx.AsyncClient,
121+
method: str,
122+
url: str,
123+
**kwargs: Any,
124+
) -> None:
125+
headers = kwargs.pop('headers', {})
126+
headers['Accept'] = 'text/event-stream'
127+
headers['Cache-Control'] = 'no-store'
128+
self._request = client.build_request(
129+
method, url, headers=headers, **kwargs
130+
)
131+
self._client = client
132+
self._response: httpx.Response | None = None
133+
134+
async def __aenter__(self) -> EventSource:
135+
self._response = await self._client.send(self._request, stream=True)
136+
return EventSource(self._response)
137+
138+
async def __aexit__(self, *args: object) -> None:
139+
if self._response is not None:
140+
await self._response.aclose()
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Test that streaming SSE responses clean up without athrow() errors.
2+
3+
Reproduces https://github.com/a2aproject/a2a-python/issues/XXX —
4+
``RuntimeError: athrow(): asynchronous generator is already running``
5+
during event-loop shutdown after consuming a streaming response.
6+
"""
7+
8+
import asyncio
9+
import gc
10+
11+
from typing import Any
12+
from uuid import uuid4
13+
14+
import httpx
15+
import pytest
16+
17+
from starlette.applications import Starlette
18+
19+
from a2a.client.base_client import BaseClient
20+
from a2a.client.client import ClientConfig
21+
from a2a.client.client_factory import ClientFactory
22+
from a2a.server.agent_execution import AgentExecutor, RequestContext
23+
from a2a.server.events import EventQueue
24+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
25+
from a2a.server.request_handlers import DefaultRequestHandler
26+
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
27+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
28+
from a2a.types import (
29+
AgentCapabilities,
30+
AgentCard,
31+
AgentInterface,
32+
Message,
33+
Part,
34+
Role,
35+
SendMessageRequest,
36+
)
37+
from a2a.utils import TransportProtocol
38+
39+
40+
class _MessageExecutor(AgentExecutor):
41+
"""Responds with a single Message event."""
42+
43+
async def execute(self, ctx: RequestContext, eq: EventQueue) -> None:
44+
await eq.enqueue_event(
45+
Message(
46+
role=Role.ROLE_AGENT,
47+
message_id=str(uuid4()),
48+
parts=[Part(text='Hello')],
49+
context_id=ctx.context_id,
50+
task_id=ctx.task_id,
51+
)
52+
)
53+
54+
async def cancel(self, ctx: RequestContext, eq: EventQueue) -> None:
55+
pass
56+
57+
58+
@pytest.fixture
59+
def client():
60+
"""Creates a JSON-RPC client backed by an in-process ASGI server."""
61+
card = AgentCard(
62+
name='T',
63+
description='T',
64+
version='1',
65+
capabilities=AgentCapabilities(streaming=True),
66+
default_input_modes=['text/plain'],
67+
default_output_modes=['text/plain'],
68+
supported_interfaces=[
69+
AgentInterface(
70+
protocol_binding=TransportProtocol.JSONRPC,
71+
url='http://test',
72+
),
73+
],
74+
)
75+
handler = DefaultRequestHandler(
76+
agent_executor=_MessageExecutor(),
77+
task_store=InMemoryTaskStore(),
78+
queue_manager=InMemoryQueueManager(),
79+
)
80+
app = Starlette(
81+
routes=[
82+
*create_agent_card_routes(agent_card=card, card_url='/card'),
83+
*create_jsonrpc_routes(
84+
agent_card=card,
85+
request_handler=handler,
86+
extended_agent_card=card,
87+
rpc_url='/',
88+
),
89+
]
90+
)
91+
return ClientFactory(
92+
config=ClientConfig(
93+
httpx_client=httpx.AsyncClient(
94+
transport=httpx.ASGITransport(app=app),
95+
base_url='http://test',
96+
)
97+
)
98+
).create(card)
99+
100+
101+
@pytest.mark.asyncio
102+
async def test_stream_message_no_athrow(client: BaseClient) -> None:
103+
"""Consuming a streamed Message must not leave broken async generators."""
104+
errors: list[dict[str, Any]] = []
105+
loop = asyncio.get_event_loop()
106+
orig = loop.get_exception_handler()
107+
loop.set_exception_handler(lambda _l, ctx: errors.append(ctx))
108+
109+
try:
110+
msg = Message(
111+
role=Role.ROLE_USER,
112+
message_id=f'msg-{uuid4()}',
113+
parts=[Part(text='hi')],
114+
)
115+
events = [
116+
e
117+
async for e in client.send_message(
118+
request=SendMessageRequest(message=msg)
119+
)
120+
]
121+
assert events
122+
assert events[0][0].HasField('message')
123+
124+
gc.collect()
125+
await loop.shutdown_asyncgens()
126+
127+
bad = [
128+
e
129+
for e in errors
130+
if 'asynchronous generator' in str(e.get('message', ''))
131+
]
132+
assert not bad, '\n'.join(str(e.get('message', '')) for e in bad)
133+
finally:
134+
loop.set_exception_handler(orig)
135+
await client.close()

0 commit comments

Comments
 (0)