11import argparse
22import asyncio
33import contextlib
4+ import os
5+ import signal
46import uuid
57
8+ from typing import Any
9+
10+ import grpc
611import httpx
712
813from a2a .client import A2ACardResolver , ClientConfig , ClientFactory
914from a2a .types import Message , Part , Role , SendMessageRequest , TaskState
1015
1116
17+ async def _handle_stream (
18+ stream : Any , current_task_id : str | None
19+ ) -> str | None :
20+ async for event , task in stream :
21+ if not task :
22+ continue
23+ if not current_task_id :
24+ current_task_id = task .id
25+
26+ if event :
27+ if event .HasField ('status_update' ):
28+ state_name = TaskState .Name (event .status_update .status .state )
29+ print (f'TaskStatusUpdate [state={ state_name } ]:' , end = ' ' )
30+ if event .status_update .status .HasField ('message' ):
31+ for part in event .status_update .status .message .parts :
32+ if part .text :
33+ print (part .text , end = ' ' )
34+ print ()
35+
36+ if (
37+ event .status_update .status .state
38+ == TaskState .TASK_STATE_COMPLETED
39+ ):
40+ current_task_id = None
41+ print ('--- Task Completed ---' )
42+
43+ elif event .HasField ('artifact_update' ):
44+ print (
45+ f'TaskArtifactUpdate [name={ event .artifact_update .artifact .name } ]:' ,
46+ end = ' ' ,
47+ )
48+ for part in event .artifact_update .artifact .parts :
49+ if part .text :
50+ print (part .text , end = ' ' )
51+ print ()
52+
53+ return current_task_id
54+
55+
1256async def main () -> None :
1357 """Run the A2A terminal client."""
1458 parser = argparse .ArgumentParser (description = 'A2A Terminal Client' )
@@ -48,7 +92,8 @@ async def main() -> None:
4892
4993 while True :
5094 try :
51- user_input = input ('You: ' )
95+ loop = asyncio .get_running_loop ()
96+ user_input = await loop .run_in_executor (None , input , 'You: ' )
5297 except KeyboardInterrupt :
5398 break
5499
@@ -69,49 +114,13 @@ async def main() -> None:
69114
70115 try :
71116 stream = client .send_message (request )
72- async for event , task in stream :
73- if not task :
74- continue
75- if not current_task_id :
76- current_task_id = task .id
77-
78- if event :
79- if event .HasField ('status_update' ):
80- state_name = TaskState .Name (
81- event .status_update .status .state
82- )
83- print (f'TaskStatusUpdate [{ state_name } ]:' , end = ' ' )
84- if event .status_update .status .HasField ('message' ):
85- for (
86- part
87- ) in event .status_update .status .message .parts :
88- if part .text :
89- print (part .text , end = ' ' )
90- print ()
91-
92- if (
93- event .status_update .status .state
94- == TaskState .TASK_STATE_COMPLETED
95- ):
96- current_task_id = None
97- print ('--- Task Completed ---' )
98-
99- elif event .HasField ('artifact_update' ):
100- print (
101- f'TaskArtifactUpdate [{ event .artifact_update .artifact .name } ]:' ,
102- end = ' ' ,
103- )
104- for part in event .artifact_update .artifact .parts :
105- if part .text :
106- print (part .text , end = ' ' )
107- print ()
108-
109- except Exception as e :
117+ current_task_id = await _handle_stream (stream , current_task_id )
118+ except (httpx .RequestError , grpc .RpcError ) as e :
110119 print (f'Error communicating with agent: { e } ' )
111120
112121 await client .close ()
113122
114123
115124if __name__ == '__main__' :
116- with contextlib . suppress ( KeyboardInterrupt , asyncio . CancelledError ):
117- asyncio .run (main ())
125+ signal . signal ( signal . SIGINT , lambda sig , frame : os . _exit ( 0 ))
126+ asyncio .run (main ())
0 commit comments