Skip to content

Commit 0a889c5

Browse files
committed
Add input required test
1 parent cc90642 commit 0a889c5

1 file changed

Lines changed: 188 additions & 37 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 188 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,86 @@
3232
from a2a.utils import TransportProtocol
3333

3434

35+
def assert_history_matches(history, expected_history):
36+
assert len(history) == len(expected_history)
37+
for msg, (expected_role, expected_text) in zip(
38+
history, expected_history, strict=True
39+
):
40+
assert msg.role == expected_role
41+
assert msg.parts[0].text == expected_text
42+
43+
44+
def assert_artifacts_match(artifacts, expected_artifacts):
45+
assert len(artifacts) == len(expected_artifacts)
46+
for artifact, (expected_name, expected_text) in zip(
47+
artifacts, expected_artifacts, strict=True
48+
):
49+
assert artifact.name == expected_name
50+
assert artifact.parts[0].text == expected_text
51+
52+
53+
def assert_events_match(events, expected_events):
54+
assert len(events) == len(expected_events)
55+
for (event, _), (expected_type, expected_val) in zip(
56+
events, expected_events, strict=True
57+
):
58+
assert event.HasField(expected_type)
59+
if expected_type == 'status_update':
60+
assert event.status_update.status.state == expected_val
61+
elif expected_type == 'artifact_update':
62+
if expected_val is not None:
63+
assert_artifacts_match(
64+
[event.artifact_update.artifact],
65+
expected_val,
66+
)
67+
else:
68+
raise ValueError(f'Unexpected event type: {expected_type}')
69+
70+
3571
class MockAgentExecutor(AgentExecutor):
3672
async def execute(self, context: RequestContext, event_queue: EventQueue):
3773
task_updater = TaskUpdater(
3874
event_queue,
3975
context.task_id,
4076
context.context_id,
4177
)
42-
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
43-
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
44-
await task_updater.add_artifact(
45-
parts=[Part(text='artifact content')], name='test-artifact'
78+
user_input = context.get_user_input()
79+
80+
is_input_required_resumption = (
81+
context.current_task is not None
82+
and context.current_task.status.state
83+
== TaskState.TASK_STATE_INPUT_REQUIRED
4684
)
85+
86+
if not is_input_required_resumption:
87+
await task_updater.update_status(
88+
TaskState.TASK_STATE_SUBMITTED,
89+
message=task_updater.new_agent_message(
90+
[Part(text='task submitted')]
91+
),
92+
)
93+
4794
await task_updater.update_status(
48-
TaskState.TASK_STATE_COMPLETED,
49-
message=task_updater.new_agent_message([Part(text='done')]),
95+
TaskState.TASK_STATE_WORKING,
96+
message=task_updater.new_agent_message([Part(text='task working')]),
5097
)
5198

99+
if user_input == 'Need input':
100+
await task_updater.update_status(
101+
TaskState.TASK_STATE_INPUT_REQUIRED,
102+
message=task_updater.new_agent_message(
103+
[Part(text='Please provide input')]
104+
),
105+
)
106+
else:
107+
await task_updater.add_artifact(
108+
parts=[Part(text='artifact content')], name='test-artifact'
109+
)
110+
await task_updater.update_status(
111+
TaskState.TASK_STATE_COMPLETED,
112+
message=task_updater.new_agent_message([Part(text='done')]),
113+
)
114+
52115
async def cancel(self, context: RequestContext, event_queue: EventQueue):
53116
raise NotImplementedError('Cancellation is not supported')
54117

@@ -218,12 +281,18 @@ async def test_end_to_end_send_message_blocking(transport_setups):
218281
response, _ = events[0]
219282
assert response.task.id
220283
assert response.task.status.state == TaskState.TASK_STATE_COMPLETED
221-
assert len(response.task.artifacts) == 1
222-
assert response.task.artifacts[0].name == 'test-artifact'
223-
assert response.task.artifacts[0].parts[0].text == 'artifact content'
224-
assert len(response.task.history) == 1
225-
assert response.task.history[0].role == Role.ROLE_USER
226-
assert response.task.history[0].parts[0].text == 'Run dummy agent!'
284+
assert_artifacts_match(
285+
response.task.artifacts,
286+
[('test-artifact', 'artifact content')],
287+
)
288+
assert_history_matches(
289+
response.task.history,
290+
[
291+
(Role.ROLE_USER, 'Run dummy agent!'),
292+
(Role.ROLE_AGENT, 'task submitted'),
293+
(Role.ROLE_AGENT, 'task working'),
294+
],
295+
)
227296

228297

229298
@pytest.mark.asyncio
@@ -248,9 +317,12 @@ async def test_end_to_end_send_message_non_blocking(transport_setups):
248317
response, _ = events[0]
249318
assert response.task.id
250319
assert response.task.status.state == TaskState.TASK_STATE_SUBMITTED
251-
assert len(response.task.history) == 1
252-
assert response.task.history[0].role == Role.ROLE_USER
253-
assert response.task.history[0].parts[0].text == 'Run dummy agent!'
320+
assert_history_matches(
321+
response.task.history,
322+
[
323+
(Role.ROLE_USER, 'Run dummy agent!'),
324+
],
325+
)
254326

255327

256328
@pytest.mark.asyncio
@@ -270,28 +342,23 @@ async def test_end_to_end_send_message_streaming(transport_setups):
270342
expected_events = [
271343
('status_update', TaskState.TASK_STATE_SUBMITTED),
272344
('status_update', TaskState.TASK_STATE_WORKING),
273-
('artifact_update', None),
345+
('artifact_update', [('test-artifact', 'artifact content')]),
274346
('status_update', TaskState.TASK_STATE_COMPLETED),
275347
]
276348

277-
assert len(events) == len(expected_events)
278-
for (event, task), (expected_type, expected_state) in zip(
279-
events, expected_events, strict=True
280-
):
281-
assert event.HasField(expected_type)
282-
if expected_type == 'status_update':
283-
assert event.status_update.status.state == expected_state
284-
elif expected_type == 'artifact_update':
285-
assert event.artifact_update.artifact.name == 'test-artifact'
286-
assert (
287-
event.artifact_update.artifact.parts[0].text
288-
== 'artifact content'
289-
)
349+
assert_events_match(events, expected_events)
290350

291-
last_task = events[-1][1]
292-
assert len(last_task.history) == 1
293-
assert last_task.history[0].role == Role.ROLE_AGENT
294-
assert last_task.history[0].parts[0].text == 'done'
351+
task = await client.get_task(request=GetTaskRequest(id=events[0][1].id))
352+
assert_history_matches(
353+
task.history,
354+
[
355+
(Role.ROLE_USER, 'Run dummy agent!'),
356+
(Role.ROLE_AGENT, 'task submitted'),
357+
(Role.ROLE_AGENT, 'task working'),
358+
],
359+
)
360+
assert task.status.state == TaskState.TASK_STATE_COMPLETED
361+
assert task.status.message.parts[0].text == 'done'
295362

296363

297364
@pytest.mark.asyncio
@@ -318,9 +385,14 @@ async def test_end_to_end_get_task(transport_setups):
318385
TaskState.TASK_STATE_WORKING,
319386
TaskState.TASK_STATE_COMPLETED,
320387
}
321-
assert len(retrieved_task.history) == 1
322-
assert retrieved_task.history[0].role == Role.ROLE_USER
323-
assert retrieved_task.history[0].parts[0].text == 'Test Get Task'
388+
assert_history_matches(
389+
retrieved_task.history,
390+
[
391+
(Role.ROLE_USER, 'Test Get Task'),
392+
(Role.ROLE_AGENT, 'task submitted'),
393+
(Role.ROLE_AGENT, 'task working'),
394+
],
395+
)
324396

325397

326398
@pytest.mark.asyncio
@@ -361,11 +433,90 @@ async def test_end_to_end_list_tasks(transport_setups):
361433
actual_task_ids.extend([task.id for task in list_response.tasks])
362434

363435
for task in list_response.tasks:
364-
assert len(task.history) == 1
436+
assert len(task.history) >= 1
365437
assert task.history[0].role == Role.ROLE_USER
366438
assert task.history[0].parts[0].text.startswith('Test List Tasks ')
367439

368440
token = list_response.next_page_token
369441

370442
assert len(actual_task_ids) == total_items
371443
assert sorted(actual_task_ids) == sorted(expected_task_ids)
444+
445+
446+
@pytest.mark.asyncio
447+
async def test_end_to_end_input_required(transport_setups):
448+
client = transport_setups.client
449+
450+
message_to_send = Message(
451+
role=Role.ROLE_USER,
452+
message_id='msg-e2e-input-req-1',
453+
parts=[Part(text='Need input')],
454+
)
455+
456+
events = [
457+
event async for event in client.send_message(request=message_to_send)
458+
]
459+
460+
expected_first_events = [
461+
('status_update', TaskState.TASK_STATE_SUBMITTED),
462+
('status_update', TaskState.TASK_STATE_WORKING),
463+
('status_update', TaskState.TASK_STATE_INPUT_REQUIRED),
464+
]
465+
466+
assert_events_match(events, expected_first_events)
467+
468+
task = await client.get_task(request=GetTaskRequest(id=events[0][1].id))
469+
470+
assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED
471+
assert_history_matches(
472+
task.history,
473+
[
474+
(Role.ROLE_USER, 'Need input'),
475+
(Role.ROLE_AGENT, 'task submitted'),
476+
(Role.ROLE_AGENT, 'task working'),
477+
],
478+
)
479+
assert task.status.message.role == Role.ROLE_AGENT
480+
assert task.status.message.parts[0].text == 'Please provide input'
481+
482+
# Follow-up message
483+
follow_up_message = Message(
484+
task_id=task.id,
485+
role=Role.ROLE_USER,
486+
message_id='msg-e2e-input-req-2',
487+
parts=[Part(text='Here is the input')],
488+
)
489+
490+
follow_up_events = [
491+
event async for event in client.send_message(request=follow_up_message)
492+
]
493+
494+
expected_second_events = [
495+
('status_update', TaskState.TASK_STATE_WORKING),
496+
('artifact_update', [('test-artifact', 'artifact content')]),
497+
('status_update', TaskState.TASK_STATE_COMPLETED),
498+
]
499+
500+
assert_events_match(follow_up_events, expected_second_events)
501+
502+
task = await client.get_task(request=GetTaskRequest(id=task.id))
503+
504+
assert task.status.state == TaskState.TASK_STATE_COMPLETED
505+
assert_artifacts_match(
506+
task.artifacts,
507+
[('test-artifact', 'artifact content')],
508+
)
509+
510+
assert_history_matches(
511+
task.history,
512+
[
513+
(Role.ROLE_USER, 'Need input'),
514+
(Role.ROLE_AGENT, 'task submitted'),
515+
(Role.ROLE_AGENT, 'task working'),
516+
(Role.ROLE_AGENT, 'Please provide input'),
517+
(Role.ROLE_USER, 'Here is the input'),
518+
(Role.ROLE_AGENT, 'task working'),
519+
],
520+
)
521+
assert task.status.message.role == Role.ROLE_AGENT
522+
assert task.status.message.parts[0].text == 'done'

0 commit comments

Comments
 (0)