Skip to content

Commit b4e7671

Browse files
Merge branch '1.0-dev' into guglielmoc/refactor_utils_and_helpers
2 parents 12baed7 + 2846be6 commit b4e7671

1 file changed

Lines changed: 81 additions & 3 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,28 @@
55
import httpx
66
import pytest
77
import pytest_asyncio
8+
89
from starlette.applications import Starlette
910

1011
from a2a.client.base_client import BaseClient
11-
from a2a.client.client import ClientConfig
12+
from a2a.client.client import ClientCallContext, ClientConfig
1213
from a2a.client.client_factory import ClientFactory
14+
from a2a.client.service_parameters import (
15+
ServiceParametersFactory,
16+
with_a2a_extensions,
17+
)
1318
from a2a.server.agent_execution import AgentExecutor, RequestContext
1419
from a2a.server.events import EventQueue
1520
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
16-
from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler
21+
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
1722
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
1823
from a2a.server.routes.rest_routes import create_rest_routes
1924
from a2a.server.tasks import TaskUpdater
2025
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
2126
from a2a.types import (
2227
AgentCapabilities,
2328
AgentCard,
29+
AgentExtension,
2430
AgentInterface,
2531
CancelTaskRequest,
2632
DeleteTaskPushNotificationConfigRequest,
@@ -42,6 +48,12 @@
4248
from a2a.utils.errors import InvalidParamsError
4349

4450

51+
SUPPORTED_EXTENSION_URIS = [
52+
'https://example.com/ext/v1',
53+
'https://example.com/ext/v2',
54+
]
55+
56+
4557
def assert_message_matches(message, expected_role, expected_text):
4658
assert message.role == expected_role
4759
assert message.parts[0].text == expected_text
@@ -88,6 +100,23 @@ class MockAgentExecutor(AgentExecutor):
88100
async def execute(self, context: RequestContext, event_queue: EventQueue):
89101
user_input = context.get_user_input()
90102

103+
# Extensions echo: activate all requested extensions and report them
104+
# back via the Message.extensions field.
105+
if user_input.startswith('Extensions:'):
106+
for ext_uri in context.requested_extensions:
107+
context.add_activated_extension(ext_uri)
108+
await event_queue.enqueue_event(
109+
Message(
110+
role=Role.ROLE_AGENT,
111+
message_id='ext-reply-1',
112+
parts=[Part(text='extensions echoed')],
113+
extensions=sorted(
114+
context.call_context.activated_extensions
115+
),
116+
)
117+
)
118+
return
119+
91120
# Direct message response (no task created).
92121
if user_input.startswith('Message:'):
93122
await event_queue.enqueue_event(
@@ -143,7 +172,15 @@ def agent_card() -> AgentCard:
143172
description='Real in-memory integration testing.',
144173
version='1.0.0',
145174
capabilities=AgentCapabilities(
146-
streaming=True, push_notifications=False
175+
streaming=True,
176+
push_notifications=False,
177+
extensions=[
178+
AgentExtension(
179+
uri=uri,
180+
description=f'Test extension {uri}',
181+
)
182+
for uri in SUPPORTED_EXTENSION_URIS
183+
],
147184
),
148185
skills=[],
149186
default_input_modes=['text/plain'],
@@ -758,3 +795,44 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups):
758795
Role.ROLE_AGENT,
759796
'Direct reply to: Message: Quick question',
760797
)
798+
799+
800+
@pytest.mark.asyncio
801+
@pytest.mark.parametrize(
802+
'streaming',
803+
[
804+
pytest.param(False, id='blocking'),
805+
pytest.param(True, id='streaming'),
806+
],
807+
)
808+
async def test_end_to_end_extensions_propagation(transport_setups, streaming):
809+
"""Test that extensions sent by the client reach the agent executor."""
810+
client = transport_setups.client
811+
client._config.streaming = streaming
812+
813+
service_params = ServiceParametersFactory.create(
814+
[with_a2a_extensions(SUPPORTED_EXTENSION_URIS)]
815+
)
816+
context = ClientCallContext(service_parameters=service_params)
817+
818+
message_to_send = Message(
819+
role=Role.ROLE_USER,
820+
message_id='msg-ext-propagation',
821+
parts=[Part(text='Extensions: echo')],
822+
)
823+
824+
events = [
825+
event
826+
async for event in client.send_message(
827+
request=SendMessageRequest(message=message_to_send),
828+
context=context,
829+
)
830+
]
831+
832+
assert len(events) == 1
833+
response = events[0]
834+
assert response.HasField('message')
835+
assert_message_matches(
836+
response.message, Role.ROLE_AGENT, 'extensions echoed'
837+
)
838+
assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS)

0 commit comments

Comments
 (0)