Skip to content

Commit 58911a9

Browse files
Merge branch '1.0-dev' into guglielmoc/remove_helpers
2 parents 3a16982 + b8df210 commit 58911a9

25 files changed

Lines changed: 395 additions & 155 deletions
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Context builders that add v0.3 backwards-compatibility for extensions.
2+
3+
The current spec uses ``A2A-Extensions`` (RFC 6648, no ``X-`` prefix). v0.3
4+
clients still send the old ``X-A2A-Extensions`` name, so the v0.3 compat
5+
adapters wrap the default builders with these classes to recognize both names.
6+
"""
7+
8+
from typing import TYPE_CHECKING, Any
9+
10+
import grpc
11+
12+
from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER
13+
from a2a.extensions.common import get_requested_extensions
14+
from a2a.server.context import ServerCallContext
15+
16+
17+
if TYPE_CHECKING:
18+
from starlette.requests import Request
19+
20+
from a2a.server.request_handlers.grpc_handler import (
21+
GrpcServerCallContextBuilder,
22+
)
23+
from a2a.server.routes.common import ServerCallContextBuilder
24+
else:
25+
try:
26+
from starlette.requests import Request
27+
except ImportError:
28+
Request = Any
29+
30+
31+
def _get_legacy_grpc_extensions(
32+
context: grpc.aio.ServicerContext,
33+
) -> list[str]:
34+
md = context.invocation_metadata()
35+
if md is None:
36+
return []
37+
lower_key = LEGACY_HTTP_EXTENSION_HEADER.lower()
38+
return [
39+
e if isinstance(e, str) else e.decode('utf-8')
40+
for k, e in md
41+
if k.lower() == lower_key
42+
]
43+
44+
45+
class V03ServerCallContextBuilder:
46+
"""Wraps a ServerCallContextBuilder to also accept the legacy header.
47+
48+
Recognizes the v0.3 ``X-A2A-Extensions`` HTTP header in addition to the
49+
spec ``A2A-Extensions``.
50+
"""
51+
52+
def __init__(self, inner: 'ServerCallContextBuilder') -> None:
53+
self._inner = inner
54+
55+
def build(self, request: 'Request') -> ServerCallContext:
56+
"""Builds a ServerCallContext, merging legacy extension headers."""
57+
context = self._inner.build(request)
58+
context.requested_extensions |= get_requested_extensions(
59+
request.headers.getlist(LEGACY_HTTP_EXTENSION_HEADER)
60+
)
61+
return context
62+
63+
64+
class V03GrpcServerCallContextBuilder:
65+
"""Wraps a GrpcServerCallContextBuilder to also accept the legacy metadata.
66+
67+
Recognizes the v0.3 ``X-A2A-Extensions`` gRPC metadata key in addition to
68+
the spec ``A2A-Extensions``.
69+
"""
70+
71+
def __init__(self, inner: 'GrpcServerCallContextBuilder') -> None:
72+
self._inner = inner
73+
74+
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
75+
"""Builds a ServerCallContext, merging legacy extension metadata."""
76+
server_context = self._inner.build(context)
77+
server_context.requested_extensions |= get_requested_extensions(
78+
_get_legacy_grpc_extensions(context)
79+
)
80+
return server_context
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Shared header name constants for v0.3 extension compatibility.
2+
3+
The current spec uses ``A2A-Extensions``. v0.3 used the ``X-`` prefixed
4+
``X-A2A-Extensions`` form. v0.3 compat servers and clients accept/emit both
5+
names so they can interoperate with peers that only know the legacy one.
6+
"""
7+
8+
from a2a.client.service_parameters import ServiceParameters
9+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
10+
11+
12+
LEGACY_HTTP_EXTENSION_HEADER = f'X-{HTTP_EXTENSION_HEADER}'
13+
14+
15+
def add_legacy_extension_header(parameters: ServiceParameters) -> None:
16+
"""Mirrors the ``A2A-Extensions`` parameter under its legacy name in-place.
17+
18+
Used by v0.3 compat client transports so that requests can be understood
19+
by older v0.3 servers that only recognize ``X-A2A-Extensions``.
20+
"""
21+
if (
22+
HTTP_EXTENSION_HEADER in parameters
23+
and LEGACY_HTTP_EXTENSION_HEADER not in parameters
24+
):
25+
parameters[LEGACY_HTTP_EXTENSION_HEADER] = parameters[
26+
HTTP_EXTENSION_HEADER
27+
]

src/a2a/compat/v0_3/grpc_handler.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from a2a.compat.v0_3 import (
1818
types as types_v03,
1919
)
20+
from a2a.compat.v0_3.context_builders import V03GrpcServerCallContextBuilder
2021
from a2a.compat.v0_3.request_handler import RequestHandler03
21-
from a2a.extensions.common import HTTP_EXTENSION_HEADER
2222
from a2a.server.context import ServerCallContext
2323
from a2a.server.request_handlers.grpc_handler import (
2424
_ERROR_CODE_MAP,
@@ -51,7 +51,7 @@ def __init__(
5151
DefaultCallContextBuilder is used.
5252
"""
5353
self.handler03 = RequestHandler03(request_handler=request_handler)
54-
self._context_builder = (
54+
self._context_builder = V03GrpcServerCallContextBuilder(
5555
context_builder or DefaultGrpcServerCallContextBuilder()
5656
)
5757

@@ -65,7 +65,6 @@ async def _handle_unary(
6565
try:
6666
server_context = self._context_builder.build(context)
6767
result = await handler_func(server_context)
68-
self._set_extension_metadata(context, server_context)
6968
except A2AError as e:
7069
await self.abort_context(e, context)
7170
else:
@@ -82,7 +81,6 @@ async def _handle_stream(
8281
server_context = self._context_builder.build(context)
8382
async for item in handler_func(server_context):
8483
yield item
85-
self._set_extension_metadata(context, server_context)
8684
except A2AError as e:
8785
await self.abort_context(e, context)
8886

@@ -120,19 +118,6 @@ async def abort_context(
120118
f'Unknown error type: {error}',
121119
)
122120

123-
def _set_extension_metadata(
124-
self,
125-
context: grpc.aio.ServicerContext,
126-
server_context: ServerCallContext,
127-
) -> None:
128-
if server_context.activated_extensions:
129-
context.set_trailing_metadata(
130-
[
131-
(HTTP_EXTENSION_HEADER.lower(), e)
132-
for e in sorted(server_context.activated_extensions)
133-
]
134-
)
135-
136121
async def SendMessage(
137122
self,
138123
request: a2a_v0_3_pb2.SendMessageRequest,

src/a2a/compat/v0_3/grpc_transport.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from a2a.compat.v0_3 import (
3131
types as types_v03,
3232
)
33+
from a2a.compat.v0_3.extension_headers import add_legacy_extension_header
3334
from a2a.types import a2a_pb2
3435
from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER
3536
from a2a.utils.telemetry import SpanKind, trace_class
@@ -361,7 +362,9 @@ def _get_grpc_metadata(
361362
metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_0_3)]
362363

363364
if context and context.service_parameters:
364-
for key, value in context.service_parameters.items():
365+
params = dict(context.service_parameters)
366+
add_legacy_extension_header(params)
367+
for key, value in params.items():
365368
metadata.append((key.lower(), value))
366369

367370
return metadata

src/a2a/compat/v0_3/jsonrpc_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_package_starlette_installed = False
2525

2626
from a2a.compat.v0_3 import types as types_v03
27+
from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder
2728
from a2a.compat.v0_3.request_handler import RequestHandler03
2829
from a2a.server.context import ServerCallContext
2930
from a2a.server.jsonrpc_models import (
@@ -70,7 +71,7 @@ def __init__(
7071
self.handler = RequestHandler03(
7172
request_handler=http_handler,
7273
)
73-
self._context_builder = (
74+
self._context_builder = V03ServerCallContextBuilder(
7475
context_builder or DefaultServerCallContextBuilder()
7576
)
7677

src/a2a/compat/v0_3/jsonrpc_transport.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from a2a.compat.v0_3 import conversions
2121
from a2a.compat.v0_3 import types as types_v03
22+
from a2a.compat.v0_3.extension_headers import add_legacy_extension_header
2223
from a2a.types.a2a_pb2 import (
2324
AgentCard,
2425
CancelTaskRequest,
@@ -424,6 +425,7 @@ async def _send_stream_request(
424425
http_kwargs = get_http_args(context)
425426
http_kwargs.setdefault('headers', {})
426427
http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3
428+
add_legacy_extension_header(http_kwargs['headers'])
427429

428430
async for sse_data in send_http_stream_request(
429431
self.httpx_client,
@@ -485,6 +487,7 @@ async def _send_request(
485487
http_kwargs = get_http_args(context)
486488
http_kwargs.setdefault('headers', {})
487489
http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3
490+
add_legacy_extension_header(http_kwargs['headers'])
488491

489492
request = self.httpx_client.build_request(
490493
'POST',

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_package_starlette_installed = False
3232

3333

34+
from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder
3435
from a2a.compat.v0_3.rest_handler import REST03Handler
3536
from a2a.server.routes.common import (
3637
DefaultServerCallContextBuilder,
@@ -60,7 +61,7 @@ def __init__(
6061
context_builder: 'ServerCallContextBuilder | None' = None,
6162
):
6263
self.handler = REST03Handler(request_handler=http_handler)
63-
self._context_builder = (
64+
self._context_builder = V03ServerCallContextBuilder(
6465
context_builder or DefaultServerCallContextBuilder()
6566
)
6667

src/a2a/compat/v0_3/rest_transport.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from a2a.compat.v0_3 import (
2626
types as types_v03,
2727
)
28+
from a2a.compat.v0_3.extension_headers import add_legacy_extension_header
2829
from a2a.types.a2a_pb2 import (
2930
AgentCard,
3031
CancelTaskRequest,
@@ -380,6 +381,7 @@ async def _send_stream_request(
380381
http_kwargs = get_http_args(context)
381382
http_kwargs.setdefault('headers', {})
382383
http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3
384+
add_legacy_extension_header(http_kwargs['headers'])
383385

384386
async for sse_data in send_http_stream_request(
385387
self.httpx_client,
@@ -414,6 +416,7 @@ async def _execute_request(
414416
http_kwargs = get_http_args(context)
415417
http_kwargs.setdefault('headers', {})
416418
http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3
419+
add_legacy_extension_header(http_kwargs['headers'])
417420

418421
request = self.httpx_client.build_request(
419422
method,

src/a2a/extensions/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from a2a.types.a2a_pb2 import AgentCard, AgentExtension
22

33

4-
HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'
4+
HTTP_EXTENSION_HEADER = 'A2A-Extensions'
55

66

77
def get_requested_extensions(values: list[str]) -> set[str]:

src/a2a/server/agent_execution/context.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,22 +151,14 @@ def metadata(self) -> dict[str, Any]:
151151
return dict(self._params.metadata)
152152
return {}
153153

154-
def add_activated_extension(self, uri: str) -> None:
155-
"""Add an extension to the set of activated extensions for this request.
156-
157-
This causes the extension to be indicated back to the client in the
158-
response.
159-
"""
160-
self._call_context.activated_extensions.add(uri)
161-
162154
@property
163155
def tenant(self) -> str:
164156
"""The tenant associated with this request."""
165157
return self._call_context.tenant
166158

167159
@property
168160
def requested_extensions(self) -> set[str]:
169-
"""Extensions that the client requested to activate."""
161+
"""Extensions that the client requested for this interaction."""
170162
return self._call_context.requested_extensions
171163

172164
def _check_or_generate_task_id(self) -> None:

0 commit comments

Comments
 (0)