|
5 | 5 | import httpx |
6 | 6 | import pytest |
7 | 7 | import pytest_asyncio |
| 8 | + |
| 9 | +from google.protobuf.struct_pb2 import Struct, Value |
8 | 10 | from starlette.applications import Starlette |
9 | 11 |
|
10 | 12 | from a2a.client.base_client import BaseClient |
11 | | -from a2a.client.client import ClientConfig |
| 13 | +from a2a.client.client import ClientCallContext, ClientConfig |
12 | 14 | from a2a.client.client_factory import ClientFactory |
| 15 | +from a2a.client.service_parameters import ( |
| 16 | + ServiceParametersFactory, |
| 17 | + with_a2a_extensions, |
| 18 | +) |
13 | 19 | from a2a.server.agent_execution import AgentExecutor, RequestContext |
14 | 20 | from a2a.server.events import EventQueue |
15 | 21 | from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager |
16 | | -from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler |
| 22 | +from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler |
17 | 23 | from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes |
18 | 24 | from a2a.server.routes.rest_routes import create_rest_routes |
19 | 25 | from a2a.server.tasks import TaskUpdater |
@@ -87,6 +93,29 @@ class MockAgentExecutor(AgentExecutor): |
87 | 93 | async def execute(self, context: RequestContext, event_queue: EventQueue): |
88 | 94 | user_input = context.get_user_input() |
89 | 95 |
|
| 96 | + # Extensions echo: report requested extensions and activate them. |
| 97 | + if user_input.startswith('Extensions:'): |
| 98 | + requested = sorted(context.requested_extensions) |
| 99 | + for ext_uri in requested: |
| 100 | + context.add_activated_extension(ext_uri) |
| 101 | + activated = sorted(context.call_context.activated_extensions) |
| 102 | + |
| 103 | + payload = Struct() |
| 104 | + payload.update( |
| 105 | + { |
| 106 | + 'requested_extensions': requested, |
| 107 | + 'activated_extensions': activated, |
| 108 | + } |
| 109 | + ) |
| 110 | + await event_queue.enqueue_event( |
| 111 | + Message( |
| 112 | + role=Role.ROLE_AGENT, |
| 113 | + message_id='ext-reply-1', |
| 114 | + parts=[Part(data=Value(struct_value=payload))], |
| 115 | + ) |
| 116 | + ) |
| 117 | + return |
| 118 | + |
90 | 119 | # Direct message response (no task created). |
91 | 120 | if user_input.startswith('Message:'): |
92 | 121 | await event_queue.enqueue_event( |
@@ -757,3 +786,53 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups): |
757 | 786 | Role.ROLE_AGENT, |
758 | 787 | 'Direct reply to: Message: Quick question', |
759 | 788 | ) |
| 789 | + |
| 790 | + |
| 791 | +@pytest.mark.asyncio |
| 792 | +@pytest.mark.parametrize( |
| 793 | + 'streaming', |
| 794 | + [ |
| 795 | + pytest.param(False, id='blocking'), |
| 796 | + pytest.param(True, id='streaming'), |
| 797 | + ], |
| 798 | +) |
| 799 | +async def test_end_to_end_extensions_propagation(transport_setups, streaming): |
| 800 | + """Test that extensions sent by the client reach the agent executor.""" |
| 801 | + client = transport_setups.client |
| 802 | + client._config.streaming = streaming |
| 803 | + |
| 804 | + extensions = [ |
| 805 | + 'https://example.com/ext/v1', |
| 806 | + 'https://example.com/ext/v2', |
| 807 | + ] |
| 808 | + service_params = ServiceParametersFactory.create( |
| 809 | + [with_a2a_extensions(extensions)] |
| 810 | + ) |
| 811 | + context = ClientCallContext(service_parameters=service_params) |
| 812 | + |
| 813 | + message_to_send = Message( |
| 814 | + role=Role.ROLE_USER, |
| 815 | + message_id='msg-ext-propagation', |
| 816 | + parts=[Part(text='Extensions: echo')], |
| 817 | + ) |
| 818 | + |
| 819 | + events = [ |
| 820 | + event |
| 821 | + async for event in client.send_message( |
| 822 | + request=SendMessageRequest(message=message_to_send), |
| 823 | + context=context, |
| 824 | + ) |
| 825 | + ] |
| 826 | + |
| 827 | + assert len(events) == 1 |
| 828 | + response = events[0] |
| 829 | + assert response.HasField('message') |
| 830 | + part = response.message.parts[0] |
| 831 | + assert part.HasField('data') |
| 832 | + |
| 833 | + payload = part.data.struct_value |
| 834 | + requested = set(payload['requested_extensions']) |
| 835 | + activated = set(payload['activated_extensions']) |
| 836 | + |
| 837 | + assert requested == set(extensions) |
| 838 | + assert activated == set(extensions) |
0 commit comments