3232from a2a .utils import TransportProtocol
3333
3434
35+ def assert_history_matches (history , expected_history ):
36+ assert len (history ) == len (expected_history )
37+ for msg , (expected_role , expected_text ) in zip (
38+ history , expected_history , strict = True
39+ ):
40+ assert msg .role == expected_role
41+ assert msg .parts [0 ].text == expected_text
42+
43+
44+ def assert_artifacts_match (artifacts , expected_artifacts ):
45+ assert len (artifacts ) == len (expected_artifacts )
46+ for artifact , (expected_name , expected_text ) in zip (
47+ artifacts , expected_artifacts , strict = True
48+ ):
49+ assert artifact .name == expected_name
50+ assert artifact .parts [0 ].text == expected_text
51+
52+
53+ def assert_events_match (events , expected_events ):
54+ assert len (events ) == len (expected_events )
55+ for (event , _ ), (expected_type , expected_val ) in zip (
56+ events , expected_events , strict = True
57+ ):
58+ assert event .HasField (expected_type )
59+ if expected_type == 'status_update' :
60+ assert event .status_update .status .state == expected_val
61+ elif expected_type == 'artifact_update' :
62+ if expected_val is not None :
63+ assert_artifacts_match (
64+ [event .artifact_update .artifact ],
65+ expected_val ,
66+ )
67+ else :
68+ raise ValueError (f'Unexpected event type: { expected_type } ' )
69+
70+
3571class MockAgentExecutor (AgentExecutor ):
3672 async def execute (self , context : RequestContext , event_queue : EventQueue ):
3773 task_updater = TaskUpdater (
3874 event_queue ,
3975 context .task_id ,
4076 context .context_id ,
4177 )
42- await task_updater .update_status (TaskState .TASK_STATE_SUBMITTED )
43- await task_updater .update_status (TaskState .TASK_STATE_WORKING )
44- await task_updater .add_artifact (
45- parts = [Part (text = 'artifact content' )], name = 'test-artifact'
78+ user_input = context .get_user_input ()
79+
80+ is_input_required_resumption = (
81+ context .current_task is not None
82+ and context .current_task .status .state
83+ == TaskState .TASK_STATE_INPUT_REQUIRED
4684 )
85+
86+ if not is_input_required_resumption :
87+ await task_updater .update_status (
88+ TaskState .TASK_STATE_SUBMITTED ,
89+ message = task_updater .new_agent_message (
90+ [Part (text = 'task submitted' )]
91+ ),
92+ )
93+
4794 await task_updater .update_status (
48- TaskState .TASK_STATE_COMPLETED ,
49- message = task_updater .new_agent_message ([Part (text = 'done ' )]),
95+ TaskState .TASK_STATE_WORKING ,
96+ message = task_updater .new_agent_message ([Part (text = 'task working ' )]),
5097 )
5198
99+ if user_input == 'Need input' :
100+ await task_updater .update_status (
101+ TaskState .TASK_STATE_INPUT_REQUIRED ,
102+ message = task_updater .new_agent_message (
103+ [Part (text = 'Please provide input' )]
104+ ),
105+ )
106+ else :
107+ await task_updater .add_artifact (
108+ parts = [Part (text = 'artifact content' )], name = 'test-artifact'
109+ )
110+ await task_updater .update_status (
111+ TaskState .TASK_STATE_COMPLETED ,
112+ message = task_updater .new_agent_message ([Part (text = 'done' )]),
113+ )
114+
52115 async def cancel (self , context : RequestContext , event_queue : EventQueue ):
53116 raise NotImplementedError ('Cancellation is not supported' )
54117
@@ -218,12 +281,18 @@ async def test_end_to_end_send_message_blocking(transport_setups):
218281 response , _ = events [0 ]
219282 assert response .task .id
220283 assert response .task .status .state == TaskState .TASK_STATE_COMPLETED
221- assert len (response .task .artifacts ) == 1
222- assert response .task .artifacts [0 ].name == 'test-artifact'
223- assert response .task .artifacts [0 ].parts [0 ].text == 'artifact content'
224- assert len (response .task .history ) == 1
225- assert response .task .history [0 ].role == Role .ROLE_USER
226- assert response .task .history [0 ].parts [0 ].text == 'Run dummy agent!'
284+ assert_artifacts_match (
285+ response .task .artifacts ,
286+ [('test-artifact' , 'artifact content' )],
287+ )
288+ assert_history_matches (
289+ response .task .history ,
290+ [
291+ (Role .ROLE_USER , 'Run dummy agent!' ),
292+ (Role .ROLE_AGENT , 'task submitted' ),
293+ (Role .ROLE_AGENT , 'task working' ),
294+ ],
295+ )
227296
228297
229298@pytest .mark .asyncio
@@ -248,9 +317,12 @@ async def test_end_to_end_send_message_non_blocking(transport_setups):
248317 response , _ = events [0 ]
249318 assert response .task .id
250319 assert response .task .status .state == TaskState .TASK_STATE_SUBMITTED
251- assert len (response .task .history ) == 1
252- assert response .task .history [0 ].role == Role .ROLE_USER
253- assert response .task .history [0 ].parts [0 ].text == 'Run dummy agent!'
320+ assert_history_matches (
321+ response .task .history ,
322+ [
323+ (Role .ROLE_USER , 'Run dummy agent!' ),
324+ ],
325+ )
254326
255327
256328@pytest .mark .asyncio
@@ -270,28 +342,23 @@ async def test_end_to_end_send_message_streaming(transport_setups):
270342 expected_events = [
271343 ('status_update' , TaskState .TASK_STATE_SUBMITTED ),
272344 ('status_update' , TaskState .TASK_STATE_WORKING ),
273- ('artifact_update' , None ),
345+ ('artifact_update' , [( 'test-artifact' , 'artifact content' )] ),
274346 ('status_update' , TaskState .TASK_STATE_COMPLETED ),
275347 ]
276348
277- assert len (events ) == len (expected_events )
278- for (event , task ), (expected_type , expected_state ) in zip (
279- events , expected_events , strict = True
280- ):
281- assert event .HasField (expected_type )
282- if expected_type == 'status_update' :
283- assert event .status_update .status .state == expected_state
284- elif expected_type == 'artifact_update' :
285- assert event .artifact_update .artifact .name == 'test-artifact'
286- assert (
287- event .artifact_update .artifact .parts [0 ].text
288- == 'artifact content'
289- )
349+ assert_events_match (events , expected_events )
290350
291- last_task = events [- 1 ][1 ]
292- assert len (last_task .history ) == 1
293- assert last_task .history [0 ].role == Role .ROLE_AGENT
294- assert last_task .history [0 ].parts [0 ].text == 'done'
351+ task = await client .get_task (request = GetTaskRequest (id = events [0 ][1 ].id ))
352+ assert_history_matches (
353+ task .history ,
354+ [
355+ (Role .ROLE_USER , 'Run dummy agent!' ),
356+ (Role .ROLE_AGENT , 'task submitted' ),
357+ (Role .ROLE_AGENT , 'task working' ),
358+ ],
359+ )
360+ assert task .status .state == TaskState .TASK_STATE_COMPLETED
361+ assert task .status .message .parts [0 ].text == 'done'
295362
296363
297364@pytest .mark .asyncio
@@ -318,9 +385,14 @@ async def test_end_to_end_get_task(transport_setups):
318385 TaskState .TASK_STATE_WORKING ,
319386 TaskState .TASK_STATE_COMPLETED ,
320387 }
321- assert len (retrieved_task .history ) == 1
322- assert retrieved_task .history [0 ].role == Role .ROLE_USER
323- assert retrieved_task .history [0 ].parts [0 ].text == 'Test Get Task'
388+ assert_history_matches (
389+ retrieved_task .history ,
390+ [
391+ (Role .ROLE_USER , 'Test Get Task' ),
392+ (Role .ROLE_AGENT , 'task submitted' ),
393+ (Role .ROLE_AGENT , 'task working' ),
394+ ],
395+ )
324396
325397
326398@pytest .mark .asyncio
@@ -361,11 +433,90 @@ async def test_end_to_end_list_tasks(transport_setups):
361433 actual_task_ids .extend ([task .id for task in list_response .tasks ])
362434
363435 for task in list_response .tasks :
364- assert len (task .history ) = = 1
436+ assert len (task .history ) > = 1
365437 assert task .history [0 ].role == Role .ROLE_USER
366438 assert task .history [0 ].parts [0 ].text .startswith ('Test List Tasks ' )
367439
368440 token = list_response .next_page_token
369441
370442 assert len (actual_task_ids ) == total_items
371443 assert sorted (actual_task_ids ) == sorted (expected_task_ids )
444+
445+
446+ @pytest .mark .asyncio
447+ async def test_end_to_end_input_required (transport_setups ):
448+ client = transport_setups .client
449+
450+ message_to_send = Message (
451+ role = Role .ROLE_USER ,
452+ message_id = 'msg-e2e-input-req-1' ,
453+ parts = [Part (text = 'Need input' )],
454+ )
455+
456+ events = [
457+ event async for event in client .send_message (request = message_to_send )
458+ ]
459+
460+ expected_first_events = [
461+ ('status_update' , TaskState .TASK_STATE_SUBMITTED ),
462+ ('status_update' , TaskState .TASK_STATE_WORKING ),
463+ ('status_update' , TaskState .TASK_STATE_INPUT_REQUIRED ),
464+ ]
465+
466+ assert_events_match (events , expected_first_events )
467+
468+ task = await client .get_task (request = GetTaskRequest (id = events [0 ][1 ].id ))
469+
470+ assert task .status .state == TaskState .TASK_STATE_INPUT_REQUIRED
471+ assert_history_matches (
472+ task .history ,
473+ [
474+ (Role .ROLE_USER , 'Need input' ),
475+ (Role .ROLE_AGENT , 'task submitted' ),
476+ (Role .ROLE_AGENT , 'task working' ),
477+ ],
478+ )
479+ assert task .status .message .role == Role .ROLE_AGENT
480+ assert task .status .message .parts [0 ].text == 'Please provide input'
481+
482+ # Follow-up message
483+ follow_up_message = Message (
484+ task_id = task .id ,
485+ role = Role .ROLE_USER ,
486+ message_id = 'msg-e2e-input-req-2' ,
487+ parts = [Part (text = 'Here is the input' )],
488+ )
489+
490+ follow_up_events = [
491+ event async for event in client .send_message (request = follow_up_message )
492+ ]
493+
494+ expected_second_events = [
495+ ('status_update' , TaskState .TASK_STATE_WORKING ),
496+ ('artifact_update' , [('test-artifact' , 'artifact content' )]),
497+ ('status_update' , TaskState .TASK_STATE_COMPLETED ),
498+ ]
499+
500+ assert_events_match (follow_up_events , expected_second_events )
501+
502+ task = await client .get_task (request = GetTaskRequest (id = task .id ))
503+
504+ assert task .status .state == TaskState .TASK_STATE_COMPLETED
505+ assert_artifacts_match (
506+ task .artifacts ,
507+ [('test-artifact' , 'artifact content' )],
508+ )
509+
510+ assert_history_matches (
511+ task .history ,
512+ [
513+ (Role .ROLE_USER , 'Need input' ),
514+ (Role .ROLE_AGENT , 'task submitted' ),
515+ (Role .ROLE_AGENT , 'task working' ),
516+ (Role .ROLE_AGENT , 'Please provide input' ),
517+ (Role .ROLE_USER , 'Here is the input' ),
518+ (Role .ROLE_AGENT , 'task working' ),
519+ ],
520+ )
521+ assert task .status .message .role == Role .ROLE_AGENT
522+ assert task .status .message .parts [0 ].text == 'done'
0 commit comments