From c7f3db747e4a9fb8ef534dbf71b7b8150c7d8306 Mon Sep 17 00:00:00 2001 From: Krzysztof Dziedzic Date: Thu, 30 Apr 2026 09:39:38 +0000 Subject: [PATCH] test: setup itk resubscribe tests --- itk/main.py | 256 ++++++++++++++++++++++++++++++++++++++----------- itk/run_itk.sh | 18 ++++ 2 files changed, 219 insertions(+), 55 deletions(-) diff --git a/itk/main.py b/itk/main.py index 76c72e1c2..58d3180bb 100644 --- a/itk/main.py +++ b/itk/main.py @@ -3,6 +3,7 @@ import base64 import logging import os +import signal import uuid import grpc @@ -10,10 +11,12 @@ import uvicorn from fastapi import FastAPI +from typing import Any from pyproto import instruction_pb2 -from a2a.client import ClientConfig, create_client +from a2a.client import Client, ClientConfig, create_client +from a2a.client.errors import A2AClientError from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -31,20 +34,25 @@ InMemoryPushNotificationConfigStore, ) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.context import ServerCallContext from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( AgentCapabilities, AgentCard, AgentInterface, + CancelTaskRequest, Message, Part, SendMessageRequest, + SubscribeToTaskRequest, Task, TaskState, TaskStatus, TaskPushNotificationConfig, ) from a2a.utils import TransportProtocol +from a2a.utils.errors import TaskNotCancelableError +from a2a.server.tasks.push_notification_sender import PushNotificationEvent log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper() log_level = getattr(logging, log_level_str, logging.INFO) @@ -73,7 +81,7 @@ def extract_instruction( # Some clients might send it as base64 in text part raw = base64.b64decode(part.text) inst.ParseFromString(raw) - except Exception: + except Exception: # noqa: BLE001 logger.debug( 'Failed to parse instruction from binary part', exc_info=True, @@ -88,7 +96,7 @@ def extract_instruction( raw = base64.b64decode(part.text) inst = instruction_pb2.Instruction() inst.ParseFromString(raw) - except Exception: + except Exception: # noqa: BLE001 logger.debug( 'Failed to parse instruction from text part', exc_info=True ) @@ -98,6 +106,93 @@ def extract_instruction( return None +def _extract_text_from_event(event: Any) -> list[str]: + """Extracts text parts from an event's message.""" + if isinstance(event, tuple): + results = [] + for item in event: + results.extend(_extract_text_from_event(item)) + return results + + message = None + if hasattr(event, 'HasField'): + if event.HasField('message'): + message = event.message + elif event.HasField('task') and event.task.status.HasField('message'): + message = event.task.status.message + elif event.HasField( + 'status_update' + ) and event.status_update.status.HasField('message'): + message = event.status_update.status.message + + results = [] + if message: + results.extend(part.text for part in message.parts if part.text) + return results + + +async def _handle_call_agent_with_resubscribe( + client: Client, request: SendMessageRequest +) -> list[str]: + """Handles the send-disconnect-resubscribe flow.""" + results = [] + logger.info('Executing re-subscribe behavior') + agen = client.send_message(request) + task_id = None + + async for event in agen: + logger.info('Event before disconnect: %s', event) + if event.HasField('task'): + task_id = event.task.id + elif event.HasField('status_update'): + task_id = event.status_update.task_id + break + + await agen.aclose() + logger.info('Disconnected from task %s. Now re-subscribing.', task_id) + + resub_agen = client.subscribe(SubscribeToTaskRequest(id=task_id)) + + task_obj = None + finished = False + async for event in resub_agen: + logger.info('Event after re-subscribe: %s', event) + if hasattr(event, 'task'): + task_obj = event.task + elif hasattr(event, 'HasField') and event.HasField('task'): + task_obj = event.task + + extracted_text = _extract_text_from_event(event) + for text in extracted_text: + processed_text = text.replace('task-finished', '') + results.append(processed_text) + if any('task-finished' in text for text in extracted_text): + logger.info( + 'Received task-finished after re-subscribe, breaking loop.' + ) + finished = True + break + + if not results and task_obj and hasattr(task_obj, 'history'): + logger.info('Results empty after loop, reading from history.') + for msg in task_obj.history: + if msg.role == 'ROLE_AGENT' or msg.role == 'agent': + for part in msg.parts: + if part.text: + results.append(part.text.replace('task-finished', '')) + + if not finished: + logger.info('Canceling task %s after retrieval.', task_id) + try: + await client.cancel_task(CancelTaskRequest(id=task_id)) + logger.info('Task cancelled successfully: %s', task_id) + except A2AClientError as e: + logger.error('Failed to cancel task %s: %s', task_id, str(e)) + raise + + return results + + def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message: """Wraps an Instruction proto into an A2A Message.""" inst_bytes = inst.SerializeToString() @@ -129,18 +224,22 @@ async def handle_call_agent( 'GRPC': TransportProtocol.GRPC, } - selected_transport = transport_map.get(call.transport.upper()) + selected_transport = transport_map.get( + call.transport.upper(), TransportProtocol.JSONRPC + ) if selected_transport is None: raise ValueError(f'Unsupported transport: {call.transport}') config = ClientConfig() - config.httpx_client = httpx.AsyncClient(timeout=30.0) config.grpc_channel_factory = grpc.aio.insecure_channel config.supported_protocol_bindings = [selected_transport] config.streaming = call.streaming or ( selected_transport == TransportProtocol.GRPC ) + if call.HasField('resubscribe') and not config.streaming: + raise ValueError('Re-subscription requires streaming to be enabled') + if call.HasField('push_notification'): url = call.push_notification.url if not url: @@ -149,47 +248,48 @@ async def handle_call_agent( url = f'http://{url}' config.push_notification_config = TaskPushNotificationConfig( url=f'{url}/notifications', - token='itk-token', # noqa: S106 + token='itk-token', ) - try: - client = await create_client( - call.agent_card_uri, - client_config=config, - ) + async with httpx.AsyncClient(timeout=30.0) as httpx_client: + config.httpx_client = httpx_client + try: + client = await create_client( + call.agent_card_uri, + client_config=config, + ) - # Wrap nested instruction - nested_msg = wrap_instruction_to_request(call.instruction) - request = SendMessageRequest(message=nested_msg) + # Wrap nested instruction + nested_msg = wrap_instruction_to_request(call.instruction) + request = SendMessageRequest(message=nested_msg) - results = [] - async for event in client.send_message(request): - # Event is streaming response and task - logger.info('Event: %s', event) - stream_resp = event - - message = None - if stream_resp.HasField('message'): - message = stream_resp.message - elif stream_resp.HasField( - 'task' - ) and stream_resp.task.status.HasField('message'): - message = stream_resp.task.status.message - elif stream_resp.HasField( - 'status_update' - ) and stream_resp.status_update.status.HasField('message'): - message = stream_resp.status_update.status.message - - if message: - results.extend(part.text for part in message.parts if part.text) - - except Exception as e: - logger.exception('Failed to call outbound agent') - raise RuntimeError( - f'Outbound call to {call.agent_card_uri} failed: {e!s}' - ) from e - else: - return results + results = [] + + if call.HasField('resubscribe'): + results.extend( + await _handle_call_agent_with_resubscribe(client, request) + ) + else: + async for event in client.send_message(request): + logger.info('Event: %s', event) + results.extend(_extract_text_from_event(event)) + + except Exception as e: + logger.exception('Failed to call outbound agent') + raise RuntimeError( + f'Outbound call to {call.agent_card_uri} failed: {e!s}' + ) from e + else: + return results + + +def _should_hold(inst: instruction_pb2.Instruction) -> bool: + """Recursively checks if any part of the instruction requests holding the task.""" + if inst.HasField('return_response') and inst.return_response.hold_task: + return True + if inst.HasField('steps'): + return any(_should_hold(step) for step in inst.steps.instructions) + return False async def handle_instruction( @@ -245,23 +345,57 @@ async def execute( ) return + should_hold_task = _should_hold(instruction) + try: logger.info('Instruction: %s', instruction) results = await handle_instruction(instruction) + response_text = '\n'.join(results) logger.info('Response: %s', response_text) - await task_updater.update_status( - TaskState.TASK_STATE_COMPLETED, - message=task_updater.new_agent_message( - [Part(text=response_text)] - ), - ) - logger.info('Task %s completed', context.task_id) + + if should_hold_task: + logger.info('Holding task %s as requested', context.task_id) + # Emitted event: response + task-finished + logger.info( + 'Emitting response and task-finished for held task %s', context.task_id + ) + await task_updater.update_status( + TaskState.TASK_STATE_WORKING, + message=task_updater.new_agent_message( + [Part(text=response_text + '\n' +'task-finished')] + ), + ) + await asyncio.sleep(2) + + # Continue emitting "task-finished" every 2 seconds + try: + while True: + logger.info( + 'Emitting periodic status update for held task %s', + context.task_id, + ) + await task_updater.update_status( + TaskState.TASK_STATE_WORKING, + message=None, + ) + await asyncio.sleep(2) + except asyncio.CancelledError: + logger.info('Task %s cancelled', context.task_id) + return + else: + await task_updater.update_status( + TaskState.TASK_STATE_COMPLETED, + message=task_updater.new_agent_message( + [Part(text=response_text)] + ), + ) + logger.info('Task %s completed', context.task_id) except Exception as e: logger.exception('Error during instruction handling') await task_updater.update_status( TaskState.TASK_STATE_FAILED, - message=task_updater.new_agent_message([Part(text=str(e))]), + message=None, ) async def cancel( @@ -325,9 +459,7 @@ async def main_async(http_port: int, grpc_port: int) -> None: name='ITK v10 Agent', description='Python agent using SDK 1.0.', version='1.0.0', - capabilities=AgentCapabilities( - streaming=True, push_notifications=True, extended_agent_card=True - ), + capabilities=AgentCapabilities(streaming=True), default_input_modes=['text/plain'], default_output_modes=['text/plain'], supported_interfaces=interfaces, @@ -335,9 +467,11 @@ async def main_async(http_port: int, grpc_port: int) -> None: task_store = InMemoryTaskStore() push_config_store = InMemoryPushNotificationConfigStore() + httpx_client = httpx.AsyncClient() push_sender = BasePushNotificationSender( - httpx_client=httpx.AsyncClient(), + httpx_client=httpx_client, config_store=push_config_store, + context=ServerCallContext(), ) handler = DefaultRequestHandler( @@ -396,10 +530,22 @@ async def main_async(http_port: int, grpc_port: int) -> None: ) config = uvicorn.Config( - app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower() + app, host='127.0.0.1', port=http_port, log_level='info' ) uvicorn_server = uvicorn.Server(config) + # Signal handling + loop = asyncio.get_running_loop() + + async def shutdown() -> None: + logger.info('Shutting down...') + uvicorn_server.should_exit = True + await server.stop(5) + await httpx_client.aclose() + + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown())) + await uvicorn_server.serve() diff --git a/itk/run_itk.sh b/itk/run_itk.sh index 21736f171..5d4b21ef2 100755 --- a/itk/run_itk.sh +++ b/itk/run_itk.sh @@ -163,6 +163,24 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "edges": ["0->1", "0->2", "1->0", "2->0"], "protocols": ["http_json"], "behavior": "push_notification" + }, + { + "name": "Resubscribe Test - JSONRPC", + "sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"], + "traversal": "euler", + "edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"], + "protocols": ["jsonrpc"], + "streaming": true, + "behavior": "resubscribe" + }, + { + "name": "Resubscribe Test - Python & Go Non-JSONRPC Protocols", + "sdks": ["current", "python_v10", "python_v03", "go_v10"], + "traversal": "euler", + "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], + "protocols": ["grpc", "http_json"], + "streaming": true, + "behavior": "resubscribe" } ] }')