From be3c659184f8a6a46bd44af2b76e317c6b1c59f7 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 09:57:27 +0000 Subject: [PATCH 1/3] test: add extension propagation test in test_end_to_end.py --- tests/integration/test_end_to_end.py | 83 +++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 58dce528d..7df5e0256 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -5,15 +5,21 @@ import httpx import pytest import pytest_asyncio + +from google.protobuf.struct_pb2 import Struct, Value from starlette.applications import Starlette from a2a.client.base_client import BaseClient -from a2a.client.client import ClientConfig +from a2a.client.client import ClientCallContext, ClientConfig from a2a.client.client_factory import ClientFactory +from a2a.client.service_parameters import ( + ServiceParametersFactory, + with_a2a_extensions, +) from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager -from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler +from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import TaskUpdater @@ -87,6 +93,29 @@ class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): user_input = context.get_user_input() + # Extensions echo: report requested extensions and activate them. + if user_input.startswith('Extensions:'): + requested = sorted(context.requested_extensions) + for ext_uri in requested: + context.add_activated_extension(ext_uri) + activated = sorted(context.call_context.activated_extensions) + + payload = Struct() + payload.update( + { + 'requested_extensions': requested, + 'activated_extensions': activated, + } + ) + await event_queue.enqueue_event( + Message( + role=Role.ROLE_AGENT, + message_id='ext-reply-1', + parts=[Part(data=Value(struct_value=payload))], + ) + ) + return + # Direct message response (no task created). if user_input.startswith('Message:'): await event_queue.enqueue_event( @@ -757,3 +786,53 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups): Role.ROLE_AGENT, 'Direct reply to: Message: Quick question', ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'streaming', + [ + pytest.param(False, id='blocking'), + pytest.param(True, id='streaming'), + ], +) +async def test_end_to_end_extensions_propagation(transport_setups, streaming): + """Test that extensions sent by the client reach the agent executor.""" + client = transport_setups.client + client._config.streaming = streaming + + extensions = [ + 'https://example.com/ext/v1', + 'https://example.com/ext/v2', + ] + service_params = ServiceParametersFactory.create( + [with_a2a_extensions(extensions)] + ) + context = ClientCallContext(service_parameters=service_params) + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-ext-propagation', + parts=[Part(text='Extensions: echo')], + ) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send), + context=context, + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + part = response.message.parts[0] + assert part.HasField('data') + + payload = part.data.struct_value + requested = set(payload['requested_extensions']) + activated = set(payload['activated_extensions']) + + assert requested == set(extensions) + assert activated == set(extensions) From 88e93dc4ff2d5f4eb32927db69c038dfa3edc920 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 10:06:47 +0000 Subject: [PATCH 2/3] Update --- tests/integration/test_end_to_end.py | 34 +++++++++------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 7df5e0256..ac136df74 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -6,7 +6,6 @@ import pytest import pytest_asyncio -from google.protobuf.struct_pb2 import Struct, Value from starlette.applications import Starlette from a2a.client.base_client import BaseClient @@ -93,25 +92,19 @@ class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): user_input = context.get_user_input() - # Extensions echo: report requested extensions and activate them. + # Extensions echo: activate all requested extensions and report them + # back via the Message.extensions field. if user_input.startswith('Extensions:'): - requested = sorted(context.requested_extensions) - for ext_uri in requested: + for ext_uri in context.requested_extensions: context.add_activated_extension(ext_uri) - activated = sorted(context.call_context.activated_extensions) - - payload = Struct() - payload.update( - { - 'requested_extensions': requested, - 'activated_extensions': activated, - } - ) await event_queue.enqueue_event( Message( role=Role.ROLE_AGENT, message_id='ext-reply-1', - parts=[Part(data=Value(struct_value=payload))], + parts=[Part(text='extensions echoed')], + extensions=sorted( + context.call_context.activated_extensions + ), ) ) return @@ -827,12 +820,7 @@ async def test_end_to_end_extensions_propagation(transport_setups, streaming): assert len(events) == 1 response = events[0] assert response.HasField('message') - part = response.message.parts[0] - assert part.HasField('data') - - payload = part.data.struct_value - requested = set(payload['requested_extensions']) - activated = set(payload['activated_extensions']) - - assert requested == set(extensions) - assert activated == set(extensions) + assert_message_matches( + response.message, Role.ROLE_AGENT, 'extensions echoed' + ) + assert set(response.message.extensions) == set(extensions) From 597141b44b4207a8064e88fa3f0dfffb08fca04c Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 10:10:04 +0000 Subject: [PATCH 3/3] Update --- tests/integration/test_end_to_end.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index ac136df74..aea9784ad 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -26,6 +26,7 @@ from a2a.types import ( AgentCapabilities, AgentCard, + AgentExtension, AgentInterface, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, @@ -46,6 +47,12 @@ from a2a.utils.errors import InvalidParamsError +SUPPORTED_EXTENSION_URIS = [ + 'https://example.com/ext/v1', + 'https://example.com/ext/v2', +] + + def assert_message_matches(message, expected_role, expected_text): assert message.role == expected_role assert message.parts[0].text == expected_text @@ -164,7 +171,15 @@ def agent_card() -> AgentCard: description='Real in-memory integration testing.', version='1.0.0', capabilities=AgentCapabilities( - streaming=True, push_notifications=False + streaming=True, + push_notifications=False, + extensions=[ + AgentExtension( + uri=uri, + description=f'Test extension {uri}', + ) + for uri in SUPPORTED_EXTENSION_URIS + ], ), skills=[], default_input_modes=['text/plain'], @@ -794,12 +809,8 @@ async def test_end_to_end_extensions_propagation(transport_setups, streaming): client = transport_setups.client client._config.streaming = streaming - extensions = [ - 'https://example.com/ext/v1', - 'https://example.com/ext/v2', - ] service_params = ServiceParametersFactory.create( - [with_a2a_extensions(extensions)] + [with_a2a_extensions(SUPPORTED_EXTENSION_URIS)] ) context = ClientCallContext(service_parameters=service_params) @@ -823,4 +834,4 @@ async def test_end_to_end_extensions_propagation(transport_setups, streaming): assert_message_matches( response.message, Role.ROLE_AGENT, 'extensions echoed' ) - assert set(response.message.extensions) == set(extensions) + assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS)