|
13 | 13 | from a2a.types import Message, Part, Role, SendMessageRequest, TaskState |
14 | 14 |
|
15 | 15 |
|
16 | | -async def _handle_stream( |
| 16 | +async def _handle_stream( # noqa: PLR0912 |
17 | 17 | stream: Any, current_task_id: str | None |
18 | 18 | ) -> str | None: |
19 | | - async for event, task in stream: |
20 | | - if not task: |
21 | | - continue |
| 19 | + async for event in stream: |
| 20 | + if event.HasField('message'): |
| 21 | + print('Message:', end=' ') |
| 22 | + for part in event.message.parts: |
| 23 | + if part.text: |
| 24 | + print(part.text, end=' ') |
| 25 | + print() |
| 26 | + return None |
| 27 | + |
22 | 28 | if not current_task_id: |
23 | | - current_task_id = task.id |
24 | | - |
25 | | - if event: |
26 | | - if event.HasField('status_update'): |
27 | | - state_name = TaskState.Name(event.status_update.status.state) |
28 | | - print(f'TaskStatusUpdate [state={state_name}]:', end=' ') |
29 | | - if event.status_update.status.HasField('message'): |
30 | | - for part in event.status_update.status.message.parts: |
31 | | - if part.text: |
32 | | - print(part.text, end=' ') |
33 | | - print() |
34 | | - |
35 | | - if ( |
36 | | - event.status_update.status.state |
37 | | - == TaskState.TASK_STATE_COMPLETED |
38 | | - ): |
39 | | - current_task_id = None |
40 | | - print('--- Task Completed ---') |
41 | | - |
42 | | - elif event.HasField('artifact_update'): |
43 | | - print( |
44 | | - f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', |
45 | | - end=' ', |
46 | | - ) |
47 | | - for part in event.artifact_update.artifact.parts: |
| 29 | + if event.HasField('task'): |
| 30 | + current_task_id = event.task.id |
| 31 | + print('--- Task Started ---') |
| 32 | + print(f'Task [state={TaskState.Name(event.task.status.state)}]') |
| 33 | + else: |
| 34 | + raise ValueError(f'Unexpected first event: {event}') |
| 35 | + |
| 36 | + if event.HasField('status_update'): |
| 37 | + state_name = TaskState.Name(event.status_update.status.state) |
| 38 | + print(f'TaskStatusUpdate [state={state_name}]:', end=' ') |
| 39 | + if event.status_update.status.HasField('message'): |
| 40 | + for part in event.status_update.status.message.parts: |
48 | 41 | if part.text: |
49 | 42 | print(part.text, end=' ') |
50 | | - print() |
51 | | - |
| 43 | + print() |
| 44 | + if state_name in ( |
| 45 | + 'TASK_STATE_COMPLETED', |
| 46 | + 'TASK_STATE_FAILED', |
| 47 | + 'TASK_STATE_CANCELED', |
| 48 | + 'TASK_STATE_REJECTED', |
| 49 | + ): |
| 50 | + current_task_id = None |
| 51 | + print('--- Task Finished ---') |
| 52 | + elif event.HasField('artifact_update'): |
| 53 | + print( |
| 54 | + f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', |
| 55 | + end=' ', |
| 56 | + ) |
| 57 | + for part in event.artifact_update.artifact.parts: |
| 58 | + if part.text: |
| 59 | + print(part.text, end=' ') |
| 60 | + print() |
52 | 61 | return current_task_id |
53 | 62 |
|
54 | 63 |
|
|
0 commit comments