Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 134 additions & 35 deletions tests/integration/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
import pytest
import pytest_asyncio

from starlette.applications import Starlette

from a2a.client.base_client import BaseClient
from a2a.client.client import ClientConfig
from a2a.client.client_factory import ClientFactory
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.routes.rest_routes import create_rest_routes
from starlette.applications import Starlette
from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes
from a2a.server.events import EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
from a2a.server.request_handlers import GrpcHandler, LegacyRequestHandler
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
from a2a.server.routes.rest_routes import create_rest_routes
from a2a.server.tasks import TaskUpdater
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.types import (
Expand All @@ -37,7 +38,7 @@
TaskState,
a2a_pb2_grpc,
)
from a2a.utils import TransportProtocol
from a2a.utils import TransportProtocol, new_task
from a2a.utils.errors import InvalidParamsError


Expand Down Expand Up @@ -69,7 +70,9 @@ def assert_events_match(events, expected_events):
events, expected_events, strict=True
):
assert event.HasField(expected_type)
if expected_type == 'status_update':
if expected_type == 'task':
assert event.task.status.state == expected_val
elif expected_type == 'status_update':
assert event.status_update.status.state == expected_val
elif expected_type == 'artifact_update':
if expected_val is not None:
Expand All @@ -83,26 +86,30 @@ def assert_events_match(events, expected_events):

class MockAgentExecutor(AgentExecutor):
async def execute(self, context: RequestContext, event_queue: EventQueue):
task_updater = TaskUpdater(
event_queue,
context.task_id,
context.context_id,
)
user_input = context.get_user_input()

is_input_required_resumption = (
context.current_task is not None
and context.current_task.status.state
== TaskState.TASK_STATE_INPUT_REQUIRED
)

if not is_input_required_resumption:
await task_updater.update_status(
TaskState.TASK_STATE_SUBMITTED,
message=task_updater.new_agent_message(
[Part(text='task submitted')]
),
# Direct message response (no task created).
if user_input.startswith('Message:'):
await event_queue.enqueue_event(
Message(
role=Role.ROLE_AGENT,
message_id='direct-reply-1',
parts=[Part(text=f'Direct reply to: {user_input}')],
)
)
return

# Task-based response.
task = context.current_task
if not task:
task = new_task(context.message)
await event_queue.enqueue_event(task)

task_updater = TaskUpdater(
event_queue,
task.id,
task.context_id,
)

await task_updater.update_status(
TaskState.TASK_STATE_WORKING,
Expand Down Expand Up @@ -168,7 +175,7 @@ class ClientSetup(NamedTuple):
@pytest.fixture
def base_e2e_setup(agent_card):
task_store = InMemoryTaskStore()
handler = DefaultRequestHandler(
handler = LegacyRequestHandler(
agent_executor=MockAgentExecutor(),
task_store=task_store,
agent_card=agent_card,
Expand Down Expand Up @@ -328,7 +335,6 @@ async def test_end_to_end_send_message_blocking(transport_setups):
response.task.history,
[
(Role.ROLE_USER, 'Run dummy agent!'),
(Role.ROLE_AGENT, 'task submitted'),
(Role.ROLE_AGENT, 'task working'),
],
)
Expand Down Expand Up @@ -386,20 +392,19 @@ async def test_end_to_end_send_message_streaming(transport_setups):
assert_events_match(
events,
[
('status_update', TaskState.TASK_STATE_SUBMITTED),
('task', TaskState.TASK_STATE_SUBMITTED),
('status_update', TaskState.TASK_STATE_WORKING),
('artifact_update', [('test-artifact', 'artifact content')]),
('status_update', TaskState.TASK_STATE_COMPLETED),
],
)

task_id = events[0].status_update.task_id
task_id = events[0].task.id
task = await client.get_task(request=GetTaskRequest(id=task_id))
assert_history_matches(
task.history,
[
(Role.ROLE_USER, 'Run dummy agent!'),
(Role.ROLE_AGENT, 'task submitted'),
(Role.ROLE_AGENT, 'task working'),
],
)
Expand All @@ -423,7 +428,7 @@ async def test_end_to_end_get_task(transport_setups):
)
]
response = events[0]
task_id = response.status_update.task_id
task_id = response.task.id

get_request = GetTaskRequest(id=task_id)
retrieved_task = await client.get_task(request=get_request)
Expand All @@ -438,7 +443,6 @@ async def test_end_to_end_get_task(transport_setups):
retrieved_task.history,
[
(Role.ROLE_USER, 'Test Get Task'),
(Role.ROLE_AGENT, 'task submitted'),
(Role.ROLE_AGENT, 'task working'),
],
)
Expand All @@ -465,7 +469,7 @@ async def test_end_to_end_list_tasks(transport_setups):
)
)
)
expected_task_ids.append(response.status_update.task_id)
expected_task_ids.append(response.task.id)

list_request = ListTasksRequest(page_size=page_size)

Expand Down Expand Up @@ -514,21 +518,20 @@ async def test_end_to_end_input_required(transport_setups):
assert_events_match(
events,
[
('status_update', TaskState.TASK_STATE_SUBMITTED),
('task', TaskState.TASK_STATE_SUBMITTED),
('status_update', TaskState.TASK_STATE_WORKING),
('status_update', TaskState.TASK_STATE_INPUT_REQUIRED),
],
)

task_id = events[0].status_update.task_id
task_id = events[0].task.id
task = await client.get_task(request=GetTaskRequest(id=task_id))

assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED
assert_history_matches(
task.history,
[
(Role.ROLE_USER, 'Need input'),
(Role.ROLE_AGENT, 'task submitted'),
(Role.ROLE_AGENT, 'task working'),
],
)
Expand Down Expand Up @@ -572,7 +575,6 @@ async def test_end_to_end_input_required(transport_setups):
task.history,
[
(Role.ROLE_USER, 'Need input'),
(Role.ROLE_AGENT, 'task submitted'),
(Role.ROLE_AGENT, 'task working'),
(Role.ROLE_AGENT, 'Please provide input'),
(Role.ROLE_USER, 'Here is the input'),
Expand Down Expand Up @@ -681,3 +683,100 @@ async def test_end_to_end_subscribe_validation_error(
assert {e['field'] for e in errors} == {'id'}

await client.close()


@pytest.mark.asyncio
async def test_end_to_end_direct_message_blocking(transport_setups):
"""Test that an executor can return a direct Message without creating a Task."""
client = transport_setups.client
client._config.streaming = False

message_to_send = Message(
role=Role.ROLE_USER,
message_id='msg-direct-blocking',
parts=[Part(text='Message: Hello agent')],
)

events = [
event
async for event in client.send_message(
request=SendMessageRequest(message=message_to_send)
)
]

assert len(events) == 1
response = events[0]
assert response.HasField('message')
assert not response.HasField('task')
assert_message_matches(
response.message,
Role.ROLE_AGENT,
'Direct reply to: Message: Hello agent',
)
Comment thread
ishymko marked this conversation as resolved.
Outdated


@pytest.mark.asyncio
async def test_end_to_end_direct_message_return_immediately(transport_setups):
"""Test that return_immediately still returns the Message for direct replies.

When the executor responds with a direct Message, the response is
inherently immediate -- there is no async task to defer to. The client
should receive the Message regardless of the return_immediately flag.
"""
client = transport_setups.client
client._config.streaming = False

message_to_send = Message(
role=Role.ROLE_USER,
message_id='msg-direct-return-immediately',
parts=[Part(text='Message: Quick question')],
)
configuration = SendMessageConfiguration(return_immediately=True)

events = [
event
async for event in client.send_message(
request=SendMessageRequest(
message=message_to_send, configuration=configuration
)
)
]

assert len(events) == 1
response = events[0]
assert response.HasField('message')
assert not response.HasField('task')
assert_message_matches(
response.message,
Role.ROLE_AGENT,
'Direct reply to: Message: Quick question',
)


@pytest.mark.asyncio
async def test_end_to_end_direct_message_streaming(transport_setups):
"""Test that streaming returns a direct Message and terminates the stream."""
client = transport_setups.client

message_to_send = Message(
role=Role.ROLE_USER,
message_id='msg-direct-streaming',
parts=[Part(text='Message: Hello streaming')],
)

events = [
event
async for event in client.send_message(
request=SendMessageRequest(message=message_to_send)
)
]

assert len(events) == 1
response = events[0]
assert response.HasField('message')
assert not response.HasField('task')
assert_message_matches(
response.message,
Role.ROLE_AGENT,
'Direct reply to: Message: Hello streaming',
)
Loading