-
Notifications
You must be signed in to change notification settings - Fork 429
Expand file tree
/
Copy pathserver_0_3.py
More file actions
238 lines (208 loc) · 7.82 KB
/
server_0_3.py
File metadata and controls
238 lines (208 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import argparse
import uvicorn
from fastapi import FastAPI
import asyncio
import grpc
import sys
import time
from a2a.server.agent_execution.agent_executor import AgentExecutor
from a2a.server.agent_execution.context import RequestContext
from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication
from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication
from a2a.server.events.event_queue import EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers.default_request_handler import (
DefaultRequestHandler,
)
from a2a.server.request_handlers.grpc_handler import GrpcHandler
from a2a.server.tasks.task_updater import TaskUpdater
from a2a.server.tasks.inmemory_push_notification_config_store import (
InMemoryPushNotificationConfigStore,
)
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentInterface,
Part,
TaskState,
TextPart,
FilePart,
TransportProtocol,
FileWithBytes,
FileWithUri,
DataPart,
)
from a2a.grpc import a2a_pb2_grpc
from starlette.requests import Request
from starlette.concurrency import iterate_in_threadpool
import time
from a2a.utils.task import new_task
from server_common import CustomLoggingMiddleware
class MockAgentExecutor(AgentExecutor):
def __init__(self):
self.events = {}
async def execute(self, context: RequestContext, event_queue: EventQueue):
print(f'SERVER: execute called for task {context.task_id}')
task = new_task(context.message)
task.id = context.task_id
task.context_id = context.context_id
task.status.state = TaskState.working
await event_queue.enqueue_event(task)
task_updater = TaskUpdater(
event_queue,
context.task_id,
context.context_id,
)
await task_updater.update_status(TaskState.working)
text = ''
if context.message and context.message.parts:
part = context.message.parts[0]
if hasattr(part, 'root') and hasattr(part.root, 'text'):
text = part.root.text
elif hasattr(part, 'text'):
text = part.text
metadata = (
dict(context.message.metadata)
if context.message and context.message.metadata
else {}
)
if metadata.get('test_key') not in ('full_message', 'simple_message'):
print(f'SERVER: WARNING: Missing or incorrect metadata: {metadata}')
raise ValueError(
f'Missing expected metadata from client. Got: {metadata}'
)
if metadata.get('test_key') == 'full_message':
expected_parts = [
Part(root=TextPart(text='stream')),
Part(
root=FilePart(
file=FileWithUri(
uri='https://example.com/file.txt',
mime_type='text/plain',
)
)
),
Part(
root=FilePart(
file=FileWithBytes(
bytes=b'aGVsbG8=',
mime_type='application/octet-stream',
)
)
),
Part(root=DataPart(data={'key': 'value'})),
]
assert context.message.parts == expected_parts
print(f"SERVER: request message text='{text}'")
if 'stream' in text:
print(f'SERVER: waiting on stream event for task {context.task_id}')
event = asyncio.Event()
self.events[context.task_id] = event
async def emit_periodic():
try:
while not event.is_set():
await task_updater.update_status(
TaskState.working,
message=task_updater.new_agent_message(
[Part(root=TextPart(text='ping'))]
),
)
await task_updater.add_artifact(
[Part(root=TextPart(text='artifact-chunk'))],
name='test-artifact',
metadata={'artifact_key': 'artifact_value'},
)
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
bg_task = asyncio.create_task(emit_periodic())
await event.wait()
bg_task.cancel()
print(f'SERVER: stream event triggered for task {context.task_id}')
await task_updater.update_status(
TaskState.completed,
message=task_updater.new_agent_message(
[Part(root=TextPart(text='done'))],
metadata={'response_key': 'response_value'},
),
)
print(f'SERVER: execute finished for task {context.task_id}')
async def cancel(self, context: RequestContext, event_queue: EventQueue):
print(f'SERVER: cancel called for task {context.task_id}')
assert context.task_id in self.events
self.events[context.task_id].set()
task_updater = TaskUpdater(
event_queue,
context.task_id,
context.context_id,
)
await task_updater.update_status(TaskState.canceled)
async def main_async(http_port: int, grpc_port: int):
print(
f'SERVER: Starting server on http_port={http_port}, grpc_port={grpc_port}'
)
agent_card = AgentCard(
name='Server 0.3',
description='Server running on a2a v0.3.0',
version='1.0.0',
url=f'http://127.0.0.1:{http_port}/jsonrpc/',
preferred_transport=TransportProtocol.jsonrpc,
skills=[],
capabilities=AgentCapabilities(streaming=True, push_notifications=True),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
additional_interfaces=[
AgentInterface(
transport=TransportProtocol.http_json,
url=f'http://127.0.0.1:{http_port}/rest/',
),
AgentInterface(
transport=TransportProtocol.grpc,
url=f'127.0.0.1:{grpc_port}',
),
],
supports_authenticated_extended_card=False,
)
task_store = InMemoryTaskStore()
handler = DefaultRequestHandler(
agent_executor=MockAgentExecutor(),
task_store=task_store,
queue_manager=InMemoryQueueManager(),
push_config_store=InMemoryPushNotificationConfigStore(),
)
app = FastAPI()
app.mount(
'/jsonrpc',
A2AFastAPIApplication(
http_handler=handler, agent_card=agent_card
).build(),
)
app.mount(
'/rest',
A2ARESTFastAPIApplication(
http_handler=handler, agent_card=agent_card
).build(),
)
# Start gRPC Server
server = grpc.aio.server()
servicer = GrpcHandler(agent_card, handler)
a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server)
server.add_insecure_port(f'127.0.0.1:{grpc_port}')
await server.start()
app.add_middleware(CustomLoggingMiddleware)
# Start Uvicorn
config = uvicorn.Config(
app, host='127.0.0.1', port=http_port, log_level='info', access_log=True
)
uvicorn_server = uvicorn.Server(config)
await uvicorn_server.serve()
def main():
print('Starting server_0_3...')
parser = argparse.ArgumentParser()
parser.add_argument('--http-port', type=int, required=True)
parser.add_argument('--grpc-port', type=int, required=True)
args = parser.parse_args()
asyncio.run(main_async(args.http_port, args.grpc_port))
if __name__ == '__main__':
main()