Skip to content

Commit abb85dd

Browse files
committed
fix: propagate activated extensions for REST
1 parent 2846be6 commit abb85dd

2 files changed

Lines changed: 95 additions & 26 deletions

File tree

src/a2a/server/routes/rest_dispatcher.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from google.protobuf.json_format import MessageToDict, Parse
88

9+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
910
from a2a.server.context import ServerCallContext
1011
from a2a.server.request_handlers.request_handler import RequestHandler
1112
from a2a.server.routes.common import (
@@ -99,14 +100,29 @@ def _build_call_context(self, request: Request) -> ServerCallContext:
99100
call_context.tenant = request.path_params['tenant']
100101
return call_context
101102

103+
def _extension_headers(self, context: ServerCallContext) -> dict[str, str]:
104+
"""Builds response headers carrying the activated extensions, if any."""
105+
if exts := context.activated_extensions:
106+
return {HTTP_EXTENSION_HEADER: ', '.join(sorted(exts))}
107+
return {}
108+
102109
async def _handle_non_streaming(
103110
self,
104111
request: Request,
105112
handler_func: Callable[[ServerCallContext], Awaitable[TResponse]],
106-
) -> TResponse:
107-
"""Centralized error handling and context management for unary calls."""
113+
serializer: Callable[[TResponse], Any] = MessageToDict,
114+
) -> JSONResponse:
115+
"""Centralized error handling and context management for unary calls.
116+
117+
Builds the call context, invokes the handler, and wraps the result in
118+
a `JSONResponse` carrying any activated-extension headers.
119+
"""
108120
context = self._build_call_context(request)
109-
return await handler_func(context)
121+
response = await handler_func(context)
122+
return JSONResponse(
123+
content=serializer(response),
124+
headers=self._extension_headers(context),
125+
)
110126

111127
async def _handle_streaming(
112128
self,
@@ -137,7 +153,9 @@ async def _handle_streaming(
137153
try:
138154
first_item = await anext(stream)
139155
except StopAsyncIteration:
140-
return EventSourceResponse(iter([]))
156+
return EventSourceResponse(
157+
iter([]), headers=self._extension_headers(context)
158+
)
141159

142160
async def event_generator() -> AsyncIterator[ServerSentEvent]:
143161
yield ServerSentEvent(data=json.dumps(first_item))
@@ -151,7 +169,9 @@ async def event_generator() -> AsyncIterator[ServerSentEvent]:
151169
event='error',
152170
)
153171

154-
return EventSourceResponse(event_generator())
172+
return EventSourceResponse(
173+
event_generator(), headers=self._extension_headers(context)
174+
)
155175

156176
@rest_error_handler
157177
async def on_message_send(self, request: Request) -> Response:
@@ -171,8 +191,7 @@ async def _handler(
171191
return a2a_pb2.SendMessageResponse(task=task_or_message)
172192
return a2a_pb2.SendMessageResponse(message=task_or_message)
173193

174-
response = await self._handle_non_streaming(request, _handler)
175-
return JSONResponse(content=MessageToDict(response))
194+
return await self._handle_non_streaming(request, _handler)
176195

177196
@rest_stream_error_handler
178197
async def on_message_send_stream(
@@ -209,8 +228,7 @@ async def _handler(context: ServerCallContext) -> a2a_pb2.Task:
209228
return task
210229
raise TaskNotFoundError
211230

212-
response = await self._handle_non_streaming(request, _handler)
213-
return JSONResponse(content=MessageToDict(response))
231+
return await self._handle_non_streaming(request, _handler)
214232

215233
@rest_stream_error_handler
216234
async def on_subscribe_to_task(
@@ -245,8 +263,7 @@ async def _handler(context: ServerCallContext) -> a2a_pb2.Task:
245263
return task
246264
raise TaskNotFoundError
247265

248-
response = await self._handle_non_streaming(request, _handler)
249-
return JSONResponse(content=MessageToDict(response))
266+
return await self._handle_non_streaming(request, _handler)
250267

251268
@rest_error_handler
252269
async def get_push_notification(self, request: Request) -> Response:
@@ -267,8 +284,7 @@ async def _handler(
267284
)
268285
)
269286

270-
response = await self._handle_non_streaming(request, _handler)
271-
return JSONResponse(content=MessageToDict(response))
287+
return await self._handle_non_streaming(request, _handler)
272288

273289
@rest_error_handler
274290
async def delete_push_notification(self, request: Request) -> Response:
@@ -285,8 +301,9 @@ async def _handler(context: ServerCallContext) -> None:
285301
params, context
286302
)
287303

288-
await self._handle_non_streaming(request, _handler)
289-
return JSONResponse(content={})
304+
return await self._handle_non_streaming(
305+
request, _handler, serializer=lambda _: {}
306+
)
290307

291308
@rest_error_handler
292309
async def set_push_notification(self, request: Request) -> Response:
@@ -304,8 +321,7 @@ async def _handler(
304321
params, context
305322
)
306323

307-
response = await self._handle_non_streaming(request, _handler)
308-
return JSONResponse(content=MessageToDict(response))
324+
return await self._handle_non_streaming(request, _handler)
309325

310326
@rest_error_handler
311327
async def list_push_notifications(self, request: Request) -> Response:
@@ -322,8 +338,7 @@ async def _handler(
322338
params, context
323339
)
324340

325-
response = await self._handle_non_streaming(request, _handler)
326-
return JSONResponse(content=MessageToDict(response))
341+
return await self._handle_non_streaming(request, _handler)
327342

328343
@rest_error_handler
329344
async def list_tasks(self, request: Request) -> Response:
@@ -337,11 +352,12 @@ async def _handler(
337352
proto_utils.parse_params(request.query_params, params)
338353
return await self.request_handler.on_list_tasks(params, context)
339354

340-
response = await self._handle_non_streaming(request, _handler)
341-
return JSONResponse(
342-
content=MessageToDict(
343-
response, always_print_fields_with_no_presence=True
344-
)
355+
return await self._handle_non_streaming(
356+
request,
357+
_handler,
358+
serializer=lambda r: MessageToDict(
359+
r, always_print_fields_with_no_presence=True
360+
),
345361
)
346362

347363
@rest_error_handler
@@ -359,5 +375,4 @@ async def _handler(
359375
params, context
360376
)
361377

362-
response = await self._handle_non_streaming(request, _handler)
363-
return JSONResponse(content=MessageToDict(response))
378+
return await self._handle_non_streaming(request, _handler)

tests/integration/test_end_to_end.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ServiceParametersFactory,
1616
with_a2a_extensions,
1717
)
18+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
1819
from a2a.server.agent_execution import AgentExecutor, RequestContext
1920
from a2a.server.events import EventQueue
2021
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
@@ -835,3 +836,56 @@ async def test_end_to_end_extensions_propagation(transport_setups, streaming):
835836
response.message, Role.ROLE_AGENT, 'extensions echoed'
836837
)
837838
assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS)
839+
840+
841+
@pytest.mark.asyncio
842+
@pytest.mark.parametrize(
843+
'transport_fixture',
844+
[
845+
pytest.param('rest_setup', id='REST'),
846+
pytest.param('jsonrpc_setup', id='JSON-RPC'),
847+
],
848+
)
849+
async def test_end_to_end_extensions_response_header(
850+
request, transport_fixture
851+
):
852+
"""Test that activated extensions are returned in the X-A2A-Extensions
853+
response header for HTTP-based transports."""
854+
setup = request.getfixturevalue(transport_fixture)
855+
client = setup.client
856+
client._config.streaming = False
857+
858+
captured_headers: list[httpx.Headers] = []
859+
860+
async def capture_response(response: httpx.Response) -> None:
861+
captured_headers.append(response.headers)
862+
863+
client._transport.httpx_client.event_hooks['response'].append(
864+
capture_response
865+
)
866+
867+
service_params = ServiceParametersFactory.create(
868+
[with_a2a_extensions(SUPPORTED_EXTENSION_URIS)]
869+
)
870+
context = ClientCallContext(service_parameters=service_params)
871+
872+
message_to_send = Message(
873+
role=Role.ROLE_USER,
874+
message_id='msg-ext-header',
875+
parts=[Part(text='Extensions: echo')],
876+
)
877+
878+
async for _ in client.send_message(
879+
request=SendMessageRequest(message=message_to_send),
880+
context=context,
881+
):
882+
pass
883+
884+
assert captured_headers, 'No HTTP response was captured'
885+
response_headers = captured_headers[-1]
886+
assert HTTP_EXTENSION_HEADER in response_headers
887+
returned = {
888+
ext.strip()
889+
for ext in response_headers[HTTP_EXTENSION_HEADER].split(',')
890+
}
891+
assert returned == set(SUPPORTED_EXTENSION_URIS)

0 commit comments

Comments
 (0)