Skip to content

Commit 6214946

Browse files
author
Krzysztof Dziedzic
committed
Add ITK test suite
1 parent 734d062 commit 6214946

5 files changed

Lines changed: 528 additions & 0 deletions

File tree

.github/actions/spelling/allow.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ initdb
6363
inmemory
6464
INR
6565
isready
66+
itk
67+
ITK
6668
jcs
69+
jit
6770
jku
6871
JOSE
6972
JPY
@@ -107,11 +110,13 @@ protoc
107110
pydantic
108111
pyi
109112
pypistats
113+
pyproto
110114
pyupgrade
111115
pyversions
112116
redef
113117
respx
114118
resub
119+
rmi
115120
RS256
116121
RUF
117122
SECP256R1

itk/__init__.py

Whitespace-only changes.

itk/main.py

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
import argparse # noqa: I001
2+
import asyncio
3+
import base64
4+
import logging
5+
import uuid
6+
7+
import grpc
8+
import httpx
9+
import uvicorn
10+
11+
from fastapi import FastAPI
12+
13+
from pyproto import instruction_pb2
14+
15+
from a2a.client import ClientConfig, ClientFactory
16+
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
17+
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
18+
from a2a.server.agent_execution import AgentExecutor, RequestContext
19+
from a2a.server.apps import A2ARESTFastAPIApplication
20+
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
21+
from a2a.server.events import EventQueue
22+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
23+
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
24+
from a2a.server.tasks import TaskUpdater
25+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
26+
from a2a.types import a2a_pb2_grpc
27+
from a2a.types.a2a_pb2 import (
28+
AgentCapabilities,
29+
AgentCard,
30+
AgentInterface,
31+
Message,
32+
Part,
33+
SendMessageRequest,
34+
TaskState,
35+
)
36+
from a2a.utils import TransportProtocol
37+
38+
39+
logging.basicConfig(level=logging.INFO)
40+
logger = logging.getLogger(__name__)
41+
42+
43+
def extract_instruction(
44+
message: Message | None,
45+
) -> instruction_pb2.Instruction | None:
46+
"""Extracts an Instruction proto from an A2A Message."""
47+
if not message or not message.parts:
48+
return None
49+
50+
for part in message.parts:
51+
# 1. Handle binary protobuf part (media_type or filename)
52+
if (
53+
part.media_type == 'application/x-protobuf'
54+
or part.filename == 'instruction.bin'
55+
):
56+
try:
57+
inst = instruction_pb2.Instruction()
58+
if part.raw:
59+
inst.ParseFromString(part.raw)
60+
elif part.text:
61+
# Some clients might send it as base64 in text part
62+
raw = base64.b64decode(part.text)
63+
inst.ParseFromString(raw)
64+
except Exception: # noqa: BLE001
65+
logger.debug(
66+
'Failed to parse instruction from binary part',
67+
exc_info=True,
68+
)
69+
continue
70+
else:
71+
return inst
72+
73+
# 2. Handle base64 encoded instruction in any text part
74+
if part.text:
75+
try:
76+
raw = base64.b64decode(part.text)
77+
inst = instruction_pb2.Instruction()
78+
inst.ParseFromString(raw)
79+
except Exception: # noqa: BLE001
80+
logger.debug(
81+
'Failed to parse instruction from text part', exc_info=True
82+
)
83+
continue
84+
else:
85+
return inst
86+
return None
87+
88+
89+
def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
90+
"""Wraps an Instruction proto into an A2A Message."""
91+
inst_bytes = inst.SerializeToString()
92+
return Message(
93+
role='ROLE_USER',
94+
message_id=str(uuid.uuid4()),
95+
parts=[
96+
Part(
97+
raw=inst_bytes,
98+
media_type='application/x-protobuf',
99+
filename='instruction.bin',
100+
)
101+
],
102+
)
103+
104+
105+
async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
106+
"""Handles the CallAgent instruction by invoking another agent."""
107+
logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport)
108+
109+
# Mapping transport string to TransportProtocol enum
110+
transport_map = {
111+
'JSONRPC': TransportProtocol.JSONRPC,
112+
'HTTP+JSON': TransportProtocol.HTTP_JSON,
113+
'HTTP_JSON': TransportProtocol.HTTP_JSON,
114+
'REST': TransportProtocol.HTTP_JSON,
115+
'GRPC': TransportProtocol.GRPC,
116+
}
117+
118+
selected_transport = transport_map.get(
119+
call.transport.upper(), TransportProtocol.JSONRPC
120+
)
121+
if selected_transport is None:
122+
raise ValueError(f'Unsupported transport: {call.transport}')
123+
124+
config = ClientConfig()
125+
config.httpx_client = httpx.AsyncClient(timeout=30.0)
126+
config.grpc_channel_factory = grpc.aio.insecure_channel
127+
config.supported_protocol_bindings = [selected_transport]
128+
config.streaming = call.streaming or (
129+
selected_transport == TransportProtocol.GRPC
130+
)
131+
132+
try:
133+
client = await ClientFactory.connect(
134+
call.agent_card_uri,
135+
client_config=config,
136+
)
137+
138+
# Wrap nested instruction
139+
nested_msg = wrap_instruction_to_request(call.instruction)
140+
request = SendMessageRequest(message=nested_msg)
141+
142+
results = []
143+
async for event in client.send_message(request):
144+
# Event is streaming response and task
145+
logger.info('Event: %s', event)
146+
stream_resp, task = event
147+
148+
message = None
149+
if stream_resp.HasField('message'):
150+
message = stream_resp.message
151+
elif task and task.status.HasField('message'):
152+
message = task.status.message
153+
elif stream_resp.HasField(
154+
'status_update'
155+
) and stream_resp.status_update.status.HasField('message'):
156+
message = stream_resp.status_update.status.message
157+
158+
if message:
159+
results.extend(part.text for part in message.parts if part.text)
160+
161+
except Exception as e:
162+
logger.exception('Failed to call outbound agent')
163+
raise RuntimeError(
164+
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
165+
) from e
166+
else:
167+
return results
168+
169+
170+
async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]:
171+
"""Recursively handles instructions."""
172+
if inst.HasField('call_agent'):
173+
return await handle_call_agent(inst.call_agent)
174+
if inst.HasField('return_response'):
175+
return [inst.return_response.response]
176+
if inst.HasField('steps'):
177+
all_results = []
178+
for step in inst.steps.instructions:
179+
results = await handle_instruction(step)
180+
all_results.extend(results)
181+
return all_results
182+
raise ValueError('Unknown instruction type')
183+
184+
185+
class V10AgentExecutor(AgentExecutor):
186+
"""Executor for ITK v10 agent tasks."""
187+
188+
async def execute(
189+
self, context: RequestContext, event_queue: EventQueue
190+
) -> None:
191+
"""Executes a task instruction."""
192+
logger.info('Executing task %s', context.task_id)
193+
task_updater = TaskUpdater(
194+
event_queue,
195+
context.task_id,
196+
context.context_id,
197+
)
198+
199+
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
200+
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
201+
202+
instruction = extract_instruction(context.message)
203+
if not instruction:
204+
error_msg = 'No valid instruction found in request'
205+
logger.error(error_msg)
206+
await task_updater.update_status(
207+
TaskState.TASK_STATE_FAILED,
208+
message=task_updater.new_agent_message([Part(text=error_msg)]),
209+
)
210+
return
211+
212+
try:
213+
logger.info('Instruction: %s', instruction)
214+
results = await handle_instruction(instruction)
215+
response_text = '\n'.join(results)
216+
logger.info('Response: %s', response_text)
217+
await task_updater.update_status(
218+
TaskState.TASK_STATE_COMPLETED,
219+
message=task_updater.new_agent_message(
220+
[Part(text=response_text)]
221+
),
222+
)
223+
logger.info('Task %s completed', context.task_id)
224+
except Exception as e:
225+
logger.exception('Error during instruction handling')
226+
await task_updater.update_status(
227+
TaskState.TASK_STATE_FAILED,
228+
message=task_updater.new_agent_message([Part(text=str(e))]),
229+
)
230+
231+
async def cancel(
232+
self, context: RequestContext, event_queue: EventQueue
233+
) -> None:
234+
"""Cancels a task."""
235+
logger.info('Cancel requested for task %s', context.task_id)
236+
task_updater = TaskUpdater(
237+
event_queue,
238+
context.task_id,
239+
context.context_id,
240+
)
241+
await task_updater.update_status(TaskState.TASK_STATE_CANCELED)
242+
243+
244+
async def main_async(http_port: int, grpc_port: int) -> None:
245+
"""Starts the Agent with HTTP and gRPC interfaces."""
246+
interfaces = [
247+
AgentInterface(
248+
protocol_binding=TransportProtocol.GRPC,
249+
url=f'127.0.0.1:{grpc_port}',
250+
protocol_version='1.0',
251+
),
252+
AgentInterface(
253+
protocol_binding=TransportProtocol.GRPC,
254+
url=f'127.0.0.1:{grpc_port}',
255+
protocol_version='0.3',
256+
),
257+
]
258+
259+
interfaces.append(
260+
AgentInterface(
261+
protocol_binding=TransportProtocol.JSONRPC,
262+
url=f'http://127.0.0.1:{http_port}/jsonrpc/',
263+
)
264+
)
265+
interfaces.append(
266+
AgentInterface(
267+
protocol_binding=TransportProtocol.HTTP_JSON,
268+
url=f'http://127.0.0.1:{http_port}/rest/',
269+
protocol_version='1.0',
270+
)
271+
)
272+
interfaces.append(
273+
AgentInterface(
274+
protocol_binding=TransportProtocol.HTTP_JSON,
275+
url=f'http://127.0.0.1:{http_port}/rest/',
276+
protocol_version='0.3',
277+
)
278+
)
279+
280+
agent_card = AgentCard(
281+
name='ITK v10 Agent',
282+
description='Python agent using SDK 1.0.',
283+
version='1.0.0',
284+
capabilities=AgentCapabilities(streaming=True),
285+
default_input_modes=['text/plain'],
286+
default_output_modes=['text/plain'],
287+
supported_interfaces=interfaces,
288+
)
289+
290+
task_store = InMemoryTaskStore()
291+
handler = DefaultRequestHandler(
292+
agent_executor=V10AgentExecutor(),
293+
task_store=task_store,
294+
queue_manager=InMemoryQueueManager(),
295+
)
296+
297+
app = FastAPI()
298+
299+
agent_card_routes = create_agent_card_routes(
300+
agent_card=agent_card, card_url='/.well-known/agent-card.json'
301+
)
302+
jsonrpc_routes = create_jsonrpc_routes(
303+
agent_card=agent_card,
304+
request_handler=handler,
305+
extended_agent_card=agent_card,
306+
rpc_url='/',
307+
enable_v0_3_compat=True,
308+
)
309+
app.mount(
310+
'/jsonrpc',
311+
FastAPI(routes=jsonrpc_routes + agent_card_routes),
312+
)
313+
314+
rest_app = A2ARESTFastAPIApplication(
315+
http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True
316+
).build()
317+
app.mount('/rest', rest_app)
318+
319+
server = grpc.aio.server()
320+
321+
compat_servicer = CompatGrpcHandler(agent_card, handler)
322+
a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server)
323+
servicer = GrpcHandler(agent_card, handler)
324+
a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server)
325+
326+
server.add_insecure_port(f'127.0.0.1:{grpc_port}')
327+
await server.start()
328+
329+
logger.info(
330+
'Starting ITK v10 Agent on HTTP port %s and gRPC port %s',
331+
http_port,
332+
grpc_port,
333+
)
334+
335+
config = uvicorn.Config(
336+
app, host='127.0.0.1', port=http_port, log_level='info'
337+
)
338+
uvicorn_server = uvicorn.Server(config)
339+
340+
await uvicorn_server.serve()
341+
342+
343+
def str2bool(v: str | bool) -> bool:
344+
"""Converts a string to a boolean value."""
345+
if isinstance(v, bool):
346+
return v
347+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
348+
return True
349+
if v.lower() in ('no', 'false', 'f', 'n', '0'):
350+
return False
351+
raise argparse.ArgumentTypeError('Boolean value expected.')
352+
353+
354+
def main() -> None:
355+
"""Main entry point for the agent."""
356+
parser = argparse.ArgumentParser()
357+
parser.add_argument('--httpPort', type=int, default=10102)
358+
parser.add_argument('--grpcPort', type=int, default=11002)
359+
args = parser.parse_args()
360+
361+
asyncio.run(main_async(args.httpPort, args.grpcPort))
362+
363+
364+
if __name__ == '__main__':
365+
main()

0 commit comments

Comments
 (0)