Skip to content

Commit 2fd8508

Browse files
authored
Merge branch '1.0-dev' into feat/pydantic-codegen
2 parents 29e61b3 + ca7edc3 commit 2fd8508

8 files changed

Lines changed: 192 additions & 24 deletions

File tree

src/a2a/client/transports/http_helpers.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import httpx
88

9-
from httpx_sse import SSEError, aconnect_sse
9+
from httpx_sse import EventSource, SSEError
1010

1111
from a2a.client.client import ClientCallContext
1212
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
@@ -75,7 +75,7 @@ async def send_http_stream_request(
7575
) -> AsyncGenerator[str]:
7676
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
7777
with handle_http_exceptions(status_error_handler):
78-
async with aconnect_sse(
78+
async with _SSEEventSource(
7979
httpx_client, method, url, **kwargs
8080
) as event_source:
8181
try:
@@ -98,3 +98,39 @@ async def send_http_stream_request(
9898
if not sse.data:
9999
continue
100100
yield sse.data
101+
102+
103+
class _SSEEventSource:
104+
"""Class-based replacement for ``httpx_sse.aconnect_sse``.
105+
106+
``aconnect_sse`` is an ``@asynccontextmanager`` whose internal async
107+
generator gets tracked by the event loop. When the enclosing async
108+
generator is abandoned, the event loop's generator cleanup collides
109+
with the cascading cleanup — see https://bugs.python.org/issue38559.
110+
111+
Plain ``__aenter__``/``__aexit__`` coroutines avoid this entirely.
112+
"""
113+
114+
def __init__(
115+
self,
116+
client: httpx.AsyncClient,
117+
method: str,
118+
url: str,
119+
**kwargs: Any,
120+
) -> None:
121+
headers = httpx.Headers(kwargs.pop('headers', None))
122+
headers.setdefault('Accept', 'text/event-stream')
123+
headers.setdefault('Cache-Control', 'no-store')
124+
self._request = client.build_request(
125+
method, url, headers=headers, **kwargs
126+
)
127+
self._client = client
128+
self._response: httpx.Response | None = None
129+
130+
async def __aenter__(self) -> EventSource:
131+
self._response = await self._client.send(self._request, stream=True)
132+
return EventSource(self._response)
133+
134+
async def __aexit__(self, *args: object) -> None:
135+
if self._response is not None:
136+
await self._response.aclose()

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ async def event_generator(
109109
)
110110

111111
async def handle_get_agent_card(
112-
self, request: Request, call_context: ServerCallContext | None = None
112+
self, request: Request, call_context: ServerCallContext
113113
) -> dict[str, Any]:
114114
"""Handles GET requests for the agent card endpoint."""
115115
card_to_serve = self.agent_card
@@ -119,7 +119,7 @@ async def handle_get_agent_card(
119119
return v03_card.model_dump(mode='json', exclude_none=True)
120120

121121
async def handle_authenticated_agent_card(
122-
self, request: Request, call_context: ServerCallContext | None = None
122+
self, request: Request, call_context: ServerCallContext
123123
) -> dict[str, Any]:
124124
"""Hook for per credential agent card response."""
125125
if not self.agent_card.capabilities.extended_agent_card:
@@ -132,9 +132,8 @@ async def handle_authenticated_agent_card(
132132
card_to_serve = self.agent_card
133133

134134
if self.extended_card_modifier:
135-
context = self._context_builder.build(request)
136135
card_to_serve = await maybe_await(
137-
self.extended_card_modifier(card_to_serve, context)
136+
self.extended_card_modifier(card_to_serve, call_context)
138137
)
139138
elif self.card_modifier:
140139
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def __init__( # noqa: PLR0913
215215
extended_agent_card: An optional, distinct AgentCard to be served
216216
at the authenticated extended card endpoint.
217217
context_builder: The CallContextBuilder used to construct the
218-
ServerCallContext passed to the request_handler. If None, no
219-
ServerCallContext is passed.
218+
ServerCallContext passed to the request_handler. If None the
219+
DefaultCallContextBuilder is used.
220220
card_modifier: An optional callback to dynamically modify the public
221221
agent card before it is served.
222222
extended_card_modifier: An optional callback to dynamically modify

src/a2a/server/routes/jsonrpc_routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def create_jsonrpc_routes( # noqa: PLR0913
5454
extended_agent_card: An optional, distinct AgentCard to be served
5555
at the authenticated extended card endpoint.
5656
context_builder: The CallContextBuilder used to construct the
57-
ServerCallContext passed to the request_handler. If None, no
58-
ServerCallContext is passed.
57+
ServerCallContext passed to the request_handler. If None the
58+
DefaultCallContextBuilder is used.
5959
card_modifier: An optional callback to dynamically modify the public
6060
agent card before it is served.
6161
extended_card_modifier: An optional callback to dynamically modify

src/a2a/server/routes/rest_routes.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def create_rest_routes( # noqa: PLR0913
7676
extended_agent_card: An optional, distinct AgentCard to be served
7777
at the authenticated extended card endpoint.
7878
context_builder: The CallContextBuilder used to construct the
79-
ServerCallContext passed to the request_handler. If None, no
80-
ServerCallContext is passed.
79+
ServerCallContext passed to the request_handler. If None the
80+
DefaultCallContextBuilder is used.
8181
card_modifier: An optional callback to dynamically modify the public
8282
agent card before it is served.
8383
extended_card_modifier: An optional callback to dynamically modify
@@ -176,7 +176,7 @@ async def event_generator() -> AsyncIterator[str]:
176176
return EventSourceResponse(event_generator())
177177

178178
async def _handle_authenticated_agent_card(
179-
request: 'Request', call_context: ServerCallContext | None = None
179+
request: 'Request', call_context: ServerCallContext
180180
) -> dict[str, Any]:
181181
if not agent_card.capabilities.extended_agent_card:
182182
raise ExtendedAgentCardNotConfiguredError(
@@ -185,10 +185,8 @@ async def _handle_authenticated_agent_card(
185185
card_to_serve = extended_agent_card or agent_card
186186

187187
if extended_card_modifier:
188-
# Re-generate context if none passed to replicate RESTAdapter exact logic
189-
context = call_context or _build_call_context(request)
190188
card_to_serve = await maybe_await(
191-
extended_card_modifier(card_to_serve, context)
189+
extended_card_modifier(card_to_serve, call_context)
192190
)
193191
elif card_modifier:
194192
card_to_serve = await maybe_await(card_modifier(card_to_serve))

tests/client/transports/test_jsonrpc_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ async def test_close(self, transport, mock_httpx_client):
433433

434434
class TestStreamingErrors:
435435
@pytest.mark.asyncio
436-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
436+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
437437
async def test_send_message_streaming_sse_error(
438438
self,
439439
mock_aconnect_sse: AsyncMock,
@@ -457,7 +457,7 @@ async def test_send_message_streaming_sse_error(
457457
pass
458458

459459
@pytest.mark.asyncio
460-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
460+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
461461
async def test_send_message_streaming_request_error(
462462
self,
463463
mock_aconnect_sse: AsyncMock,
@@ -483,7 +483,7 @@ async def test_send_message_streaming_request_error(
483483
pass
484484

485485
@pytest.mark.asyncio
486-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
486+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
487487
async def test_send_message_streaming_timeout(
488488
self,
489489
mock_aconnect_sse: AsyncMock,
@@ -560,7 +560,7 @@ async def test_extensions_added_to_request(
560560
)
561561

562562
@pytest.mark.asyncio
563-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
563+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
564564
async def test_send_message_streaming_server_error_propagates(
565565
self,
566566
mock_aconnect_sse: AsyncMock,

tests/client/transports/test_rest_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
7070

7171
class TestRestTransport:
7272
@pytest.mark.asyncio
73-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
73+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
7474
async def test_send_message_streaming_timeout(
7575
self,
7676
mock_aconnect_sse: AsyncMock,
@@ -280,7 +280,7 @@ async def test_send_message_with_default_extensions(
280280
)
281281

282282
@pytest.mark.asyncio
283-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
283+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
284284
async def test_send_message_streaming_with_new_extensions(
285285
self,
286286
mock_aconnect_sse: AsyncMock,
@@ -329,7 +329,7 @@ async def test_send_message_streaming_with_new_extensions(
329329
)
330330

331331
@pytest.mark.asyncio
332-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
332+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
333333
async def test_send_message_streaming_server_error_propagates(
334334
self,
335335
mock_aconnect_sse: AsyncMock,
@@ -693,7 +693,7 @@ async def test_rest_get_task_prepend_empty_tenant(
693693
],
694694
)
695695
@pytest.mark.asyncio
696-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
696+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
697697
async def test_rest_streaming_methods_prepend_tenant( # noqa: PLR0913
698698
self,
699699
mock_aconnect_sse,
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Test that streaming SSE responses clean up without athrow() errors.
2+
3+
Reproduces https://github.com/a2aproject/a2a-python/issues/911 —
4+
``RuntimeError: athrow(): asynchronous generator is already running``
5+
during event-loop shutdown after consuming a streaming response.
6+
"""
7+
8+
import asyncio
9+
import gc
10+
11+
from typing import Any
12+
from uuid import uuid4
13+
14+
import httpx
15+
import pytest
16+
17+
from starlette.applications import Starlette
18+
19+
from a2a.client.base_client import BaseClient
20+
from a2a.client.client import ClientConfig
21+
from a2a.client.client_factory import ClientFactory
22+
from a2a.server.agent_execution import AgentExecutor, RequestContext
23+
from a2a.server.events import EventQueue
24+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
25+
from a2a.server.request_handlers import DefaultRequestHandler
26+
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
27+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
28+
from a2a.types import (
29+
AgentCapabilities,
30+
AgentCard,
31+
AgentInterface,
32+
Message,
33+
Part,
34+
Role,
35+
SendMessageRequest,
36+
)
37+
from a2a.utils import TransportProtocol
38+
39+
40+
class _MessageExecutor(AgentExecutor):
41+
"""Responds with a single Message event."""
42+
43+
async def execute(self, ctx: RequestContext, eq: EventQueue) -> None:
44+
await eq.enqueue_event(
45+
Message(
46+
role=Role.ROLE_AGENT,
47+
message_id=str(uuid4()),
48+
parts=[Part(text='Hello')],
49+
context_id=ctx.context_id,
50+
task_id=ctx.task_id,
51+
)
52+
)
53+
54+
async def cancel(self, ctx: RequestContext, eq: EventQueue) -> None:
55+
pass
56+
57+
58+
@pytest.fixture
59+
def client():
60+
"""Creates a JSON-RPC client backed by an in-process ASGI server."""
61+
card = AgentCard(
62+
name='T',
63+
description='T',
64+
version='1',
65+
capabilities=AgentCapabilities(streaming=True),
66+
default_input_modes=['text/plain'],
67+
default_output_modes=['text/plain'],
68+
supported_interfaces=[
69+
AgentInterface(
70+
protocol_binding=TransportProtocol.JSONRPC,
71+
url='http://test',
72+
),
73+
],
74+
)
75+
handler = DefaultRequestHandler(
76+
agent_executor=_MessageExecutor(),
77+
task_store=InMemoryTaskStore(),
78+
queue_manager=InMemoryQueueManager(),
79+
)
80+
app = Starlette(
81+
routes=[
82+
*create_agent_card_routes(agent_card=card, card_url='/card'),
83+
*create_jsonrpc_routes(
84+
agent_card=card,
85+
request_handler=handler,
86+
extended_agent_card=card,
87+
rpc_url='/',
88+
),
89+
]
90+
)
91+
return ClientFactory(
92+
config=ClientConfig(
93+
httpx_client=httpx.AsyncClient(
94+
transport=httpx.ASGITransport(app=app),
95+
base_url='http://test',
96+
)
97+
)
98+
).create(card)
99+
100+
101+
@pytest.mark.asyncio
102+
async def test_stream_message_no_athrow(client: BaseClient) -> None:
103+
"""Consuming a streamed Message must not leave broken async generators."""
104+
errors: list[dict[str, Any]] = []
105+
loop = asyncio.get_event_loop()
106+
orig = loop.get_exception_handler()
107+
loop.set_exception_handler(lambda _l, ctx: errors.append(ctx))
108+
109+
try:
110+
msg = Message(
111+
role=Role.ROLE_USER,
112+
message_id=f'msg-{uuid4()}',
113+
parts=[Part(text='hi')],
114+
)
115+
events = [
116+
e
117+
async for e in client.send_message(
118+
request=SendMessageRequest(message=msg)
119+
)
120+
]
121+
assert events
122+
assert events[0][0].HasField('message')
123+
124+
gc.collect()
125+
await loop.shutdown_asyncgens()
126+
127+
bad = [
128+
e
129+
for e in errors
130+
if 'asynchronous generator' in str(e.get('message', ''))
131+
]
132+
assert not bad, '\n'.join(str(e.get('message', '')) for e in bad)
133+
finally:
134+
loop.set_exception_handler(orig)
135+
await client.close()

0 commit comments

Comments
 (0)