-
Notifications
You must be signed in to change notification settings - Fork 429
Expand file tree
/
Copy pathserver_1_0.py
More file actions
231 lines (200 loc) · 7.77 KB
/
server_1_0.py
File metadata and controls
231 lines (200 loc) · 7.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import argparse
import uvicorn
from fastapi import FastAPI
import asyncio
import grpc
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
from a2a.server.routes.rest_routes import create_rest_routes
from a2a.server.events import EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
from a2a.server.tasks import TaskUpdater
from a2a.server.tasks.inmemory_push_notification_config_store import (
InMemoryPushNotificationConfigStore,
)
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.types.a2a_pb2 import (
AgentCapabilities,
AgentCard,
AgentInterface,
Part,
TaskState,
)
from a2a.types import a2a_pb2_grpc
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
from a2a.utils import TransportProtocol
from server_common import CustomLoggingMiddleware
from google.protobuf.struct_pb2 import Struct, Value
from a2a.utils.task import new_task
class MockAgentExecutor(AgentExecutor):
def __init__(self):
self.events = {}
async def execute(self, context: RequestContext, event_queue: EventQueue):
print(f'SERVER: execute called for task {context.task_id}')
task = new_task(context.message)
task.id = context.task_id
task.context_id = context.context_id
task.status.state = TaskState.TASK_STATE_WORKING
await event_queue.enqueue_event(task)
task_updater = TaskUpdater(
event_queue,
context.task_id,
context.context_id,
)
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
text = ''
if context.message and context.message.parts:
text = context.message.parts[0].text
metadata = (
dict(context.message.metadata)
if context.message and context.message.metadata
else {}
)
if metadata.get('test_key') not in ('full_message', 'simple_message'):
print(f'SERVER: WARNING: Missing or incorrect metadata: {metadata}')
raise ValueError(
f'Missing expected metadata from client. Got: {metadata}'
)
for part in context.message.parts:
if part.HasField('raw'):
assert part.raw == b'hello'
if metadata.get('test_key') == 'full_message':
s = Struct()
s.update({'key': 'value'})
expected_parts = [
Part(text='stream'),
Part(
url='https://example.com/file.txt', media_type='text/plain'
),
Part(raw=b'hello', media_type='application/octet-stream'),
Part(data=Value(struct_value=s)),
]
assert context.message.parts == expected_parts
if 'stream' in text:
print(f'SERVER: waiting on stream event for task {context.task_id}')
event = asyncio.Event()
self.events[context.task_id] = event
async def emit_periodic():
try:
while not event.is_set():
await task_updater.update_status(
TaskState.TASK_STATE_WORKING,
message=task_updater.new_agent_message(
[Part(text='ping')]
),
)
await task_updater.add_artifact(
[Part(text='artifact-chunk')],
name='test-artifact',
metadata={'artifact_key': 'artifact_value'},
)
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
bg_task = asyncio.create_task(emit_periodic())
await event.wait()
bg_task.cancel()
print(f'SERVER: stream event triggered for task {context.task_id}')
await task_updater.update_status(
TaskState.TASK_STATE_COMPLETED,
message=task_updater.new_agent_message(
[Part(text='done')], metadata={'response_key': 'response_value'}
),
)
print(f'SERVER: execute finished for task {context.task_id}')
async def cancel(self, context: RequestContext, event_queue: EventQueue):
print(f'SERVER: cancel called for task {context.task_id}')
assert context.task_id in self.events
self.events[context.task_id].set()
task_updater = TaskUpdater(
event_queue,
context.task_id,
context.context_id,
)
await task_updater.update_status(TaskState.TASK_STATE_CANCELED)
async def main_async(http_port: int, grpc_port: int):
agent_card = AgentCard(
name='Server 1.0',
description='Server running on a2a v1.0',
version='1.0.0',
skills=[],
capabilities=AgentCapabilities(streaming=True, push_notifications=True),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
supported_interfaces=[
AgentInterface(
protocol_binding=TransportProtocol.JSONRPC,
url=f'http://127.0.0.1:{http_port}/jsonrpc/',
),
AgentInterface(
protocol_binding=TransportProtocol.HTTP_JSON,
url=f'http://127.0.0.1:{http_port}/rest/',
protocol_version='1.0',
),
AgentInterface(
protocol_binding=TransportProtocol.HTTP_JSON,
url=f'http://127.0.0.1:{http_port}/rest/',
protocol_version='0.3',
),
AgentInterface(
protocol_binding=TransportProtocol.GRPC,
url=f'127.0.0.1:{grpc_port}',
),
],
)
task_store = InMemoryTaskStore()
handler = DefaultRequestHandler(
MockAgentExecutor(),
task_store,
agent_card,
queue_manager=InMemoryQueueManager(),
push_config_store=InMemoryPushNotificationConfigStore(),
extended_agent_card=agent_card,
)
app = FastAPI()
app.add_middleware(CustomLoggingMiddleware)
agent_card_routes = create_agent_card_routes(
agent_card=agent_card, card_url='/.well-known/agent-card.json'
)
jsonrpc_routes = create_jsonrpc_routes(
request_handler=handler,
rpc_url='/',
enable_v0_3_compat=True,
)
app.mount(
'/jsonrpc',
FastAPI(routes=jsonrpc_routes + agent_card_routes),
)
rest_routes = create_rest_routes(
request_handler=handler,
enable_v0_3_compat=True,
)
app.mount(
'/rest',
FastAPI(routes=rest_routes + agent_card_routes),
)
# Start gRPC Server
server = grpc.aio.server()
servicer = GrpcHandler(handler)
a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server)
compat_servicer = CompatGrpcHandler(handler)
a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server)
server.add_insecure_port(f'127.0.0.1:{grpc_port}')
await server.start()
# Start Uvicorn
config = uvicorn.Config(
app, host='127.0.0.1', port=http_port, log_level='info', access_log=True
)
uvicorn_server = uvicorn.Server(config)
await uvicorn_server.serve()
def main():
print('Starting server_1_0...')
parser = argparse.ArgumentParser()
parser.add_argument('--http-port', type=int, required=True)
parser.add_argument('--grpc-port', type=int, required=True)
args = parser.parse_args()
asyncio.run(main_async(args.http_port, args.grpc_port))
if __name__ == '__main__':
main()