@@ -39,16 +39,45 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
3939 context .task_id ,
4040 context .context_id ,
4141 )
42- await task_updater .update_status (TaskState .TASK_STATE_SUBMITTED )
43- await task_updater .update_status (TaskState .TASK_STATE_WORKING )
44- await task_updater .add_artifact (
45- parts = [Part (text = 'artifact content' )], name = 'test-artifact'
42+ user_input = context .get_user_input ()
43+
44+ is_input_required_resumption = (
45+ context .current_task is not None
46+ and context .current_task .status .state
47+ == TaskState .TASK_STATE_INPUT_REQUIRED
4648 )
49+
50+ if not is_input_required_resumption :
51+ await task_updater .update_status (
52+ TaskState .TASK_STATE_SUBMITTED ,
53+ message = task_updater .new_agent_message (
54+ [Part (text = 'task submitted' )]
55+ ),
56+ )
57+
4758 await task_updater .update_status (
48- TaskState .TASK_STATE_COMPLETED ,
49- message = task_updater .new_agent_message ([Part (text = 'done' )]),
59+ TaskState .TASK_STATE_WORKING ,
60+ message = task_updater .new_agent_message (
61+ [Part (text = 'task working' )]
62+ ),
5063 )
5164
65+ if user_input == 'Need input' :
66+ await task_updater .update_status (
67+ TaskState .TASK_STATE_INPUT_REQUIRED ,
68+ message = task_updater .new_agent_message (
69+ [Part (text = 'Please provide input' )]
70+ ),
71+ )
72+ else :
73+ await task_updater .add_artifact (
74+ parts = [Part (text = 'artifact content' )], name = 'test-artifact'
75+ )
76+ await task_updater .update_status (
77+ TaskState .TASK_STATE_COMPLETED ,
78+ message = task_updater .new_agent_message ([Part (text = 'done' )]),
79+ )
80+
5281 async def cancel (self , context : RequestContext , event_queue : EventQueue ):
5382 raise NotImplementedError ('Cancellation is not supported' )
5483
@@ -221,9 +250,13 @@ async def test_end_to_end_send_message_blocking(transport_setups):
221250 assert len (response .task .artifacts ) == 1
222251 assert response .task .artifacts [0 ].name == 'test-artifact'
223252 assert response .task .artifacts [0 ].parts [0 ].text == 'artifact content'
224- assert len (response .task .history ) == 1
253+ assert len (response .task .history ) == 3
225254 assert response .task .history [0 ].role == Role .ROLE_USER
226255 assert response .task .history [0 ].parts [0 ].text == 'Run dummy agent!'
256+ assert response .task .history [1 ].role == Role .ROLE_AGENT
257+ assert response .task .history [1 ].parts [0 ].text == 'task submitted'
258+ assert response .task .history [2 ].role == Role .ROLE_AGENT
259+ assert response .task .history [2 ].parts [0 ].text == 'task working'
227260
228261
229262@pytest .mark .asyncio
@@ -275,7 +308,7 @@ async def test_end_to_end_send_message_streaming(transport_setups):
275308 ]
276309
277310 assert len (events ) == len (expected_events )
278- for (event , task ), (expected_type , expected_state ) in zip (
311+ for (event , _ ), (expected_type , expected_state ) in zip (
279312 events , expected_events , strict = True
280313 ):
281314 assert event .HasField (expected_type )
@@ -289,9 +322,13 @@ async def test_end_to_end_send_message_streaming(transport_setups):
289322 )
290323
291324 last_task = events [- 1 ][1 ]
292- assert len (last_task .history ) == 1
325+ assert len (last_task .history ) == 3
293326 assert last_task .history [0 ].role == Role .ROLE_AGENT
294- assert last_task .history [0 ].parts [0 ].text == 'done'
327+ assert last_task .history [0 ].parts [0 ].text == 'task submitted'
328+ assert last_task .history [1 ].role == Role .ROLE_AGENT
329+ assert last_task .history [1 ].parts [0 ].text == 'task working'
330+ assert last_task .history [2 ].role == Role .ROLE_AGENT
331+ assert last_task .history [2 ].parts [0 ].text == 'done'
295332
296333
297334@pytest .mark .asyncio
@@ -318,9 +355,13 @@ async def test_end_to_end_get_task(transport_setups):
318355 TaskState .TASK_STATE_WORKING ,
319356 TaskState .TASK_STATE_COMPLETED ,
320357 }
321- assert len (retrieved_task .history ) == 1
358+ assert len (retrieved_task .history ) == 3
322359 assert retrieved_task .history [0 ].role == Role .ROLE_USER
323360 assert retrieved_task .history [0 ].parts [0 ].text == 'Test Get Task'
361+ assert retrieved_task .history [1 ].role == Role .ROLE_AGENT
362+ assert retrieved_task .history [1 ].parts [0 ].text == 'task submitted'
363+ assert retrieved_task .history [2 ].role == Role .ROLE_AGENT
364+ assert retrieved_task .history [2 ].parts [0 ].text == 'task working'
324365
325366
326367@pytest .mark .asyncio
@@ -361,11 +402,104 @@ async def test_end_to_end_list_tasks(transport_setups):
361402 actual_task_ids .extend ([task .id for task in list_response .tasks ])
362403
363404 for task in list_response .tasks :
364- assert len (task .history ) = = 1
405+ assert len (task .history ) > = 1
365406 assert task .history [0 ].role == Role .ROLE_USER
366407 assert task .history [0 ].parts [0 ].text .startswith ('Test List Tasks ' )
367408
368409 token = list_response .next_page_token
369410
370411 assert len (actual_task_ids ) == total_items
371412 assert sorted (actual_task_ids ) == sorted (expected_task_ids )
413+
414+
415+ @pytest .mark .asyncio
416+ async def test_end_to_end_input_required (transport_setups ):
417+ client = transport_setups .client
418+
419+ message_to_send = Message (
420+ role = Role .ROLE_USER ,
421+ message_id = 'msg-e2e-input-req-1' ,
422+ parts = [Part (text = 'Need input' )],
423+ )
424+
425+ events = [
426+ event async for event in client .send_message (request = message_to_send )
427+ ]
428+
429+ expected_first_events = [
430+ ('status_update' , TaskState .TASK_STATE_SUBMITTED ),
431+ ('status_update' , TaskState .TASK_STATE_WORKING ),
432+ ('status_update' , TaskState .TASK_STATE_INPUT_REQUIRED ),
433+ ]
434+
435+ assert len (events ) == len (expected_first_events )
436+ for (event , _ ), (expected_type , expected_state ) in zip (
437+ events , expected_first_events , strict = True
438+ ):
439+ assert event .HasField (expected_type )
440+ if expected_type == 'status_update' :
441+ assert event .status_update .status .state == expected_state
442+
443+ _ , task = events [- 1 ]
444+
445+ assert task .status .state == TaskState .TASK_STATE_INPUT_REQUIRED
446+ assert len (task .history ) == 3
447+ assert task .history [0 ].role == Role .ROLE_AGENT
448+ assert task .history [0 ].parts [0 ].text == 'task submitted'
449+ assert task .history [1 ].role == Role .ROLE_AGENT
450+ assert task .history [1 ].parts [0 ].text == 'task working'
451+ assert task .history [2 ].role == Role .ROLE_AGENT
452+ assert task .history [2 ].parts [0 ].text == 'Please provide input'
453+ assert task .status .message .role == Role .ROLE_AGENT
454+ assert task .status .message .parts [0 ].text == 'Please provide input'
455+
456+ task_id = task .id
457+
458+ # Follow-up message
459+ follow_up_message = Message (
460+ task_id = task_id ,
461+ role = Role .ROLE_USER ,
462+ message_id = 'msg-e2e-input-req-2' ,
463+ parts = [Part (text = 'Here is the input' )],
464+ )
465+
466+ follow_up_events = [
467+ event async for event in client .send_message (request = follow_up_message )
468+ ]
469+
470+ expected_second_events = [
471+ ('status_update' , TaskState .TASK_STATE_WORKING ),
472+ ('artifact_update' , None ),
473+ ('status_update' , TaskState .TASK_STATE_COMPLETED ),
474+ ]
475+
476+ assert len (follow_up_events ) == len (expected_second_events )
477+ for (event , _ ), (expected_type , expected_state ) in zip (
478+ follow_up_events , expected_second_events , strict = True
479+ ):
480+ assert event .HasField (expected_type )
481+ if expected_type == 'status_update' :
482+ assert event .status_update .status .state == expected_state
483+ elif expected_type == 'artifact_update' :
484+ assert event .artifact_update .artifact .name == 'test-artifact'
485+ assert (
486+ event .artifact_update .artifact .parts [0 ].text
487+ == 'artifact content'
488+ )
489+
490+ _ , task_end = follow_up_events [- 1 ]
491+ print (task_end )
492+
493+ assert task_end .status .state == TaskState .TASK_STATE_COMPLETED
494+ assert task_end .id == task_id
495+ assert len (task_end .artifacts ) == 1
496+ assert task_end .artifacts [0 ].name == 'test-artifact'
497+ assert task_end .artifacts [0 ].parts [0 ].text == 'artifact content'
498+
499+ assert len (task_end .history ) == 2
500+ assert task_end .history [0 ].role == Role .ROLE_AGENT
501+ assert task_end .history [0 ].parts [0 ].text == 'task working'
502+ assert task_end .history [1 ].role == Role .ROLE_AGENT
503+ assert task_end .history [1 ].parts [0 ].text == 'done'
504+ assert task_end .status .message .role == Role .ROLE_AGENT
505+ assert task_end .status .message .parts [0 ].text == 'done'
0 commit comments