Skip to content

Commit 56cb1f6

Browse files
committed
Add server support for propagating extensions, both input and output.
This commit adds support for extracting extensions requested by a client, and marking an extension as activated, which causes it to be returned to clients in a header. This is purely plumbing, no particular support for actually using extensions.
1 parent a3e6071 commit 56cb1f6

5 files changed

Lines changed: 102 additions & 8 deletions

File tree

src/a2a/extensions/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from a2a.types import AgentCard, AgentExtension
2+
3+
4+
HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'
5+
6+
7+
def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
8+
"""Find an AgentExtension in an AgentCard given a uri."""
9+
for ext in card.capabilities.extensions or []:
10+
if ext.uri == uri:
11+
return ext
12+
13+
return None

src/a2a/server/agent_execution/context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,24 @@ def metadata(self) -> dict[str, Any]:
143143
return {}
144144
return self._params.metadata or {}
145145

146+
def add_activated_extension(self, uri: str) -> None:
147+
"""Add an extension to the set of activated extensions for this request.
148+
149+
This causes the extension to be indicated back to the client in the
150+
response.
151+
"""
152+
if self._call_context:
153+
self._call_context.activated_extensions.add(uri)
154+
155+
@property
156+
def requested_extensions(self) -> set[str]:
157+
"""Extensions that the client requested to activate."""
158+
return (
159+
self._call_context.requested_extensions
160+
if self._call_context
161+
else set()
162+
)
163+
146164
def _check_or_generate_task_id(self) -> None:
147165
"""Ensures a task ID is present, generating one if necessary."""
148166
if not self._params:

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from a2a.auth.user import UnauthenticatedUser
2121
from a2a.auth.user import User as A2AUser
22+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
2223
from a2a.server.context import ServerCallContext
2324
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
2425
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -99,7 +100,15 @@ def build(self, request: Request) -> ServerCallContext:
99100
user = StarletteUserProxy(request.user)
100101
state['auth'] = request.auth
101102
state['headers'] = dict(request.headers)
102-
return ServerCallContext(user=user, state=state)
103+
return ServerCallContext(
104+
user=user,
105+
state=state,
106+
requested_extensions={
107+
ext
108+
for h in request.headers.getlist(HTTP_EXTENSION_HEADER)
109+
for ext in h.split(', ')
110+
},
111+
)
103112

104113

105114
class JSONRPCApplication(ABC):
@@ -281,7 +290,7 @@ async def _process_streaming_request(
281290
request_obj, context
282291
)
283292

284-
return self._create_response(handler_result)
293+
return self._create_response(context, handler_result)
285294

286295
async def _process_non_streaming_request(
287296
self,
@@ -353,10 +362,11 @@ async def _process_non_streaming_request(
353362
id=request_id, error=error
354363
)
355364

356-
return self._create_response(handler_result)
365+
return self._create_response(context, handler_result)
357366

358367
def _create_response(
359368
self,
369+
context: ServerCallContext,
360370
handler_result: (
361371
AsyncGenerator[SendStreamingMessageResponse]
362372
| JSONRPCErrorResponse
@@ -372,12 +382,16 @@ def _create_response(
372382
payloads.
373383
374384
Args:
385+
context: The ServerCallContext provided to the request handler.
375386
handler_result: The result from a request handler method. Can be an
376387
async generator for streaming or a Pydantic model for non-streaming.
377388
378389
Returns:
379390
A Starlette JSONResponse or EventSourceResponse.
380391
"""
392+
headers = {}
393+
if exts := context.activated_extensions:
394+
headers[HTTP_EXTENSION_HEADER] = ', '.join(exts)
381395
if isinstance(handler_result, AsyncGenerator):
382396
# Result is a stream of SendStreamingMessageResponse objects
383397
async def event_generator(
@@ -386,17 +400,21 @@ async def event_generator(
386400
async for item in stream:
387401
yield {'data': item.root.model_dump_json(exclude_none=True)}
388402

389-
return EventSourceResponse(event_generator(handler_result))
403+
return EventSourceResponse(
404+
event_generator(handler_result), headers=headers
405+
)
390406
if isinstance(handler_result, JSONRPCErrorResponse):
391407
return JSONResponse(
392408
handler_result.model_dump(
393409
mode='json',
394410
exclude_none=True,
395-
)
411+
),
412+
headers=headers,
396413
)
397414

398415
return JSONResponse(
399-
handler_result.root.model_dump(mode='json', exclude_none=True)
416+
handler_result.root.model_dump(mode='json', exclude_none=True),
417+
headers=headers,
400418
)
401419

402420
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:

src/a2a/server/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@ class ServerCallContext(BaseModel):
2121

2222
state: State = Field(default={})
2323
user: User = Field(default=UnauthenticatedUser())
24+
requested_extensions: set[str] = Field(default=set())
25+
activated_extensions: set[str] = Field(default=set())

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterable
6+
from collections.abc import AsyncIterable, Sequence
77

88

99
try:
@@ -16,10 +16,13 @@
1616
"'pip install a2a-sdk[grpc]'"
1717
) from e
1818

19+
from grpc.aio import Metadata
20+
1921
import a2a.grpc.a2a_pb2_grpc as a2a_grpc
2022

2123
from a2a import types
2224
from a2a.auth.user import UnauthenticatedUser
25+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
2326
from a2a.grpc import a2a_pb2
2427
from a2a.server.context import ServerCallContext
2528
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -42,6 +45,25 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
4245
"""Builds a ServerCallContext from a gRPC Request."""
4346

4447

48+
def _get_metadata_value(
49+
context: grpc.aio.ServicerContext, key: str
50+
) -> list[str]:
51+
md = context.invocation_metadata
52+
vs = []
53+
if isinstance(md, Metadata):
54+
vs = [
55+
e if isinstance(e, str) else e.decode('utf-8')
56+
for e in md.get_all(key)
57+
]
58+
elif isinstance(md, Sequence):
59+
vs = [
60+
e if isinstance(e, str) else e.decode('utf-8')
61+
for (k, e) in md
62+
if k == key.lower()
63+
]
64+
return vs
65+
66+
4567
class DefaultCallContextBuilder(CallContextBuilder):
4668
"""A default implementation of CallContextBuilder."""
4769

@@ -51,7 +73,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
5173
state = {}
5274
with contextlib.suppress(Exception):
5375
state['grpc_context'] = context
54-
return ServerCallContext(user=user, state=state)
76+
return ServerCallContext(
77+
user=user,
78+
state=state,
79+
requested_extensions=set(
80+
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
81+
),
82+
)
5583

5684

5785
class GrpcHandler(a2a_grpc.A2AServiceServicer):
@@ -102,6 +130,7 @@ async def SendMessage(
102130
task_or_message = await self.request_handler.on_message_send(
103131
a2a_request, server_context
104132
)
133+
self._set_extension_metadata(context, server_context)
105134
return proto_utils.ToProto.task_or_message(task_or_message)
106135
except ServerError as e:
107136
await self.abort_context(e, context)
@@ -140,6 +169,7 @@ async def SendStreamingMessage(
140169
a2a_request, server_context
141170
):
142171
yield proto_utils.ToProto.stream_response(event)
172+
self._set_extension_metadata(context, server_context)
143173
except ServerError as e:
144174
await self.abort_context(e, context)
145175
return
@@ -371,3 +401,16 @@ async def abort_context(
371401
grpc.StatusCode.UNKNOWN,
372402
f'Unknown error type: {error.error}',
373403
)
404+
405+
def _set_extension_metadata(
406+
self,
407+
context: grpc.aio.ServicerContext,
408+
server_context: ServerCallContext,
409+
) -> None:
410+
if server_context.activated_extensions:
411+
context.set_trailing_metadata(
412+
[
413+
(HTTP_EXTENSION_HEADER, e)
414+
for e in server_context.activated_extensions
415+
]
416+
)

0 commit comments

Comments
 (0)