|
5 | 5 | import httpx |
6 | 6 | import pytest |
7 | 7 | import pytest_asyncio |
| 8 | + |
8 | 9 | from starlette.applications import Starlette |
9 | 10 |
|
10 | 11 | from a2a.client.base_client import BaseClient |
11 | | -from a2a.client.client import ClientConfig |
| 12 | +from a2a.client.client import ClientCallContext, ClientConfig |
12 | 13 | from a2a.client.client_factory import ClientFactory |
| 14 | +from a2a.client.service_parameters import ( |
| 15 | + ServiceParametersFactory, |
| 16 | + with_a2a_extensions, |
| 17 | +) |
13 | 18 | from a2a.server.agent_execution import AgentExecutor, RequestContext |
14 | 19 | from a2a.server.events import EventQueue |
15 | 20 | from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager |
16 | | -from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler |
| 21 | +from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler |
17 | 22 | from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes |
18 | 23 | from a2a.server.routes.rest_routes import create_rest_routes |
19 | 24 | from a2a.server.tasks import TaskUpdater |
20 | 25 | from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore |
21 | 26 | from a2a.types import ( |
22 | 27 | AgentCapabilities, |
23 | 28 | AgentCard, |
| 29 | + AgentExtension, |
24 | 30 | AgentInterface, |
25 | 31 | CancelTaskRequest, |
26 | 32 | DeleteTaskPushNotificationConfigRequest, |
|
41 | 47 | from a2a.utils.errors import InvalidParamsError |
42 | 48 |
|
43 | 49 |
|
| 50 | +SUPPORTED_EXTENSION_URIS = [ |
| 51 | + 'https://example.com/ext/v1', |
| 52 | + 'https://example.com/ext/v2', |
| 53 | +] |
| 54 | + |
| 55 | + |
44 | 56 | def assert_message_matches(message, expected_role, expected_text): |
45 | 57 | assert message.role == expected_role |
46 | 58 | assert message.parts[0].text == expected_text |
@@ -87,6 +99,23 @@ class MockAgentExecutor(AgentExecutor): |
87 | 99 | async def execute(self, context: RequestContext, event_queue: EventQueue): |
88 | 100 | user_input = context.get_user_input() |
89 | 101 |
|
| 102 | + # Extensions echo: activate all requested extensions and report them |
| 103 | + # back via the Message.extensions field. |
| 104 | + if user_input.startswith('Extensions:'): |
| 105 | + for ext_uri in context.requested_extensions: |
| 106 | + context.add_activated_extension(ext_uri) |
| 107 | + await event_queue.enqueue_event( |
| 108 | + Message( |
| 109 | + role=Role.ROLE_AGENT, |
| 110 | + message_id='ext-reply-1', |
| 111 | + parts=[Part(text='extensions echoed')], |
| 112 | + extensions=sorted( |
| 113 | + context.call_context.activated_extensions |
| 114 | + ), |
| 115 | + ) |
| 116 | + ) |
| 117 | + return |
| 118 | + |
90 | 119 | # Direct message response (no task created). |
91 | 120 | if user_input.startswith('Message:'): |
92 | 121 | await event_queue.enqueue_event( |
@@ -142,7 +171,15 @@ def agent_card() -> AgentCard: |
142 | 171 | description='Real in-memory integration testing.', |
143 | 172 | version='1.0.0', |
144 | 173 | capabilities=AgentCapabilities( |
145 | | - streaming=True, push_notifications=False |
| 174 | + streaming=True, |
| 175 | + push_notifications=False, |
| 176 | + extensions=[ |
| 177 | + AgentExtension( |
| 178 | + uri=uri, |
| 179 | + description=f'Test extension {uri}', |
| 180 | + ) |
| 181 | + for uri in SUPPORTED_EXTENSION_URIS |
| 182 | + ], |
146 | 183 | ), |
147 | 184 | skills=[], |
148 | 185 | default_input_modes=['text/plain'], |
@@ -757,3 +794,44 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups): |
757 | 794 | Role.ROLE_AGENT, |
758 | 795 | 'Direct reply to: Message: Quick question', |
759 | 796 | ) |
| 797 | + |
| 798 | + |
| 799 | +@pytest.mark.asyncio |
| 800 | +@pytest.mark.parametrize( |
| 801 | + 'streaming', |
| 802 | + [ |
| 803 | + pytest.param(False, id='blocking'), |
| 804 | + pytest.param(True, id='streaming'), |
| 805 | + ], |
| 806 | +) |
| 807 | +async def test_end_to_end_extensions_propagation(transport_setups, streaming): |
| 808 | + """Test that extensions sent by the client reach the agent executor.""" |
| 809 | + client = transport_setups.client |
| 810 | + client._config.streaming = streaming |
| 811 | + |
| 812 | + service_params = ServiceParametersFactory.create( |
| 813 | + [with_a2a_extensions(SUPPORTED_EXTENSION_URIS)] |
| 814 | + ) |
| 815 | + context = ClientCallContext(service_parameters=service_params) |
| 816 | + |
| 817 | + message_to_send = Message( |
| 818 | + role=Role.ROLE_USER, |
| 819 | + message_id='msg-ext-propagation', |
| 820 | + parts=[Part(text='Extensions: echo')], |
| 821 | + ) |
| 822 | + |
| 823 | + events = [ |
| 824 | + event |
| 825 | + async for event in client.send_message( |
| 826 | + request=SendMessageRequest(message=message_to_send), |
| 827 | + context=context, |
| 828 | + ) |
| 829 | + ] |
| 830 | + |
| 831 | + assert len(events) == 1 |
| 832 | + response = events[0] |
| 833 | + assert response.HasField('message') |
| 834 | + assert_message_matches( |
| 835 | + response.message, Role.ROLE_AGENT, 'extensions echoed' |
| 836 | + ) |
| 837 | + assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS) |
0 commit comments