Skip to content

Commit 958486e

Browse files
committed
Add input required test
1 parent cc90642 commit 958486e

1 file changed

Lines changed: 146 additions & 12 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 146 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,45 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
3939
context.task_id,
4040
context.context_id,
4141
)
42-
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
43-
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
44-
await task_updater.add_artifact(
45-
parts=[Part(text='artifact content')], name='test-artifact'
42+
user_input = context.get_user_input()
43+
44+
is_input_required_resumption = (
45+
context.current_task is not None
46+
and context.current_task.status.state
47+
== TaskState.TASK_STATE_INPUT_REQUIRED
4648
)
49+
50+
if not is_input_required_resumption:
51+
await task_updater.update_status(
52+
TaskState.TASK_STATE_SUBMITTED,
53+
message=task_updater.new_agent_message(
54+
[Part(text='task submitted')]
55+
),
56+
)
57+
4758
await task_updater.update_status(
48-
TaskState.TASK_STATE_COMPLETED,
49-
message=task_updater.new_agent_message([Part(text='done')]),
59+
TaskState.TASK_STATE_WORKING,
60+
message=task_updater.new_agent_message(
61+
[Part(text='task working')]
62+
),
5063
)
5164

65+
if user_input == 'Need input':
66+
await task_updater.update_status(
67+
TaskState.TASK_STATE_INPUT_REQUIRED,
68+
message=task_updater.new_agent_message(
69+
[Part(text='Please provide input')]
70+
),
71+
)
72+
else:
73+
await task_updater.add_artifact(
74+
parts=[Part(text='artifact content')], name='test-artifact'
75+
)
76+
await task_updater.update_status(
77+
TaskState.TASK_STATE_COMPLETED,
78+
message=task_updater.new_agent_message([Part(text='done')]),
79+
)
80+
5281
async def cancel(self, context: RequestContext, event_queue: EventQueue):
5382
raise NotImplementedError('Cancellation is not supported')
5483

@@ -221,9 +250,13 @@ async def test_end_to_end_send_message_blocking(transport_setups):
221250
assert len(response.task.artifacts) == 1
222251
assert response.task.artifacts[0].name == 'test-artifact'
223252
assert response.task.artifacts[0].parts[0].text == 'artifact content'
224-
assert len(response.task.history) == 1
253+
assert len(response.task.history) == 3
225254
assert response.task.history[0].role == Role.ROLE_USER
226255
assert response.task.history[0].parts[0].text == 'Run dummy agent!'
256+
assert response.task.history[1].role == Role.ROLE_AGENT
257+
assert response.task.history[1].parts[0].text == 'task submitted'
258+
assert response.task.history[2].role == Role.ROLE_AGENT
259+
assert response.task.history[2].parts[0].text == 'task working'
227260

228261

229262
@pytest.mark.asyncio
@@ -275,7 +308,7 @@ async def test_end_to_end_send_message_streaming(transport_setups):
275308
]
276309

277310
assert len(events) == len(expected_events)
278-
for (event, task), (expected_type, expected_state) in zip(
311+
for (event, _), (expected_type, expected_state) in zip(
279312
events, expected_events, strict=True
280313
):
281314
assert event.HasField(expected_type)
@@ -289,9 +322,13 @@ async def test_end_to_end_send_message_streaming(transport_setups):
289322
)
290323

291324
last_task = events[-1][1]
292-
assert len(last_task.history) == 1
325+
assert len(last_task.history) == 3
293326
assert last_task.history[0].role == Role.ROLE_AGENT
294-
assert last_task.history[0].parts[0].text == 'done'
327+
assert last_task.history[0].parts[0].text == 'task submitted'
328+
assert last_task.history[1].role == Role.ROLE_AGENT
329+
assert last_task.history[1].parts[0].text == 'task working'
330+
assert last_task.history[2].role == Role.ROLE_AGENT
331+
assert last_task.history[2].parts[0].text == 'done'
295332

296333

297334
@pytest.mark.asyncio
@@ -318,9 +355,13 @@ async def test_end_to_end_get_task(transport_setups):
318355
TaskState.TASK_STATE_WORKING,
319356
TaskState.TASK_STATE_COMPLETED,
320357
}
321-
assert len(retrieved_task.history) == 1
358+
assert len(retrieved_task.history) == 3
322359
assert retrieved_task.history[0].role == Role.ROLE_USER
323360
assert retrieved_task.history[0].parts[0].text == 'Test Get Task'
361+
assert retrieved_task.history[1].role == Role.ROLE_AGENT
362+
assert retrieved_task.history[1].parts[0].text == 'task submitted'
363+
assert retrieved_task.history[2].role == Role.ROLE_AGENT
364+
assert retrieved_task.history[2].parts[0].text == 'task working'
324365

325366

326367
@pytest.mark.asyncio
@@ -361,11 +402,104 @@ async def test_end_to_end_list_tasks(transport_setups):
361402
actual_task_ids.extend([task.id for task in list_response.tasks])
362403

363404
for task in list_response.tasks:
364-
assert len(task.history) == 1
405+
assert len(task.history) >= 1
365406
assert task.history[0].role == Role.ROLE_USER
366407
assert task.history[0].parts[0].text.startswith('Test List Tasks ')
367408

368409
token = list_response.next_page_token
369410

370411
assert len(actual_task_ids) == total_items
371412
assert sorted(actual_task_ids) == sorted(expected_task_ids)
413+
414+
415+
@pytest.mark.asyncio
416+
async def test_end_to_end_input_required(transport_setups):
417+
client = transport_setups.client
418+
419+
message_to_send = Message(
420+
role=Role.ROLE_USER,
421+
message_id='msg-e2e-input-req-1',
422+
parts=[Part(text='Need input')],
423+
)
424+
425+
events = [
426+
event async for event in client.send_message(request=message_to_send)
427+
]
428+
429+
expected_first_events = [
430+
('status_update', TaskState.TASK_STATE_SUBMITTED),
431+
('status_update', TaskState.TASK_STATE_WORKING),
432+
('status_update', TaskState.TASK_STATE_INPUT_REQUIRED),
433+
]
434+
435+
assert len(events) == len(expected_first_events)
436+
for (event, _), (expected_type, expected_state) in zip(
437+
events, expected_first_events, strict=True
438+
):
439+
assert event.HasField(expected_type)
440+
if expected_type == 'status_update':
441+
assert event.status_update.status.state == expected_state
442+
443+
_, task = events[-1]
444+
445+
assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED
446+
assert len(task.history) == 3
447+
assert task.history[0].role == Role.ROLE_AGENT
448+
assert task.history[0].parts[0].text == 'task submitted'
449+
assert task.history[1].role == Role.ROLE_AGENT
450+
assert task.history[1].parts[0].text == 'task working'
451+
assert task.history[2].role == Role.ROLE_AGENT
452+
assert task.history[2].parts[0].text == 'Please provide input'
453+
assert task.status.message.role == Role.ROLE_AGENT
454+
assert task.status.message.parts[0].text == 'Please provide input'
455+
456+
task_id = task.id
457+
458+
# Follow-up message
459+
follow_up_message = Message(
460+
task_id=task_id,
461+
role=Role.ROLE_USER,
462+
message_id='msg-e2e-input-req-2',
463+
parts=[Part(text='Here is the input')],
464+
)
465+
466+
follow_up_events = [
467+
event async for event in client.send_message(request=follow_up_message)
468+
]
469+
470+
expected_second_events = [
471+
('status_update', TaskState.TASK_STATE_WORKING),
472+
('artifact_update', None),
473+
('status_update', TaskState.TASK_STATE_COMPLETED),
474+
]
475+
476+
assert len(follow_up_events) == len(expected_second_events)
477+
for (event, _), (expected_type, expected_state) in zip(
478+
follow_up_events, expected_second_events, strict=True
479+
):
480+
assert event.HasField(expected_type)
481+
if expected_type == 'status_update':
482+
assert event.status_update.status.state == expected_state
483+
elif expected_type == 'artifact_update':
484+
assert event.artifact_update.artifact.name == 'test-artifact'
485+
assert (
486+
event.artifact_update.artifact.parts[0].text
487+
== 'artifact content'
488+
)
489+
490+
_, task_end = follow_up_events[-1]
491+
print(task_end)
492+
493+
assert task_end.status.state == TaskState.TASK_STATE_COMPLETED
494+
assert task_end.id == task_id
495+
assert len(task_end.artifacts) == 1
496+
assert task_end.artifacts[0].name == 'test-artifact'
497+
assert task_end.artifacts[0].parts[0].text == 'artifact content'
498+
499+
assert len(task_end.history) == 2
500+
assert task_end.history[0].role == Role.ROLE_AGENT
501+
assert task_end.history[0].parts[0].text == 'task working'
502+
assert task_end.history[1].role == Role.ROLE_AGENT
503+
assert task_end.history[1].parts[0].text == 'done'
504+
assert task_end.status.message.role == Role.ROLE_AGENT
505+
assert task_end.status.message.parts[0].text == 'done'

0 commit comments

Comments
 (0)