Skip to content

Commit 9084d61

Browse files
committed
Cosmetics
1 parent 35ab621 commit 9084d61

1 file changed

Lines changed: 50 additions & 41 deletions

File tree

samples/cli.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,58 @@
11
import argparse
22
import asyncio
33
import contextlib
4+
import os
5+
import signal
46
import uuid
57

8+
from typing import Any
9+
10+
import grpc
611
import httpx
712

813
from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
914
from a2a.types import Message, Part, Role, SendMessageRequest, TaskState
1015

1116

17+
async def _handle_stream(
18+
stream: Any, current_task_id: str | None
19+
) -> str | None:
20+
async for event, task in stream:
21+
if not task:
22+
continue
23+
if not current_task_id:
24+
current_task_id = task.id
25+
26+
if event:
27+
if event.HasField('status_update'):
28+
state_name = TaskState.Name(event.status_update.status.state)
29+
print(f'TaskStatusUpdate [state={state_name}]:', end=' ')
30+
if event.status_update.status.HasField('message'):
31+
for part in event.status_update.status.message.parts:
32+
if part.text:
33+
print(part.text, end=' ')
34+
print()
35+
36+
if (
37+
event.status_update.status.state
38+
== TaskState.TASK_STATE_COMPLETED
39+
):
40+
current_task_id = None
41+
print('--- Task Completed ---')
42+
43+
elif event.HasField('artifact_update'):
44+
print(
45+
f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:',
46+
end=' ',
47+
)
48+
for part in event.artifact_update.artifact.parts:
49+
if part.text:
50+
print(part.text, end=' ')
51+
print()
52+
53+
return current_task_id
54+
55+
1256
async def main() -> None:
1357
"""Run the A2A terminal client."""
1458
parser = argparse.ArgumentParser(description='A2A Terminal Client')
@@ -48,7 +92,8 @@ async def main() -> None:
4892

4993
while True:
5094
try:
51-
user_input = input('You: ')
95+
loop = asyncio.get_running_loop()
96+
user_input = await loop.run_in_executor(None, input, 'You: ')
5297
except KeyboardInterrupt:
5398
break
5499

@@ -69,49 +114,13 @@ async def main() -> None:
69114

70115
try:
71116
stream = client.send_message(request)
72-
async for event, task in stream:
73-
if not task:
74-
continue
75-
if not current_task_id:
76-
current_task_id = task.id
77-
78-
if event:
79-
if event.HasField('status_update'):
80-
state_name = TaskState.Name(
81-
event.status_update.status.state
82-
)
83-
print(f'TaskStatusUpdate [{state_name}]:', end=' ')
84-
if event.status_update.status.HasField('message'):
85-
for (
86-
part
87-
) in event.status_update.status.message.parts:
88-
if part.text:
89-
print(part.text, end=' ')
90-
print()
91-
92-
if (
93-
event.status_update.status.state
94-
== TaskState.TASK_STATE_COMPLETED
95-
):
96-
current_task_id = None
97-
print('--- Task Completed ---')
98-
99-
elif event.HasField('artifact_update'):
100-
print(
101-
f'TaskArtifactUpdate [{event.artifact_update.artifact.name}]:',
102-
end=' ',
103-
)
104-
for part in event.artifact_update.artifact.parts:
105-
if part.text:
106-
print(part.text, end=' ')
107-
print()
108-
109-
except Exception as e:
117+
current_task_id = await _handle_stream(stream, current_task_id)
118+
except (httpx.RequestError, grpc.RpcError) as e:
110119
print(f'Error communicating with agent: {e}')
111120

112121
await client.close()
113122

114123

115124
if __name__ == '__main__':
116-
with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
117-
asyncio.run(main())
125+
signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0))
126+
asyncio.run(main())

0 commit comments

Comments
 (0)