Skip to content

Commit be3c659

Browse files
committed
test: add extension propagation test in test_end_to_end.py
1 parent 1863359 commit be3c659

1 file changed

Lines changed: 81 additions & 2 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,21 @@
55
import httpx
66
import pytest
77
import pytest_asyncio
8+
9+
from google.protobuf.struct_pb2 import Struct, Value
810
from starlette.applications import Starlette
911

1012
from a2a.client.base_client import BaseClient
11-
from a2a.client.client import ClientConfig
13+
from a2a.client.client import ClientCallContext, ClientConfig
1214
from a2a.client.client_factory import ClientFactory
15+
from a2a.client.service_parameters import (
16+
ServiceParametersFactory,
17+
with_a2a_extensions,
18+
)
1319
from a2a.server.agent_execution import AgentExecutor, RequestContext
1420
from a2a.server.events import EventQueue
1521
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
16-
from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler
22+
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
1723
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
1824
from a2a.server.routes.rest_routes import create_rest_routes
1925
from a2a.server.tasks import TaskUpdater
@@ -87,6 +93,29 @@ class MockAgentExecutor(AgentExecutor):
8793
async def execute(self, context: RequestContext, event_queue: EventQueue):
8894
user_input = context.get_user_input()
8995

96+
# Extensions echo: report requested extensions and activate them.
97+
if user_input.startswith('Extensions:'):
98+
requested = sorted(context.requested_extensions)
99+
for ext_uri in requested:
100+
context.add_activated_extension(ext_uri)
101+
activated = sorted(context.call_context.activated_extensions)
102+
103+
payload = Struct()
104+
payload.update(
105+
{
106+
'requested_extensions': requested,
107+
'activated_extensions': activated,
108+
}
109+
)
110+
await event_queue.enqueue_event(
111+
Message(
112+
role=Role.ROLE_AGENT,
113+
message_id='ext-reply-1',
114+
parts=[Part(data=Value(struct_value=payload))],
115+
)
116+
)
117+
return
118+
90119
# Direct message response (no task created).
91120
if user_input.startswith('Message:'):
92121
await event_queue.enqueue_event(
@@ -757,3 +786,53 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups):
757786
Role.ROLE_AGENT,
758787
'Direct reply to: Message: Quick question',
759788
)
789+
790+
791+
@pytest.mark.asyncio
792+
@pytest.mark.parametrize(
793+
'streaming',
794+
[
795+
pytest.param(False, id='blocking'),
796+
pytest.param(True, id='streaming'),
797+
],
798+
)
799+
async def test_end_to_end_extensions_propagation(transport_setups, streaming):
800+
"""Test that extensions sent by the client reach the agent executor."""
801+
client = transport_setups.client
802+
client._config.streaming = streaming
803+
804+
extensions = [
805+
'https://example.com/ext/v1',
806+
'https://example.com/ext/v2',
807+
]
808+
service_params = ServiceParametersFactory.create(
809+
[with_a2a_extensions(extensions)]
810+
)
811+
context = ClientCallContext(service_parameters=service_params)
812+
813+
message_to_send = Message(
814+
role=Role.ROLE_USER,
815+
message_id='msg-ext-propagation',
816+
parts=[Part(text='Extensions: echo')],
817+
)
818+
819+
events = [
820+
event
821+
async for event in client.send_message(
822+
request=SendMessageRequest(message=message_to_send),
823+
context=context,
824+
)
825+
]
826+
827+
assert len(events) == 1
828+
response = events[0]
829+
assert response.HasField('message')
830+
part = response.message.parts[0]
831+
assert part.HasField('data')
832+
833+
payload = part.data.struct_value
834+
requested = set(payload['requested_extensions'])
835+
activated = set(payload['activated_extensions'])
836+
837+
assert requested == set(extensions)
838+
assert activated == set(extensions)

0 commit comments

Comments
 (0)