66
77from google .protobuf .json_format import MessageToDict , Parse
88
9+ from a2a .extensions .common import HTTP_EXTENSION_HEADER
910from a2a .server .context import ServerCallContext
1011from a2a .server .request_handlers .request_handler import RequestHandler
1112from a2a .server .routes .common import (
@@ -99,14 +100,29 @@ def _build_call_context(self, request: Request) -> ServerCallContext:
99100 call_context .tenant = request .path_params ['tenant' ]
100101 return call_context
101102
103+ def _extension_headers (self , context : ServerCallContext ) -> dict [str , str ]:
104+ """Builds response headers carrying the activated extensions, if any."""
105+ if exts := context .activated_extensions :
106+ return {HTTP_EXTENSION_HEADER : ', ' .join (sorted (exts ))}
107+ return {}
108+
102109 async def _handle_non_streaming (
103110 self ,
104111 request : Request ,
105112 handler_func : Callable [[ServerCallContext ], Awaitable [TResponse ]],
106- ) -> TResponse :
107- """Centralized error handling and context management for unary calls."""
113+ serializer : Callable [[TResponse ], Any ] = MessageToDict ,
114+ ) -> JSONResponse :
115+ """Centralized error handling and context management for unary calls.
116+
117+ Builds the call context, invokes the handler, and wraps the result in
118+ a `JSONResponse` carrying any activated-extension headers.
119+ """
108120 context = self ._build_call_context (request )
109- return await handler_func (context )
121+ response = await handler_func (context )
122+ return JSONResponse (
123+ content = serializer (response ),
124+ headers = self ._extension_headers (context ),
125+ )
110126
111127 async def _handle_streaming (
112128 self ,
@@ -137,7 +153,9 @@ async def _handle_streaming(
137153 try :
138154 first_item = await anext (stream )
139155 except StopAsyncIteration :
140- return EventSourceResponse (iter ([]))
156+ return EventSourceResponse (
157+ iter ([]), headers = self ._extension_headers (context )
158+ )
141159
142160 async def event_generator () -> AsyncIterator [ServerSentEvent ]:
143161 yield ServerSentEvent (data = json .dumps (first_item ))
@@ -151,7 +169,9 @@ async def event_generator() -> AsyncIterator[ServerSentEvent]:
151169 event = 'error' ,
152170 )
153171
154- return EventSourceResponse (event_generator ())
172+ return EventSourceResponse (
173+ event_generator (), headers = self ._extension_headers (context )
174+ )
155175
156176 @rest_error_handler
157177 async def on_message_send (self , request : Request ) -> Response :
@@ -171,8 +191,7 @@ async def _handler(
171191 return a2a_pb2 .SendMessageResponse (task = task_or_message )
172192 return a2a_pb2 .SendMessageResponse (message = task_or_message )
173193
174- response = await self ._handle_non_streaming (request , _handler )
175- return JSONResponse (content = MessageToDict (response ))
194+ return await self ._handle_non_streaming (request , _handler )
176195
177196 @rest_stream_error_handler
178197 async def on_message_send_stream (
@@ -209,8 +228,7 @@ async def _handler(context: ServerCallContext) -> a2a_pb2.Task:
209228 return task
210229 raise TaskNotFoundError
211230
212- response = await self ._handle_non_streaming (request , _handler )
213- return JSONResponse (content = MessageToDict (response ))
231+ return await self ._handle_non_streaming (request , _handler )
214232
215233 @rest_stream_error_handler
216234 async def on_subscribe_to_task (
@@ -245,8 +263,7 @@ async def _handler(context: ServerCallContext) -> a2a_pb2.Task:
245263 return task
246264 raise TaskNotFoundError
247265
248- response = await self ._handle_non_streaming (request , _handler )
249- return JSONResponse (content = MessageToDict (response ))
266+ return await self ._handle_non_streaming (request , _handler )
250267
251268 @rest_error_handler
252269 async def get_push_notification (self , request : Request ) -> Response :
@@ -267,8 +284,7 @@ async def _handler(
267284 )
268285 )
269286
270- response = await self ._handle_non_streaming (request , _handler )
271- return JSONResponse (content = MessageToDict (response ))
287+ return await self ._handle_non_streaming (request , _handler )
272288
273289 @rest_error_handler
274290 async def delete_push_notification (self , request : Request ) -> Response :
@@ -285,8 +301,9 @@ async def _handler(context: ServerCallContext) -> None:
285301 params , context
286302 )
287303
288- await self ._handle_non_streaming (request , _handler )
289- return JSONResponse (content = {})
304+ return await self ._handle_non_streaming (
305+ request , _handler , serializer = lambda _ : {}
306+ )
290307
291308 @rest_error_handler
292309 async def set_push_notification (self , request : Request ) -> Response :
@@ -304,8 +321,7 @@ async def _handler(
304321 params , context
305322 )
306323
307- response = await self ._handle_non_streaming (request , _handler )
308- return JSONResponse (content = MessageToDict (response ))
324+ return await self ._handle_non_streaming (request , _handler )
309325
310326 @rest_error_handler
311327 async def list_push_notifications (self , request : Request ) -> Response :
@@ -322,8 +338,7 @@ async def _handler(
322338 params , context
323339 )
324340
325- response = await self ._handle_non_streaming (request , _handler )
326- return JSONResponse (content = MessageToDict (response ))
341+ return await self ._handle_non_streaming (request , _handler )
327342
328343 @rest_error_handler
329344 async def list_tasks (self , request : Request ) -> Response :
@@ -337,11 +352,12 @@ async def _handler(
337352 proto_utils .parse_params (request .query_params , params )
338353 return await self .request_handler .on_list_tasks (params , context )
339354
340- response = await self ._handle_non_streaming (request , _handler )
341- return JSONResponse (
342- content = MessageToDict (
343- response , always_print_fields_with_no_presence = True
344- )
355+ return await self ._handle_non_streaming (
356+ request ,
357+ _handler ,
358+ serializer = lambda r : MessageToDict (
359+ r , always_print_fields_with_no_presence = True
360+ ),
345361 )
346362
347363 @rest_error_handler
@@ -359,5 +375,4 @@ async def _handler(
359375 params , context
360376 )
361377
362- response = await self ._handle_non_streaming (request , _handler )
363- return JSONResponse (content = MessageToDict (response ))
378+ return await self ._handle_non_streaming (request , _handler )
0 commit comments