diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 58dce528d..aea9784ad 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -5,15 +5,20 @@ import httpx import pytest import pytest_asyncio + from starlette.applications import Starlette from a2a.client.base_client import BaseClient -from a2a.client.client import ClientConfig +from a2a.client.client import ClientCallContext, ClientConfig from a2a.client.client_factory import ClientFactory +from a2a.client.service_parameters import ( + ServiceParametersFactory, + with_a2a_extensions, +) from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager -from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler +from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import TaskUpdater @@ -21,6 +26,7 @@ from a2a.types import ( AgentCapabilities, AgentCard, + AgentExtension, AgentInterface, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, @@ -41,6 +47,12 @@ from a2a.utils.errors import InvalidParamsError +SUPPORTED_EXTENSION_URIS = [ + 'https://example.com/ext/v1', + 'https://example.com/ext/v2', +] + + def assert_message_matches(message, expected_role, expected_text): assert message.role == expected_role assert message.parts[0].text == expected_text @@ -87,6 +99,23 @@ class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): user_input = context.get_user_input() + # Extensions echo: activate all requested extensions and report them + # back via the Message.extensions field. + if user_input.startswith('Extensions:'): + for ext_uri in context.requested_extensions: + context.add_activated_extension(ext_uri) + await event_queue.enqueue_event( + Message( + role=Role.ROLE_AGENT, + message_id='ext-reply-1', + parts=[Part(text='extensions echoed')], + extensions=sorted( + context.call_context.activated_extensions + ), + ) + ) + return + # Direct message response (no task created). if user_input.startswith('Message:'): await event_queue.enqueue_event( @@ -142,7 +171,15 @@ def agent_card() -> AgentCard: description='Real in-memory integration testing.', version='1.0.0', capabilities=AgentCapabilities( - streaming=True, push_notifications=False + streaming=True, + push_notifications=False, + extensions=[ + AgentExtension( + uri=uri, + description=f'Test extension {uri}', + ) + for uri in SUPPORTED_EXTENSION_URIS + ], ), skills=[], default_input_modes=['text/plain'], @@ -757,3 +794,44 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups): Role.ROLE_AGENT, 'Direct reply to: Message: Quick question', ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'streaming', + [ + pytest.param(False, id='blocking'), + pytest.param(True, id='streaming'), + ], +) +async def test_end_to_end_extensions_propagation(transport_setups, streaming): + """Test that extensions sent by the client reach the agent executor.""" + client = transport_setups.client + client._config.streaming = streaming + + service_params = ServiceParametersFactory.create( + [with_a2a_extensions(SUPPORTED_EXTENSION_URIS)] + ) + context = ClientCallContext(service_parameters=service_params) + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-ext-propagation', + parts=[Part(text='Extensions: echo')], + ) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send), + context=context, + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert_message_matches( + response.message, Role.ROLE_AGENT, 'extensions echoed' + ) + assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS)