Skip to content

Commit 35310ac

Browse files
committed
enforce first streaming event to be Task
1 parent b23c112 commit 35310ac

4 files changed

Lines changed: 183 additions & 59 deletions

File tree

samples/cli.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import signal
55
import uuid
6+
import warnings
67

78
from typing import Any
89

@@ -18,35 +19,50 @@ async def _handle_stream(
1819
stream: Any, current_task_id: str | None
1920
) -> str | None:
2021
async for event in stream:
21-
if not current_task_id:
22-
current_task_id = event.task.id
23-
if event:
24-
if event.HasField('status_update'):
25-
state_name = TaskState.Name(event.status_update.status.state)
26-
print(f'TaskStatusUpdate [state={state_name}]:', end=' ')
27-
if event.status_update.status.HasField('message'):
28-
message = event.status_update.status.message
29-
print(get_message_text(message, delimiter=' '))
30-
print()
22+
if event.HasField('message'):
23+
print(f'Message: {get_message_text(event.message, delimiter=" ")}')
24+
continue
3125

32-
if (
33-
event.status_update.status.state
34-
== TaskState.TASK_STATE_COMPLETED
35-
):
36-
current_task_id = None
37-
print('--- Task Completed ---')
38-
39-
elif event.HasField('artifact_update'):
40-
print(
41-
f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:',
42-
end=' ',
43-
)
44-
print(
45-
get_artifact_text(
46-
event.artifact_update.artifact, delimiter=' '
47-
)
26+
if not current_task_id:
27+
if event.HasField('task'):
28+
# V2 handler emits Task(SUBMITTED) first per A2A spec §3.1.2.
29+
current_task_id = event.task.id
30+
state_name = TaskState.Name(event.task.status.state)
31+
print(f'Task [state={state_name}]')
32+
elif event.HasField('status_update'):
33+
# Legacy handler streams status updates directly without a
34+
# leading Task event; extract the task ID from the update.
35+
current_task_id = event.status_update.task_id
36+
else:
37+
warnings.warn(
38+
'Unexpected first streaming event type. '
39+
'Cannot determine task ID.',
40+
stacklevel=2,
4841
)
42+
43+
if event.HasField('status_update'):
44+
state_name = TaskState.Name(event.status_update.status.state)
45+
print(f'TaskStatusUpdate [state={state_name}]:', end=' ')
46+
if event.status_update.status.HasField('message'):
47+
message = event.status_update.status.message
48+
print(get_message_text(message, delimiter=' '))
49+
else:
4950
print()
51+
if (
52+
event.status_update.status.state
53+
== TaskState.TASK_STATE_COMPLETED
54+
):
55+
current_task_id = None
56+
print('--- Task Completed ---')
57+
58+
elif event.HasField('artifact_update'):
59+
print(
60+
f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:',
61+
end=' ',
62+
)
63+
print(
64+
get_artifact_text(event.artifact_update.artifact, delimiter=' ')
65+
)
5066

5167
return current_task_id
5268

src/a2a/server/request_handlers/default_request_handler_v2.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
SubscribeToTaskRequest,
3838
Task,
3939
TaskPushNotificationConfig,
40+
TaskState,
41+
TaskStatus,
4042
TaskStatusUpdateEvent,
4143
)
4244
from a2a.utils.errors import (
@@ -302,16 +304,35 @@ async def on_message_send_stream( # noqa: D102
302304
params: SendMessageRequest,
303305
context: ServerCallContext,
304306
) -> AsyncGenerator[Event, None]:
307+
is_new_task = not params.message.task_id
308+
305309
active_task, request_context = await self._setup_active_task(
306310
params, context
307311
)
308-
309312
task_id = cast('str', request_context.task_id)
313+
context_id = cast('str', request_context.context_id)
314+
first_event = True
310315

311316
async for event in active_task.subscribe(
312317
request=request_context,
313318
include_initial_task=False,
314319
):
320+
if (
321+
first_event
322+
and is_new_task
323+
and not isinstance(event, (Task, Message))
324+
):
325+
# Agent didn't emit a Task/Message first.
326+
# The stream MUST begin with a Task or Message.
327+
submitted_task = Task(
328+
id=task_id,
329+
context_id=context_id,
330+
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
331+
history=[params.message],
332+
)
333+
yield apply_history_length(submitted_task, params.configuration)
334+
first_event = False
335+
315336
if isinstance(event, Task):
316337
self._validate_task_id_match(task_id, event.id)
317338
yield apply_history_length(event, params.configuration)

tests/integration/test_scenarios.py

Lines changed: 117 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,21 @@ async def cancel(
261261
event
262262
async for event in client.send_message(SendMessageRequest(message=msg))
263263
]
264-
assert [event.status_update.status.state for event in events] == [
265-
TaskState.TASK_STATE_WORKING,
266-
TaskState.TASK_STATE_COMPLETED,
267-
]
264+
if use_legacy:
265+
# Legacy handler streams events as-is (no Task(SUBMITTED) injection).
266+
assert [event.status_update.status.state for event in events] == [
267+
TaskState.TASK_STATE_WORKING,
268+
TaskState.TASK_STATE_COMPLETED,
269+
]
270+
else:
271+
# V2 handler injects Task(SUBMITTED) first per A2A spec §3.1.2.
272+
assert events[0].HasField('task'), (
273+
'First streaming event must be a Task or Message'
274+
)
275+
assert [event.status_update.status.state for event in events[1:]] == [
276+
TaskState.TASK_STATE_WORKING,
277+
TaskState.TASK_STATE_COMPLETED,
278+
]
268279

269280

270281
# Scenario 5: Re-subscribing to a finished task
@@ -374,15 +385,32 @@ async def cancel(
374385
configuration=SendMessageConfiguration(return_immediately=False),
375386
)
376387
)
377-
(event,) = [event async for event in it]
388+
events = [event async for event in it]
378389

379390
if streaming:
380-
assert event.HasField('status_update')
381-
task_id = event.status_update.task_id
382-
assert (
383-
event.status_update.status.state == TaskState.TASK_STATE_COMPLETED
384-
)
391+
if use_legacy:
392+
# Legacy streams events as-is: just the status_update(COMPLETED).
393+
(event,) = events
394+
assert event.HasField('status_update')
395+
task_id = event.status_update.task_id
396+
assert (
397+
event.status_update.status.state
398+
== TaskState.TASK_STATE_COMPLETED
399+
)
400+
else:
401+
# V2 injects Task(SUBMITTED) first per A2A spec §3.1.2.
402+
assert len(events) == 2
403+
assert events[0].HasField('task'), (
404+
'First streaming event must be a Task or Message'
405+
)
406+
task_id = events[0].task.id
407+
assert events[1].HasField('status_update')
408+
assert (
409+
events[1].status_update.status.state
410+
== TaskState.TASK_STATE_COMPLETED
411+
)
385412
else:
413+
(event,) = events
386414
assert event.HasField('task')
387415
task_id = event.task.id
388416
assert event.task.status.state == TaskState.TASK_STATE_COMPLETED
@@ -498,8 +526,23 @@ async def cancel(
498526
tasks = []
499527

500528
if streaming:
501-
res = await it.__anext__()
502-
assert res.status_update.status.state == TaskState.TASK_STATE_WORKING
529+
if use_legacy:
530+
# Legacy streams events as-is; first event is the WORKING status update.
531+
first = await it.__anext__()
532+
assert (
533+
first.status_update.status.state == TaskState.TASK_STATE_WORKING
534+
)
535+
else:
536+
# V2 injects Task(SUBMITTED) first per A2A spec §3.1.2.
537+
first = await it.__anext__()
538+
assert first.HasField('task'), (
539+
'First streaming event must be a Task or Message'
540+
)
541+
second = await it.__anext__()
542+
assert (
543+
second.status_update.status.state
544+
== TaskState.TASK_STATE_WORKING
545+
)
503546
continue_event.set()
504547
else:
505548

@@ -1082,10 +1125,19 @@ async def cancel(
10821125
states = [get_state(event) async for event in it]
10831126

10841127
if streaming:
1085-
assert states == [
1086-
TaskState.TASK_STATE_WORKING,
1087-
TaskState.TASK_STATE_COMPLETED,
1088-
]
1128+
if use_legacy:
1129+
# Legacy streams events as-is (no Task(SUBMITTED) injection).
1130+
assert states == [
1131+
TaskState.TASK_STATE_WORKING,
1132+
TaskState.TASK_STATE_COMPLETED,
1133+
]
1134+
else:
1135+
# V2 injects Task(SUBMITTED) first per A2A spec §3.1.2.
1136+
assert states == [
1137+
TaskState.TASK_STATE_SUBMITTED,
1138+
TaskState.TASK_STATE_WORKING,
1139+
TaskState.TASK_STATE_COMPLETED,
1140+
]
10891141
else:
10901142
assert states == [TaskState.TASK_STATE_WORKING]
10911143

@@ -1151,11 +1203,27 @@ async def cancel(
11511203
)
11521204

11531205
events1 = [event async for event in it]
1154-
assert [get_state(event) for event in events1] == [
1155-
TaskState.TASK_STATE_INPUT_REQUIRED,
1156-
]
1157-
task_id = events1[0].status_update.task_id
1158-
context_id = events1[0].status_update.context_id
1206+
if streaming and not use_legacy:
1207+
# V2 injects Task(SUBMITTED) first per A2A spec §3.1.2.
1208+
assert [get_state(event) for event in events1] == [
1209+
TaskState.TASK_STATE_SUBMITTED,
1210+
TaskState.TASK_STATE_INPUT_REQUIRED,
1211+
]
1212+
task_id = events1[0].task.id
1213+
context_id = events1[0].task.context_id
1214+
elif streaming and use_legacy:
1215+
# Legacy streams events as-is; first event is the INPUT_REQUIRED status update.
1216+
assert [get_state(event) for event in events1] == [
1217+
TaskState.TASK_STATE_INPUT_REQUIRED,
1218+
]
1219+
task_id = events1[0].status_update.task_id
1220+
context_id = events1[0].status_update.context_id
1221+
else:
1222+
assert [get_state(event) for event in events1] == [
1223+
TaskState.TASK_STATE_INPUT_REQUIRED,
1224+
]
1225+
task_id = events1[0].task.id
1226+
context_id = events1[0].task.context_id
11591227

11601228
# Now send another message to resume
11611229
msg2 = Message(
@@ -1240,19 +1308,38 @@ async def cancel(
12401308
)
12411309

12421310
if streaming:
1243-
event1 = await asyncio.wait_for(it.__anext__(), timeout=1.0)
1244-
assert get_state(event1) == TaskState.TASK_STATE_WORKING
1311+
if use_legacy:
1312+
# Legacy streams events as-is: WORKING → AUTH_REQUIRED → COMPLETED.
1313+
event1 = await asyncio.wait_for(it.__anext__(), timeout=1.0)
1314+
assert get_state(event1) == TaskState.TASK_STATE_WORKING
12451315

1246-
event2 = await asyncio.wait_for(it.__anext__(), timeout=1.0)
1247-
assert get_state(event2) == TaskState.TASK_STATE_AUTH_REQUIRED
1316+
event2 = await asyncio.wait_for(it.__anext__(), timeout=1.0)
1317+
assert get_state(event2) == TaskState.TASK_STATE_AUTH_REQUIRED
12481318

1249-
task_id = event2.status_update.task_id
1319+
task_id = event2.status_update.task_id
12501320

1251-
side_channel_event.set()
1321+
side_channel_event.set()
1322+
1323+
(event3,) = [event async for event in it]
1324+
assert get_state(event3) == TaskState.TASK_STATE_COMPLETED
1325+
else:
1326+
# V2 injects Task(SUBMITTED) first per A2A spec §3.1.2.
1327+
# Full sequence: Task(SUBMITTED) → WORKING → AUTH_REQUIRED → COMPLETED.
1328+
event1 = await asyncio.wait_for(it.__anext__(), timeout=1.0)
1329+
assert get_state(event1) == TaskState.TASK_STATE_SUBMITTED
1330+
1331+
event2 = await asyncio.wait_for(it.__anext__(), timeout=1.0)
1332+
assert get_state(event2) == TaskState.TASK_STATE_WORKING
1333+
1334+
event3 = await asyncio.wait_for(it.__anext__(), timeout=1.0)
1335+
assert get_state(event3) == TaskState.TASK_STATE_AUTH_REQUIRED
1336+
1337+
task_id = event3.status_update.task_id
1338+
1339+
side_channel_event.set()
12521340

1253-
# Remaining event.
1254-
(event3,) = [event async for event in it]
1255-
assert get_state(event3) == TaskState.TASK_STATE_COMPLETED
1341+
(event4,) = [event async for event in it]
1342+
assert get_state(event4) == TaskState.TASK_STATE_COMPLETED
12561343
else:
12571344
(event,) = [event async for event in it]
12581345
assert get_state(event) == TaskState.TASK_STATE_AUTH_REQUIRED

tests/server/request_handlers/test_default_request_handler_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,14 +560,14 @@ async def consume_stream():
560560
message_params, create_server_call_context()
561561
):
562562
events.append(event)
563-
if len(events) >= 3:
563+
if len(events) >= 4:
564564
break
565565
return events
566566

567567
start = time.perf_counter()
568568
events = await consume_stream()
569569
elapsed = time.perf_counter() - start
570-
assert len(events) == 3
570+
assert len(events) == 4
571571
assert elapsed < 0.5
572572
texts = [p.text for e in events for p in e.status.message.parts]
573573
assert texts == ['Event 0', 'Event 1', 'Event 2']

0 commit comments

Comments
 (0)