|
6 | 6 | import pytest |
7 | 7 | import pytest_asyncio |
8 | 8 |
|
9 | | -from google.protobuf.struct_pb2 import Struct, Value |
10 | 9 | from starlette.applications import Starlette |
11 | 10 |
|
12 | 11 | from a2a.client.base_client import BaseClient |
@@ -93,25 +92,19 @@ class MockAgentExecutor(AgentExecutor): |
93 | 92 | async def execute(self, context: RequestContext, event_queue: EventQueue): |
94 | 93 | user_input = context.get_user_input() |
95 | 94 |
|
96 | | - # Extensions echo: report requested extensions and activate them. |
| 95 | + # Extensions echo: activate all requested extensions and report them |
| 96 | + # back via the Message.extensions field. |
97 | 97 | if user_input.startswith('Extensions:'): |
98 | | - requested = sorted(context.requested_extensions) |
99 | | - for ext_uri in requested: |
| 98 | + for ext_uri in context.requested_extensions: |
100 | 99 | context.add_activated_extension(ext_uri) |
101 | | - activated = sorted(context.call_context.activated_extensions) |
102 | | - |
103 | | - payload = Struct() |
104 | | - payload.update( |
105 | | - { |
106 | | - 'requested_extensions': requested, |
107 | | - 'activated_extensions': activated, |
108 | | - } |
109 | | - ) |
110 | 100 | await event_queue.enqueue_event( |
111 | 101 | Message( |
112 | 102 | role=Role.ROLE_AGENT, |
113 | 103 | message_id='ext-reply-1', |
114 | | - parts=[Part(data=Value(struct_value=payload))], |
| 104 | + parts=[Part(text='extensions echoed')], |
| 105 | + extensions=sorted( |
| 106 | + context.call_context.activated_extensions |
| 107 | + ), |
115 | 108 | ) |
116 | 109 | ) |
117 | 110 | return |
@@ -827,12 +820,7 @@ async def test_end_to_end_extensions_propagation(transport_setups, streaming): |
827 | 820 | assert len(events) == 1 |
828 | 821 | response = events[0] |
829 | 822 | assert response.HasField('message') |
830 | | - part = response.message.parts[0] |
831 | | - assert part.HasField('data') |
832 | | - |
833 | | - payload = part.data.struct_value |
834 | | - requested = set(payload['requested_extensions']) |
835 | | - activated = set(payload['activated_extensions']) |
836 | | - |
837 | | - assert requested == set(extensions) |
838 | | - assert activated == set(extensions) |
| 823 | + assert_message_matches( |
| 824 | + response.message, Role.ROLE_AGENT, 'extensions echoed' |
| 825 | + ) |
| 826 | + assert set(response.message.extensions) == set(extensions) |
0 commit comments