33import logging
44
55from abc import ABC , abstractmethod
6- from collections .abc import AsyncIterable
6+ from collections .abc import AsyncIterable , Sequence
77
88
99try :
1616 "'pip install a2a-sdk[grpc]'"
1717 ) from e
1818
19+ from grpc .aio import Metadata
20+
1921import a2a .grpc .a2a_pb2_grpc as a2a_grpc
2022
2123from a2a import types
2224from a2a .auth .user import UnauthenticatedUser
25+ from a2a .extensions .common import HTTP_EXTENSION_HEADER
2326from a2a .grpc import a2a_pb2
2427from a2a .server .context import ServerCallContext
2528from a2a .server .request_handlers .request_handler import RequestHandler
@@ -42,6 +45,25 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
4245 """Builds a ServerCallContext from a gRPC Request."""
4346
4447
48+ def _get_metadata_value (
49+ context : grpc .aio .ServicerContext , key : str
50+ ) -> list [str ]:
51+ md = context .invocation_metadata
52+ vs = []
53+ if isinstance (md , Metadata ):
54+ vs = [
55+ e if isinstance (e , str ) else e .decode ('utf-8' )
56+ for e in md .get_all (key )
57+ ]
58+ elif isinstance (md , Sequence ):
59+ vs = [
60+ e if isinstance (e , str ) else e .decode ('utf-8' )
61+ for (k , e ) in md
62+ if k == key .lower ()
63+ ]
64+ return vs
65+
66+
4567class DefaultCallContextBuilder (CallContextBuilder ):
4668 """A default implementation of CallContextBuilder."""
4769
@@ -51,7 +73,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
5173 state = {}
5274 with contextlib .suppress (Exception ):
5375 state ['grpc_context' ] = context
54- return ServerCallContext (user = user , state = state )
76+ return ServerCallContext (
77+ user = user ,
78+ state = state ,
79+ requested_extensions = set (
80+ _get_metadata_value (context , HTTP_EXTENSION_HEADER )
81+ ),
82+ )
5583
5684
5785class GrpcHandler (a2a_grpc .A2AServiceServicer ):
@@ -102,6 +130,7 @@ async def SendMessage(
102130 task_or_message = await self .request_handler .on_message_send (
103131 a2a_request , server_context
104132 )
133+ self ._set_extension_metadata (context , server_context )
105134 return proto_utils .ToProto .task_or_message (task_or_message )
106135 except ServerError as e :
107136 await self .abort_context (e , context )
@@ -140,6 +169,7 @@ async def SendStreamingMessage(
140169 a2a_request , server_context
141170 ):
142171 yield proto_utils .ToProto .stream_response (event )
172+ self ._set_extension_metadata (context , server_context )
143173 except ServerError as e :
144174 await self .abort_context (e , context )
145175 return
@@ -371,3 +401,16 @@ async def abort_context(
371401 grpc .StatusCode .UNKNOWN ,
372402 f'Unknown error type: { error .error } ' ,
373403 )
404+
405+ def _set_extension_metadata (
406+ self ,
407+ context : grpc .aio .ServicerContext ,
408+ server_context : ServerCallContext ,
409+ ) -> None :
410+ if server_context .activated_extensions :
411+ context .set_trailing_metadata (
412+ [
413+ (HTTP_EXTENSION_HEADER , e )
414+ for e in server_context .activated_extensions
415+ ]
416+ )
0 commit comments