-
Notifications
You must be signed in to change notification settings - Fork 429
Expand file tree
/
Copy pathcli.py
More file actions
125 lines (99 loc) · 3.76 KB
/
cli.py
File metadata and controls
125 lines (99 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import asyncio
import os
import signal
import uuid
from typing import Any
import grpc
import httpx
from a2a.client import A2ACardResolver, ClientConfig, create_client
from a2a.types import Message, Part, Role, SendMessageRequest, TaskState
async def _handle_stream(
stream: Any, current_task_id: str | None
) -> str | None:
async for event, task in stream:
if not task:
continue
if not current_task_id:
current_task_id = task.id
if event:
if event.HasField('status_update'):
state_name = TaskState.Name(event.status_update.status.state)
print(f'TaskStatusUpdate [state={state_name}]:', end=' ')
if event.status_update.status.HasField('message'):
for part in event.status_update.status.message.parts:
if part.text:
print(part.text, end=' ')
print()
if (
event.status_update.status.state
== TaskState.TASK_STATE_COMPLETED
):
current_task_id = None
print('--- Task Completed ---')
elif event.HasField('artifact_update'):
print(
f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:',
end=' ',
)
for part in event.artifact_update.artifact.parts:
if part.text:
print(part.text, end=' ')
print()
return current_task_id
async def main() -> None:
"""Run the A2A terminal client."""
parser = argparse.ArgumentParser(description='A2A Terminal Client')
parser.add_argument(
'--url', default='http://127.0.0.1:41241', help='Agent base URL'
)
parser.add_argument(
'--transport',
default=None,
help='Preferred transport (JSONRPC, HTTP+JSON, GRPC)',
)
args = parser.parse_args()
config = ClientConfig()
if args.transport:
config.supported_protocol_bindings = [args.transport]
print(
f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})'
)
async with httpx.AsyncClient() as httpx_client:
resolver = A2ACardResolver(httpx_client, args.url)
card = await resolver.get_agent_card()
print('\n✓ Agent Card Found:')
print(f' Name: {card.name}')
client = await create_client(card, client_config=config)
actual_transport = getattr(client, '_transport', client)
print(f' Picked Transport: {actual_transport.__class__.__name__}')
print('\nConnected! Send a message or type /quit to exit.')
current_task_id = None
current_context_id = str(uuid.uuid4())
while True:
try:
loop = asyncio.get_running_loop()
user_input = await loop.run_in_executor(None, input, 'You: ')
except KeyboardInterrupt:
break
if user_input.lower() in ('/quit', '/exit'):
break
if not user_input.strip():
continue
message = Message(
role=Role.ROLE_USER,
message_id=str(uuid.uuid4()),
parts=[Part(text=user_input)],
task_id=current_task_id,
context_id=current_context_id,
)
request = SendMessageRequest(message=message)
try:
stream = client.send_message(request)
current_task_id = await _handle_stream(stream, current_task_id)
except (httpx.RequestError, grpc.RpcError) as e:
print(f'Error communicating with agent: {e}')
await client.close()
if __name__ == '__main__':
signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0))
asyncio.run(main())