Skip to content

Commit 44fdc31

Browse files
committed
test: add more scenarios to test_end_to_end
Based on https://a2a-protocol.org/latest/specification/#312-send-streaming-message: 1. `Message` based flow. 2. Emit `Task` as a first event. WIP: switches to the old request handler as there are known issues in the new one.
1 parent 01b3b2c commit 44fdc31

1 file changed

Lines changed: 134 additions & 35 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 134 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
import pytest
77
import pytest_asyncio
88

9+
from starlette.applications import Starlette
10+
911
from a2a.client.base_client import BaseClient
1012
from a2a.client.client import ClientConfig
1113
from a2a.client.client_factory import ClientFactory
1214
from a2a.server.agent_execution import AgentExecutor, RequestContext
13-
from a2a.server.routes.rest_routes import create_rest_routes
14-
from starlette.applications import Starlette
15-
from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes
1615
from a2a.server.events import EventQueue
1716
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
18-
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
17+
from a2a.server.request_handlers import GrpcHandler, LegacyRequestHandler
18+
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
19+
from a2a.server.routes.rest_routes import create_rest_routes
1920
from a2a.server.tasks import TaskUpdater
2021
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
2122
from a2a.types import (
@@ -37,7 +38,7 @@
3738
TaskState,
3839
a2a_pb2_grpc,
3940
)
40-
from a2a.utils import TransportProtocol
41+
from a2a.utils import TransportProtocol, new_task
4142
from a2a.utils.errors import InvalidParamsError
4243

4344

@@ -69,7 +70,9 @@ def assert_events_match(events, expected_events):
6970
events, expected_events, strict=True
7071
):
7172
assert event.HasField(expected_type)
72-
if expected_type == 'status_update':
73+
if expected_type == 'task':
74+
assert event.task.status.state == expected_val
75+
elif expected_type == 'status_update':
7376
assert event.status_update.status.state == expected_val
7477
elif expected_type == 'artifact_update':
7578
if expected_val is not None:
@@ -83,26 +86,30 @@ def assert_events_match(events, expected_events):
8386

8487
class MockAgentExecutor(AgentExecutor):
8588
async def execute(self, context: RequestContext, event_queue: EventQueue):
86-
task_updater = TaskUpdater(
87-
event_queue,
88-
context.task_id,
89-
context.context_id,
90-
)
9189
user_input = context.get_user_input()
9290

93-
is_input_required_resumption = (
94-
context.current_task is not None
95-
and context.current_task.status.state
96-
== TaskState.TASK_STATE_INPUT_REQUIRED
97-
)
98-
99-
if not is_input_required_resumption:
100-
await task_updater.update_status(
101-
TaskState.TASK_STATE_SUBMITTED,
102-
message=task_updater.new_agent_message(
103-
[Part(text='task submitted')]
104-
),
91+
# Direct message response (no task created).
92+
if user_input.startswith('Message:'):
93+
await event_queue.enqueue_event(
94+
Message(
95+
role=Role.ROLE_AGENT,
96+
message_id='direct-reply-1',
97+
parts=[Part(text=f'Direct reply to: {user_input}')],
98+
)
10599
)
100+
return
101+
102+
# Task-based response.
103+
task = context.current_task
104+
if not task:
105+
task = new_task(context.message)
106+
await event_queue.enqueue_event(task)
107+
108+
task_updater = TaskUpdater(
109+
event_queue,
110+
task.id,
111+
task.context_id,
112+
)
106113

107114
await task_updater.update_status(
108115
TaskState.TASK_STATE_WORKING,
@@ -168,7 +175,7 @@ class ClientSetup(NamedTuple):
168175
@pytest.fixture
169176
def base_e2e_setup(agent_card):
170177
task_store = InMemoryTaskStore()
171-
handler = DefaultRequestHandler(
178+
handler = LegacyRequestHandler(
172179
agent_executor=MockAgentExecutor(),
173180
task_store=task_store,
174181
agent_card=agent_card,
@@ -328,7 +335,6 @@ async def test_end_to_end_send_message_blocking(transport_setups):
328335
response.task.history,
329336
[
330337
(Role.ROLE_USER, 'Run dummy agent!'),
331-
(Role.ROLE_AGENT, 'task submitted'),
332338
(Role.ROLE_AGENT, 'task working'),
333339
],
334340
)
@@ -386,20 +392,19 @@ async def test_end_to_end_send_message_streaming(transport_setups):
386392
assert_events_match(
387393
events,
388394
[
389-
('status_update', TaskState.TASK_STATE_SUBMITTED),
395+
('task', TaskState.TASK_STATE_SUBMITTED),
390396
('status_update', TaskState.TASK_STATE_WORKING),
391397
('artifact_update', [('test-artifact', 'artifact content')]),
392398
('status_update', TaskState.TASK_STATE_COMPLETED),
393399
],
394400
)
395401

396-
task_id = events[0].status_update.task_id
402+
task_id = events[0].task.id
397403
task = await client.get_task(request=GetTaskRequest(id=task_id))
398404
assert_history_matches(
399405
task.history,
400406
[
401407
(Role.ROLE_USER, 'Run dummy agent!'),
402-
(Role.ROLE_AGENT, 'task submitted'),
403408
(Role.ROLE_AGENT, 'task working'),
404409
],
405410
)
@@ -423,7 +428,7 @@ async def test_end_to_end_get_task(transport_setups):
423428
)
424429
]
425430
response = events[0]
426-
task_id = response.status_update.task_id
431+
task_id = response.task.id
427432

428433
get_request = GetTaskRequest(id=task_id)
429434
retrieved_task = await client.get_task(request=get_request)
@@ -438,7 +443,6 @@ async def test_end_to_end_get_task(transport_setups):
438443
retrieved_task.history,
439444
[
440445
(Role.ROLE_USER, 'Test Get Task'),
441-
(Role.ROLE_AGENT, 'task submitted'),
442446
(Role.ROLE_AGENT, 'task working'),
443447
],
444448
)
@@ -465,7 +469,7 @@ async def test_end_to_end_list_tasks(transport_setups):
465469
)
466470
)
467471
)
468-
expected_task_ids.append(response.status_update.task_id)
472+
expected_task_ids.append(response.task.id)
469473

470474
list_request = ListTasksRequest(page_size=page_size)
471475

@@ -514,21 +518,20 @@ async def test_end_to_end_input_required(transport_setups):
514518
assert_events_match(
515519
events,
516520
[
517-
('status_update', TaskState.TASK_STATE_SUBMITTED),
521+
('task', TaskState.TASK_STATE_SUBMITTED),
518522
('status_update', TaskState.TASK_STATE_WORKING),
519523
('status_update', TaskState.TASK_STATE_INPUT_REQUIRED),
520524
],
521525
)
522526

523-
task_id = events[0].status_update.task_id
527+
task_id = events[0].task.id
524528
task = await client.get_task(request=GetTaskRequest(id=task_id))
525529

526530
assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED
527531
assert_history_matches(
528532
task.history,
529533
[
530534
(Role.ROLE_USER, 'Need input'),
531-
(Role.ROLE_AGENT, 'task submitted'),
532535
(Role.ROLE_AGENT, 'task working'),
533536
],
534537
)
@@ -572,7 +575,6 @@ async def test_end_to_end_input_required(transport_setups):
572575
task.history,
573576
[
574577
(Role.ROLE_USER, 'Need input'),
575-
(Role.ROLE_AGENT, 'task submitted'),
576578
(Role.ROLE_AGENT, 'task working'),
577579
(Role.ROLE_AGENT, 'Please provide input'),
578580
(Role.ROLE_USER, 'Here is the input'),
@@ -681,3 +683,100 @@ async def test_end_to_end_subscribe_validation_error(
681683
assert {e['field'] for e in errors} == {'id'}
682684

683685
await client.close()
686+
687+
688+
@pytest.mark.asyncio
689+
async def test_end_to_end_direct_message_blocking(transport_setups):
690+
"""Test that an executor can return a direct Message without creating a Task."""
691+
client = transport_setups.client
692+
client._config.streaming = False
693+
694+
message_to_send = Message(
695+
role=Role.ROLE_USER,
696+
message_id='msg-direct-blocking',
697+
parts=[Part(text='Message: Hello agent')],
698+
)
699+
700+
events = [
701+
event
702+
async for event in client.send_message(
703+
request=SendMessageRequest(message=message_to_send)
704+
)
705+
]
706+
707+
assert len(events) == 1
708+
response = events[0]
709+
assert response.HasField('message')
710+
assert not response.HasField('task')
711+
assert_message_matches(
712+
response.message,
713+
Role.ROLE_AGENT,
714+
'Direct reply to: Message: Hello agent',
715+
)
716+
717+
718+
@pytest.mark.asyncio
719+
async def test_end_to_end_direct_message_return_immediately(transport_setups):
720+
"""Test that return_immediately still returns the Message for direct replies.
721+
722+
When the executor responds with a direct Message, the response is
723+
inherently immediate -- there is no async task to defer to. The client
724+
should receive the Message regardless of the return_immediately flag.
725+
"""
726+
client = transport_setups.client
727+
client._config.streaming = False
728+
729+
message_to_send = Message(
730+
role=Role.ROLE_USER,
731+
message_id='msg-direct-return-immediately',
732+
parts=[Part(text='Message: Quick question')],
733+
)
734+
configuration = SendMessageConfiguration(return_immediately=True)
735+
736+
events = [
737+
event
738+
async for event in client.send_message(
739+
request=SendMessageRequest(
740+
message=message_to_send, configuration=configuration
741+
)
742+
)
743+
]
744+
745+
assert len(events) == 1
746+
response = events[0]
747+
assert response.HasField('message')
748+
assert not response.HasField('task')
749+
assert_message_matches(
750+
response.message,
751+
Role.ROLE_AGENT,
752+
'Direct reply to: Message: Quick question',
753+
)
754+
755+
756+
@pytest.mark.asyncio
757+
async def test_end_to_end_direct_message_streaming(transport_setups):
758+
"""Test that streaming returns a direct Message and terminates the stream."""
759+
client = transport_setups.client
760+
761+
message_to_send = Message(
762+
role=Role.ROLE_USER,
763+
message_id='msg-direct-streaming',
764+
parts=[Part(text='Message: Hello streaming')],
765+
)
766+
767+
events = [
768+
event
769+
async for event in client.send_message(
770+
request=SendMessageRequest(message=message_to_send)
771+
)
772+
]
773+
774+
assert len(events) == 1
775+
response = events[0]
776+
assert response.HasField('message')
777+
assert not response.HasField('task')
778+
assert_message_matches(
779+
response.message,
780+
Role.ROLE_AGENT,
781+
'Direct reply to: Message: Hello streaming',
782+
)

0 commit comments

Comments
 (0)