66import pytest
77import pytest_asyncio
88
9+ from starlette .applications import Starlette
10+
911from a2a .client .base_client import BaseClient
1012from a2a .client .client import ClientConfig
1113from a2a .client .client_factory import ClientFactory
1214from a2a .server .agent_execution import AgentExecutor , RequestContext
13- from a2a .server .routes .rest_routes import create_rest_routes
14- from starlette .applications import Starlette
15- from a2a .server .routes import create_jsonrpc_routes , create_agent_card_routes
1615from a2a .server .events import EventQueue
1716from a2a .server .events .in_memory_queue_manager import InMemoryQueueManager
18- from a2a .server .request_handlers import DefaultRequestHandler , GrpcHandler
17+ from a2a .server .request_handlers import GrpcHandler , LegacyRequestHandler
18+ from a2a .server .routes import create_agent_card_routes , create_jsonrpc_routes
19+ from a2a .server .routes .rest_routes import create_rest_routes
1920from a2a .server .tasks import TaskUpdater
2021from a2a .server .tasks .inmemory_task_store import InMemoryTaskStore
2122from a2a .types import (
3738 TaskState ,
3839 a2a_pb2_grpc ,
3940)
40- from a2a .utils import TransportProtocol
41+ from a2a .utils import TransportProtocol , new_task
4142from a2a .utils .errors import InvalidParamsError
4243
4344
@@ -69,7 +70,9 @@ def assert_events_match(events, expected_events):
6970 events , expected_events , strict = True
7071 ):
7172 assert event .HasField (expected_type )
72- if expected_type == 'status_update' :
73+ if expected_type == 'task' :
74+ assert event .task .status .state == expected_val
75+ elif expected_type == 'status_update' :
7376 assert event .status_update .status .state == expected_val
7477 elif expected_type == 'artifact_update' :
7578 if expected_val is not None :
@@ -83,26 +86,30 @@ def assert_events_match(events, expected_events):
8386
8487class MockAgentExecutor (AgentExecutor ):
8588 async def execute (self , context : RequestContext , event_queue : EventQueue ):
86- task_updater = TaskUpdater (
87- event_queue ,
88- context .task_id ,
89- context .context_id ,
90- )
9189 user_input = context .get_user_input ()
9290
93- is_input_required_resumption = (
94- context .current_task is not None
95- and context .current_task .status .state
96- == TaskState .TASK_STATE_INPUT_REQUIRED
97- )
98-
99- if not is_input_required_resumption :
100- await task_updater .update_status (
101- TaskState .TASK_STATE_SUBMITTED ,
102- message = task_updater .new_agent_message (
103- [Part (text = 'task submitted' )]
104- ),
91+ # Direct message response (no task created).
92+ if user_input .startswith ('Message:' ):
93+ await event_queue .enqueue_event (
94+ Message (
95+ role = Role .ROLE_AGENT ,
96+ message_id = 'direct-reply-1' ,
97+ parts = [Part (text = f'Direct reply to: { user_input } ' )],
98+ )
10599 )
100+ return
101+
102+ # Task-based response.
103+ task = context .current_task
104+ if not task :
105+ task = new_task (context .message )
106+ await event_queue .enqueue_event (task )
107+
108+ task_updater = TaskUpdater (
109+ event_queue ,
110+ task .id ,
111+ task .context_id ,
112+ )
106113
107114 await task_updater .update_status (
108115 TaskState .TASK_STATE_WORKING ,
@@ -168,7 +175,7 @@ class ClientSetup(NamedTuple):
168175@pytest .fixture
169176def base_e2e_setup (agent_card ):
170177 task_store = InMemoryTaskStore ()
171- handler = DefaultRequestHandler (
178+ handler = LegacyRequestHandler (
172179 agent_executor = MockAgentExecutor (),
173180 task_store = task_store ,
174181 agent_card = agent_card ,
@@ -328,7 +335,6 @@ async def test_end_to_end_send_message_blocking(transport_setups):
328335 response .task .history ,
329336 [
330337 (Role .ROLE_USER , 'Run dummy agent!' ),
331- (Role .ROLE_AGENT , 'task submitted' ),
332338 (Role .ROLE_AGENT , 'task working' ),
333339 ],
334340 )
@@ -386,20 +392,19 @@ async def test_end_to_end_send_message_streaming(transport_setups):
386392 assert_events_match (
387393 events ,
388394 [
389- ('status_update ' , TaskState .TASK_STATE_SUBMITTED ),
395+ ('task ' , TaskState .TASK_STATE_SUBMITTED ),
390396 ('status_update' , TaskState .TASK_STATE_WORKING ),
391397 ('artifact_update' , [('test-artifact' , 'artifact content' )]),
392398 ('status_update' , TaskState .TASK_STATE_COMPLETED ),
393399 ],
394400 )
395401
396- task_id = events [0 ].status_update . task_id
402+ task_id = events [0 ].task . id
397403 task = await client .get_task (request = GetTaskRequest (id = task_id ))
398404 assert_history_matches (
399405 task .history ,
400406 [
401407 (Role .ROLE_USER , 'Run dummy agent!' ),
402- (Role .ROLE_AGENT , 'task submitted' ),
403408 (Role .ROLE_AGENT , 'task working' ),
404409 ],
405410 )
@@ -423,7 +428,7 @@ async def test_end_to_end_get_task(transport_setups):
423428 )
424429 ]
425430 response = events [0 ]
426- task_id = response .status_update . task_id
431+ task_id = response .task . id
427432
428433 get_request = GetTaskRequest (id = task_id )
429434 retrieved_task = await client .get_task (request = get_request )
@@ -438,7 +443,6 @@ async def test_end_to_end_get_task(transport_setups):
438443 retrieved_task .history ,
439444 [
440445 (Role .ROLE_USER , 'Test Get Task' ),
441- (Role .ROLE_AGENT , 'task submitted' ),
442446 (Role .ROLE_AGENT , 'task working' ),
443447 ],
444448 )
@@ -465,7 +469,7 @@ async def test_end_to_end_list_tasks(transport_setups):
465469 )
466470 )
467471 )
468- expected_task_ids .append (response .status_update . task_id )
472+ expected_task_ids .append (response .task . id )
469473
470474 list_request = ListTasksRequest (page_size = page_size )
471475
@@ -514,21 +518,20 @@ async def test_end_to_end_input_required(transport_setups):
514518 assert_events_match (
515519 events ,
516520 [
517- ('status_update ' , TaskState .TASK_STATE_SUBMITTED ),
521+ ('task ' , TaskState .TASK_STATE_SUBMITTED ),
518522 ('status_update' , TaskState .TASK_STATE_WORKING ),
519523 ('status_update' , TaskState .TASK_STATE_INPUT_REQUIRED ),
520524 ],
521525 )
522526
523- task_id = events [0 ].status_update . task_id
527+ task_id = events [0 ].task . id
524528 task = await client .get_task (request = GetTaskRequest (id = task_id ))
525529
526530 assert task .status .state == TaskState .TASK_STATE_INPUT_REQUIRED
527531 assert_history_matches (
528532 task .history ,
529533 [
530534 (Role .ROLE_USER , 'Need input' ),
531- (Role .ROLE_AGENT , 'task submitted' ),
532535 (Role .ROLE_AGENT , 'task working' ),
533536 ],
534537 )
@@ -572,7 +575,6 @@ async def test_end_to_end_input_required(transport_setups):
572575 task .history ,
573576 [
574577 (Role .ROLE_USER , 'Need input' ),
575- (Role .ROLE_AGENT , 'task submitted' ),
576578 (Role .ROLE_AGENT , 'task working' ),
577579 (Role .ROLE_AGENT , 'Please provide input' ),
578580 (Role .ROLE_USER , 'Here is the input' ),
@@ -681,3 +683,100 @@ async def test_end_to_end_subscribe_validation_error(
681683 assert {e ['field' ] for e in errors } == {'id' }
682684
683685 await client .close ()
686+
687+
688+ @pytest .mark .asyncio
689+ async def test_end_to_end_direct_message_blocking (transport_setups ):
690+ """Test that an executor can return a direct Message without creating a Task."""
691+ client = transport_setups .client
692+ client ._config .streaming = False
693+
694+ message_to_send = Message (
695+ role = Role .ROLE_USER ,
696+ message_id = 'msg-direct-blocking' ,
697+ parts = [Part (text = 'Message: Hello agent' )],
698+ )
699+
700+ events = [
701+ event
702+ async for event in client .send_message (
703+ request = SendMessageRequest (message = message_to_send )
704+ )
705+ ]
706+
707+ assert len (events ) == 1
708+ response = events [0 ]
709+ assert response .HasField ('message' )
710+ assert not response .HasField ('task' )
711+ assert_message_matches (
712+ response .message ,
713+ Role .ROLE_AGENT ,
714+ 'Direct reply to: Message: Hello agent' ,
715+ )
716+
717+
718+ @pytest .mark .asyncio
719+ async def test_end_to_end_direct_message_return_immediately (transport_setups ):
720+ """Test that return_immediately still returns the Message for direct replies.
721+
722+ When the executor responds with a direct Message, the response is
723+ inherently immediate -- there is no async task to defer to. The client
724+ should receive the Message regardless of the return_immediately flag.
725+ """
726+ client = transport_setups .client
727+ client ._config .streaming = False
728+
729+ message_to_send = Message (
730+ role = Role .ROLE_USER ,
731+ message_id = 'msg-direct-return-immediately' ,
732+ parts = [Part (text = 'Message: Quick question' )],
733+ )
734+ configuration = SendMessageConfiguration (return_immediately = True )
735+
736+ events = [
737+ event
738+ async for event in client .send_message (
739+ request = SendMessageRequest (
740+ message = message_to_send , configuration = configuration
741+ )
742+ )
743+ ]
744+
745+ assert len (events ) == 1
746+ response = events [0 ]
747+ assert response .HasField ('message' )
748+ assert not response .HasField ('task' )
749+ assert_message_matches (
750+ response .message ,
751+ Role .ROLE_AGENT ,
752+ 'Direct reply to: Message: Quick question' ,
753+ )
754+
755+
756+ @pytest .mark .asyncio
757+ async def test_end_to_end_direct_message_streaming (transport_setups ):
758+ """Test that streaming returns a direct Message and terminates the stream."""
759+ client = transport_setups .client
760+
761+ message_to_send = Message (
762+ role = Role .ROLE_USER ,
763+ message_id = 'msg-direct-streaming' ,
764+ parts = [Part (text = 'Message: Hello streaming' )],
765+ )
766+
767+ events = [
768+ event
769+ async for event in client .send_message (
770+ request = SendMessageRequest (message = message_to_send )
771+ )
772+ ]
773+
774+ assert len (events ) == 1
775+ response = events [0 ]
776+ assert response .HasField ('message' )
777+ assert not response .HasField ('task' )
778+ assert_message_matches (
779+ response .message ,
780+ Role .ROLE_AGENT ,
781+ 'Direct reply to: Message: Hello streaming' ,
782+ )
0 commit comments