Skip to content

Commit 2846be6

Browse files
authored
test: add extension propagation test in test_end_to_end.py (#981)
1 parent f922ff6 commit 2846be6

1 file changed

Lines changed: 81 additions & 3 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,28 @@
55
import httpx
66
import pytest
77
import pytest_asyncio
8+
89
from starlette.applications import Starlette
910

1011
from a2a.client.base_client import BaseClient
11-
from a2a.client.client import ClientConfig
12+
from a2a.client.client import ClientCallContext, ClientConfig
1213
from a2a.client.client_factory import ClientFactory
14+
from a2a.client.service_parameters import (
15+
ServiceParametersFactory,
16+
with_a2a_extensions,
17+
)
1318
from a2a.server.agent_execution import AgentExecutor, RequestContext
1419
from a2a.server.events import EventQueue
1520
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
16-
from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler
21+
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
1722
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
1823
from a2a.server.routes.rest_routes import create_rest_routes
1924
from a2a.server.tasks import TaskUpdater
2025
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
2126
from a2a.types import (
2227
AgentCapabilities,
2328
AgentCard,
29+
AgentExtension,
2430
AgentInterface,
2531
CancelTaskRequest,
2632
DeleteTaskPushNotificationConfigRequest,
@@ -41,6 +47,12 @@
4147
from a2a.utils.errors import InvalidParamsError
4248

4349

50+
SUPPORTED_EXTENSION_URIS = [
51+
'https://example.com/ext/v1',
52+
'https://example.com/ext/v2',
53+
]
54+
55+
4456
def assert_message_matches(message, expected_role, expected_text):
4557
assert message.role == expected_role
4658
assert message.parts[0].text == expected_text
@@ -87,6 +99,23 @@ class MockAgentExecutor(AgentExecutor):
8799
async def execute(self, context: RequestContext, event_queue: EventQueue):
88100
user_input = context.get_user_input()
89101

102+
# Extensions echo: activate all requested extensions and report them
103+
# back via the Message.extensions field.
104+
if user_input.startswith('Extensions:'):
105+
for ext_uri in context.requested_extensions:
106+
context.add_activated_extension(ext_uri)
107+
await event_queue.enqueue_event(
108+
Message(
109+
role=Role.ROLE_AGENT,
110+
message_id='ext-reply-1',
111+
parts=[Part(text='extensions echoed')],
112+
extensions=sorted(
113+
context.call_context.activated_extensions
114+
),
115+
)
116+
)
117+
return
118+
90119
# Direct message response (no task created).
91120
if user_input.startswith('Message:'):
92121
await event_queue.enqueue_event(
@@ -142,7 +171,15 @@ def agent_card() -> AgentCard:
142171
description='Real in-memory integration testing.',
143172
version='1.0.0',
144173
capabilities=AgentCapabilities(
145-
streaming=True, push_notifications=False
174+
streaming=True,
175+
push_notifications=False,
176+
extensions=[
177+
AgentExtension(
178+
uri=uri,
179+
description=f'Test extension {uri}',
180+
)
181+
for uri in SUPPORTED_EXTENSION_URIS
182+
],
146183
),
147184
skills=[],
148185
default_input_modes=['text/plain'],
@@ -757,3 +794,44 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups):
757794
Role.ROLE_AGENT,
758795
'Direct reply to: Message: Quick question',
759796
)
797+
798+
799+
@pytest.mark.asyncio
800+
@pytest.mark.parametrize(
801+
'streaming',
802+
[
803+
pytest.param(False, id='blocking'),
804+
pytest.param(True, id='streaming'),
805+
],
806+
)
807+
async def test_end_to_end_extensions_propagation(transport_setups, streaming):
808+
"""Test that extensions sent by the client reach the agent executor."""
809+
client = transport_setups.client
810+
client._config.streaming = streaming
811+
812+
service_params = ServiceParametersFactory.create(
813+
[with_a2a_extensions(SUPPORTED_EXTENSION_URIS)]
814+
)
815+
context = ClientCallContext(service_parameters=service_params)
816+
817+
message_to_send = Message(
818+
role=Role.ROLE_USER,
819+
message_id='msg-ext-propagation',
820+
parts=[Part(text='Extensions: echo')],
821+
)
822+
823+
events = [
824+
event
825+
async for event in client.send_message(
826+
request=SendMessageRequest(message=message_to_send),
827+
context=context,
828+
)
829+
]
830+
831+
assert len(events) == 1
832+
response = events[0]
833+
assert response.HasField('message')
834+
assert_message_matches(
835+
response.message, Role.ROLE_AGENT, 'extensions echoed'
836+
)
837+
assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS)

0 commit comments

Comments
 (0)