Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 81 additions & 3 deletions tests/integration/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,28 @@
import httpx
import pytest
import pytest_asyncio

from starlette.applications import Starlette

from a2a.client.base_client import BaseClient
from a2a.client.client import ClientConfig
from a2a.client.client import ClientCallContext, ClientConfig
from a2a.client.client_factory import ClientFactory
from a2a.client.service_parameters import (
ServiceParametersFactory,
with_a2a_extensions,
)
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
from a2a.server.routes.rest_routes import create_rest_routes
from a2a.server.tasks import TaskUpdater
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentExtension,
AgentInterface,
CancelTaskRequest,
DeleteTaskPushNotificationConfigRequest,
Expand All @@ -41,6 +47,12 @@
from a2a.utils.errors import InvalidParamsError


SUPPORTED_EXTENSION_URIS = [
'https://example.com/ext/v1',
'https://example.com/ext/v2',
]


def assert_message_matches(message, expected_role, expected_text):
assert message.role == expected_role
assert message.parts[0].text == expected_text
Expand Down Expand Up @@ -87,6 +99,23 @@ class MockAgentExecutor(AgentExecutor):
async def execute(self, context: RequestContext, event_queue: EventQueue):
user_input = context.get_user_input()

# Extensions echo: activate all requested extensions and report them
# back via the Message.extensions field.
if user_input.startswith('Extensions:'):
for ext_uri in context.requested_extensions:
context.add_activated_extension(ext_uri)
await event_queue.enqueue_event(
Message(
role=Role.ROLE_AGENT,
message_id='ext-reply-1',
parts=[Part(text='extensions echoed')],
extensions=sorted(
context.call_context.activated_extensions
),
)
)
return

# Direct message response (no task created).
if user_input.startswith('Message:'):
await event_queue.enqueue_event(
Expand Down Expand Up @@ -142,7 +171,15 @@ def agent_card() -> AgentCard:
description='Real in-memory integration testing.',
version='1.0.0',
capabilities=AgentCapabilities(
streaming=True, push_notifications=False
streaming=True,
push_notifications=False,
extensions=[
AgentExtension(
uri=uri,
description=f'Test extension {uri}',
)
for uri in SUPPORTED_EXTENSION_URIS
],
),
skills=[],
default_input_modes=['text/plain'],
Expand Down Expand Up @@ -757,3 +794,44 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups):
Role.ROLE_AGENT,
'Direct reply to: Message: Quick question',
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
'streaming',
[
pytest.param(False, id='blocking'),
pytest.param(True, id='streaming'),
],
)
async def test_end_to_end_extensions_propagation(transport_setups, streaming):
"""Test that extensions sent by the client reach the agent executor."""
client = transport_setups.client
client._config.streaming = streaming

service_params = ServiceParametersFactory.create(
[with_a2a_extensions(SUPPORTED_EXTENSION_URIS)]
)
context = ClientCallContext(service_parameters=service_params)

message_to_send = Message(
role=Role.ROLE_USER,
message_id='msg-ext-propagation',
parts=[Part(text='Extensions: echo')],
)

events = [
event
async for event in client.send_message(
request=SendMessageRequest(message=message_to_send),
context=context,
)
]

assert len(events) == 1
response = events[0]
assert response.HasField('message')
assert_message_matches(
response.message, Role.ROLE_AGENT, 'extensions echoed'
)
assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS)
Loading