forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcli.py
More file actions
134 lines (110 loc) · 4.18 KB
/
cli.py
File metadata and controls
134 lines (110 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import asyncio
import os
import signal
import uuid
from typing import Any
import grpc
import httpx
from a2a.client import A2ACardResolver, ClientConfig, create_client
from a2a.types import Message, Part, Role, SendMessageRequest, TaskState
async def _handle_stream( # noqa: PLR0912
stream: Any, current_task_id: str | None
) -> str | None:
async for event in stream:
if event.HasField('message'):
print('Message:', end=' ')
for part in event.message.parts:
if part.text:
print(part.text, end=' ')
print()
return None
if not current_task_id:
if event.HasField('task'):
current_task_id = event.task.id
print('--- Task Started ---')
print(f'Task [state={TaskState.Name(event.task.status.state)}]')
else:
raise ValueError(f'Unexpected first event: {event}')
if event.HasField('status_update'):
state_name = TaskState.Name(event.status_update.status.state)
print(f'TaskStatusUpdate [state={state_name}]:', end=' ')
if event.status_update.status.HasField('message'):
for part in event.status_update.status.message.parts:
if part.text:
print(part.text, end=' ')
print()
if state_name in (
'TASK_STATE_COMPLETED',
'TASK_STATE_FAILED',
'TASK_STATE_CANCELED',
'TASK_STATE_REJECTED',
):
current_task_id = None
print('--- Task Finished ---')
elif event.HasField('artifact_update'):
print(
f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:',
end=' ',
)
for part in event.artifact_update.artifact.parts:
if part.text:
print(part.text, end=' ')
print()
return current_task_id
async def main() -> None:
"""Run the A2A terminal client."""
parser = argparse.ArgumentParser(description='A2A Terminal Client')
parser.add_argument(
'--url', default='http://127.0.0.1:41241', help='Agent base URL'
)
parser.add_argument(
'--transport',
default=None,
help='Preferred transport (JSONRPC, HTTP+JSON, GRPC)',
)
args = parser.parse_args()
config = ClientConfig()
if args.transport:
config.supported_protocol_bindings = [args.transport]
print(
f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})'
)
async with httpx.AsyncClient() as httpx_client:
resolver = A2ACardResolver(httpx_client, args.url)
card = await resolver.get_agent_card()
print('\n✓ Agent Card Found:')
print(f' Name: {card.name}')
client = await create_client(card, client_config=config)
actual_transport = getattr(client, '_transport', client)
print(f' Picked Transport: {actual_transport.__class__.__name__}')
print('\nConnected! Send a message or type /quit to exit.')
current_task_id = None
current_context_id = str(uuid.uuid4())
while True:
try:
loop = asyncio.get_running_loop()
user_input = await loop.run_in_executor(None, input, 'You: ')
except KeyboardInterrupt:
break
if user_input.lower() in ('/quit', '/exit'):
break
if not user_input.strip():
continue
message = Message(
role=Role.ROLE_USER,
message_id=str(uuid.uuid4()),
parts=[Part(text=user_input)],
task_id=current_task_id,
context_id=current_context_id,
)
request = SendMessageRequest(message=message)
try:
stream = client.send_message(request)
current_task_id = await _handle_stream(stream, current_task_id)
except (httpx.RequestError, grpc.RpcError) as e:
print(f'Error communicating with agent: {e}')
await client.close()
if __name__ == '__main__':
signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0))
asyncio.run(main())