From 32574ec2280251447a5e4759a4e0bdd5464f4085 Mon Sep 17 00:00:00 2001 From: Tim Niemueller Date: Tue, 3 Mar 2026 14:08:41 +0100 Subject: [PATCH] Add extra grpc header which is all lower case According to https://httpwg.org/specs/rfc7540.html#HttpHeaders headers in HTTP/2 (grpc transport) must all be lower case. Thus, using A2A via grpc fails when using the HTTP header. Add specific all lowercase header for gRPC. --- src/a2a/extensions/common.py | 1 + .../server/request_handlers/grpc_handler.py | 6 ++--- .../request_handlers/test_grpc_handler.py | 22 +++++++++---------- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index cba3517e4..10526f02a 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -4,6 +4,7 @@ HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' +GRPC_EXTENSION_HEADER = 'x-a2a-extensions' def get_requested_extensions(values: list[str]) -> set[str]: diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 1280b92aa..b2a039fa1 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -23,7 +23,7 @@ from a2a import types from a2a.auth.user import UnauthenticatedUser from a2a.extensions.common import ( - HTTP_EXTENSION_HEADER, + GRPC_EXTENSION_HEADER, get_requested_extensions, ) from a2a.grpc import a2a_pb2 @@ -76,7 +76,7 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: user=user, state=state, requested_extensions=get_requested_extensions( - _get_metadata_value(context, HTTP_EXTENSION_HEADER) + _get_metadata_value(context, GRPC_EXTENSION_HEADER) ), ) @@ -417,7 +417,7 @@ def _set_extension_metadata( if server_context.activated_extensions: context.set_trailing_metadata( [ - (HTTP_EXTENSION_HEADER.lower(), e) + (GRPC_EXTENSION_HEADER.lower(), e) for e in sorted(server_context.activated_extensions) ] ) diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 9d8da2bb4..6c9828079 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -5,7 +5,7 @@ import pytest from a2a import types -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.extensions.common import GRPC_EXTENSION_HEADER from a2a.grpc import a2a_pb2 from a2a.server.context import ServerCallContext from a2a.server.request_handlers import GrpcHandler, RequestHandler @@ -350,8 +350,8 @@ async def test_send_message_with_extensions( mock_grpc_context: AsyncMock, ) -> None: mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata( - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'bar'), + (GRPC_EXTENSION_HEADER.lower(), 'foo'), + (GRPC_EXTENSION_HEADER.lower(), 'bar'), ) def side_effect(request, context: ServerCallContext): @@ -379,8 +379,8 @@ def side_effect(request, context: ServerCallContext): mock_grpc_context.set_trailing_metadata.call_args.args[0] ) assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'baz'), + (GRPC_EXTENSION_HEADER.lower(), 'foo'), + (GRPC_EXTENSION_HEADER.lower(), 'baz'), } async def test_send_message_with_comma_separated_extensions( @@ -390,8 +390,8 @@ async def test_send_message_with_comma_separated_extensions( mock_grpc_context: AsyncMock, ) -> None: mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata( - (HTTP_EXTENSION_HEADER.lower(), 'foo ,, bar,'), - (HTTP_EXTENSION_HEADER.lower(), 'baz , bar'), + (GRPC_EXTENSION_HEADER.lower(), 'foo ,, bar,'), + (GRPC_EXTENSION_HEADER.lower(), 'baz , bar'), ) mock_request_handler.on_message_send.return_value = types.Message( message_id='1', @@ -415,8 +415,8 @@ async def test_send_streaming_message_with_extensions( mock_grpc_context: AsyncMock, ) -> None: mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata( - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'bar'), + (GRPC_EXTENSION_HEADER.lower(), 'foo'), + (GRPC_EXTENSION_HEADER.lower(), 'bar'), ) async def side_effect(request, context: ServerCallContext): @@ -450,6 +450,6 @@ async def side_effect(request, context: ServerCallContext): mock_grpc_context.set_trailing_metadata.call_args.args[0] ) assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'baz'), + (GRPC_EXTENSION_HEADER.lower(), 'foo'), + (GRPC_EXTENSION_HEADER.lower(), 'baz'), }