Skip to content

Commit 88e93dc

Browse files
committed
Update
1 parent be3c659 commit 88e93dc

1 file changed

Lines changed: 11 additions & 23 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pytest
77
import pytest_asyncio
88

9-
from google.protobuf.struct_pb2 import Struct, Value
109
from starlette.applications import Starlette
1110

1211
from a2a.client.base_client import BaseClient
@@ -93,25 +92,19 @@ class MockAgentExecutor(AgentExecutor):
9392
async def execute(self, context: RequestContext, event_queue: EventQueue):
9493
user_input = context.get_user_input()
9594

96-
# Extensions echo: report requested extensions and activate them.
95+
# Extensions echo: activate all requested extensions and report them
96+
# back via the Message.extensions field.
9797
if user_input.startswith('Extensions:'):
98-
requested = sorted(context.requested_extensions)
99-
for ext_uri in requested:
98+
for ext_uri in context.requested_extensions:
10099
context.add_activated_extension(ext_uri)
101-
activated = sorted(context.call_context.activated_extensions)
102-
103-
payload = Struct()
104-
payload.update(
105-
{
106-
'requested_extensions': requested,
107-
'activated_extensions': activated,
108-
}
109-
)
110100
await event_queue.enqueue_event(
111101
Message(
112102
role=Role.ROLE_AGENT,
113103
message_id='ext-reply-1',
114-
parts=[Part(data=Value(struct_value=payload))],
104+
parts=[Part(text='extensions echoed')],
105+
extensions=sorted(
106+
context.call_context.activated_extensions
107+
),
115108
)
116109
)
117110
return
@@ -827,12 +820,7 @@ async def test_end_to_end_extensions_propagation(transport_setups, streaming):
827820
assert len(events) == 1
828821
response = events[0]
829822
assert response.HasField('message')
830-
part = response.message.parts[0]
831-
assert part.HasField('data')
832-
833-
payload = part.data.struct_value
834-
requested = set(payload['requested_extensions'])
835-
activated = set(payload['activated_extensions'])
836-
837-
assert requested == set(extensions)
838-
assert activated == set(extensions)
823+
assert_message_matches(
824+
response.message, Role.ROLE_AGENT, 'extensions echoed'
825+
)
826+
assert set(response.message.extensions) == set(extensions)

0 commit comments

Comments
 (0)