33import base64
44import logging
55import os
6+ import signal
67import uuid
78
89import grpc
910import httpx
1011import uvicorn
1112
1213from fastapi import FastAPI
14+ from typing import Any
1315
1416from pyproto import instruction_pb2
1517
16- from a2a .client import ClientConfig , create_client
18+ from a2a .client import Client , ClientConfig , create_client
19+ from a2a .client .errors import A2AClientError
1720from a2a .compat .v0_3 import a2a_v0_3_pb2_grpc
1821from a2a .compat .v0_3 .grpc_handler import CompatGrpcHandler
1922from a2a .server .agent_execution import AgentExecutor , RequestContext
3134 InMemoryPushNotificationConfigStore ,
3235)
3336from a2a .server .tasks .inmemory_task_store import InMemoryTaskStore
37+ from a2a .server .context import ServerCallContext
3438from a2a .types import a2a_pb2_grpc
3539from a2a .types .a2a_pb2 import (
3640 AgentCapabilities ,
3741 AgentCard ,
3842 AgentInterface ,
43+ CancelTaskRequest ,
3944 Message ,
4045 Part ,
4146 SendMessageRequest ,
47+ SubscribeToTaskRequest ,
4248 Task ,
4349 TaskState ,
4450 TaskStatus ,
4551 TaskPushNotificationConfig ,
4652)
4753from a2a .utils import TransportProtocol
54+ from a2a .utils .errors import TaskNotCancelableError
55+ from a2a .server .tasks .push_notification_sender import PushNotificationEvent
4856
4957log_level_str = os .environ .get ('ITK_LOG_LEVEL' , 'INFO' ).upper ()
5058log_level = getattr (logging , log_level_str , logging .INFO )
@@ -73,7 +81,7 @@ def extract_instruction(
7381 # Some clients might send it as base64 in text part
7482 raw = base64 .b64decode (part .text )
7583 inst .ParseFromString (raw )
76- except Exception :
84+ except Exception : # noqa: BLE001
7785 logger .debug (
7886 'Failed to parse instruction from binary part' ,
7987 exc_info = True ,
@@ -88,7 +96,7 @@ def extract_instruction(
8896 raw = base64 .b64decode (part .text )
8997 inst = instruction_pb2 .Instruction ()
9098 inst .ParseFromString (raw )
91- except Exception :
99+ except Exception : # noqa: BLE001
92100 logger .debug (
93101 'Failed to parse instruction from text part' , exc_info = True
94102 )
@@ -98,6 +106,93 @@ def extract_instruction(
98106 return None
99107
100108
109+ def _extract_text_from_event (event : Any ) -> list [str ]:
110+ """Extracts text parts from an event's message."""
111+ if isinstance (event , tuple ):
112+ results = []
113+ for item in event :
114+ results .extend (_extract_text_from_event (item ))
115+ return results
116+
117+ message = None
118+ if hasattr (event , 'HasField' ):
119+ if event .HasField ('message' ):
120+ message = event .message
121+ elif event .HasField ('task' ) and event .task .status .HasField ('message' ):
122+ message = event .task .status .message
123+ elif event .HasField (
124+ 'status_update'
125+ ) and event .status_update .status .HasField ('message' ):
126+ message = event .status_update .status .message
127+
128+ results = []
129+ if message :
130+ results .extend (part .text for part in message .parts if part .text )
131+ return results
132+
133+
134+ async def _handle_call_agent_with_resubscribe (
135+ client : Client , request : SendMessageRequest
136+ ) -> list [str ]:
137+ """Handles the send-disconnect-resubscribe flow."""
138+ results = []
139+ logger .info ('Executing re-subscribe behavior' )
140+ agen = client .send_message (request )
141+ task_id = None
142+
143+ async for event in agen :
144+ logger .info ('Event before disconnect: %s' , event )
145+ if event .HasField ('task' ):
146+ task_id = event .task .id
147+ elif event .HasField ('status_update' ):
148+ task_id = event .status_update .task_id
149+ break
150+
151+ await agen .aclose ()
152+ logger .info ('Disconnected from task %s. Now re-subscribing.' , task_id )
153+
154+ resub_agen = client .subscribe (SubscribeToTaskRequest (id = task_id ))
155+
156+ task_obj = None
157+ finished = False
158+ async for event in resub_agen :
159+ logger .info ('Event after re-subscribe: %s' , event )
160+ if hasattr (event , 'task' ):
161+ task_obj = event .task
162+ elif hasattr (event , 'HasField' ) and event .HasField ('task' ):
163+ task_obj = event .task
164+
165+ extracted_text = _extract_text_from_event (event )
166+ for text in extracted_text :
167+ processed_text = text .replace ('task-finished' , '' )
168+ results .append (processed_text )
169+ if any ('task-finished' in text for text in extracted_text ):
170+ logger .info (
171+ 'Received task-finished after re-subscribe, breaking loop.'
172+ )
173+ finished = True
174+ break
175+
176+ if not results and task_obj and hasattr (task_obj , 'history' ):
177+ logger .info ('Results empty after loop, reading from history.' )
178+ for msg in task_obj .history :
179+ if msg .role == 'ROLE_AGENT' or msg .role == 'agent' :
180+ for part in msg .parts :
181+ if part .text :
182+ results .append (part .text .replace ('task-finished' , '' ))
183+
184+ if not finished :
185+ logger .info ('Canceling task %s after retrieval.' , task_id )
186+ try :
187+ await client .cancel_task (CancelTaskRequest (id = task_id ))
188+ logger .info ('Task cancelled successfully: %s' , task_id )
189+ except A2AClientError as e :
190+ logger .error ('Failed to cancel task %s: %s' , task_id , str (e ))
191+ raise
192+
193+ return results
194+
195+
101196def wrap_instruction_to_request (inst : instruction_pb2 .Instruction ) -> Message :
102197 """Wraps an Instruction proto into an A2A Message."""
103198 inst_bytes = inst .SerializeToString ()
@@ -129,18 +224,22 @@ async def handle_call_agent(
129224 'GRPC' : TransportProtocol .GRPC ,
130225 }
131226
132- selected_transport = transport_map .get (call .transport .upper ())
227+ selected_transport = transport_map .get (
228+ call .transport .upper (), TransportProtocol .JSONRPC
229+ )
133230 if selected_transport is None :
134231 raise ValueError (f'Unsupported transport: { call .transport } ' )
135232
136233 config = ClientConfig ()
137- config .httpx_client = httpx .AsyncClient (timeout = 30.0 )
138234 config .grpc_channel_factory = grpc .aio .insecure_channel
139235 config .supported_protocol_bindings = [selected_transport ]
140236 config .streaming = call .streaming or (
141237 selected_transport == TransportProtocol .GRPC
142238 )
143239
240+ if call .HasField ('resubscribe' ) and not config .streaming :
241+ raise ValueError ('Re-subscription requires streaming to be enabled' )
242+
144243 if call .HasField ('push_notification' ):
145244 url = call .push_notification .url
146245 if not url :
@@ -149,47 +248,48 @@ async def handle_call_agent(
149248 url = f'http://{ url } '
150249 config .push_notification_config = TaskPushNotificationConfig (
151250 url = f'{ url } /notifications' ,
152- token = 'itk-token' , # noqa: S106
251+ token = 'itk-token' ,
153252 )
154253
155- try :
156- client = await create_client (
157- call .agent_card_uri ,
158- client_config = config ,
159- )
254+ async with httpx .AsyncClient (timeout = 30.0 ) as httpx_client :
255+ config .httpx_client = httpx_client
256+ try :
257+ client = await create_client (
258+ call .agent_card_uri ,
259+ client_config = config ,
260+ )
160261
161- # Wrap nested instruction
162- nested_msg = wrap_instruction_to_request (call .instruction )
163- request = SendMessageRequest (message = nested_msg )
262+ # Wrap nested instruction
263+ nested_msg = wrap_instruction_to_request (call .instruction )
264+ request = SendMessageRequest (message = nested_msg )
164265
165- results = []
166- async for event in client .send_message (request ):
167- # Event is streaming response and task
168- logger .info ('Event: %s' , event )
169- stream_resp = event
170-
171- message = None
172- if stream_resp .HasField ('message' ):
173- message = stream_resp .message
174- elif stream_resp .HasField (
175- 'task'
176- ) and stream_resp .task .status .HasField ('message' ):
177- message = stream_resp .task .status .message
178- elif stream_resp .HasField (
179- 'status_update'
180- ) and stream_resp .status_update .status .HasField ('message' ):
181- message = stream_resp .status_update .status .message
182-
183- if message :
184- results .extend (part .text for part in message .parts if part .text )
185-
186- except Exception as e :
187- logger .exception ('Failed to call outbound agent' )
188- raise RuntimeError (
189- f'Outbound call to { call .agent_card_uri } failed: { e !s} '
190- ) from e
191- else :
192- return results
266+ results = []
267+
268+ if call .HasField ('resubscribe' ):
269+ results .extend (
270+ await _handle_call_agent_with_resubscribe (client , request )
271+ )
272+ else :
273+ async for event in client .send_message (request ):
274+ logger .info ('Event: %s' , event )
275+ results .extend (_extract_text_from_event (event ))
276+
277+ except Exception as e :
278+ logger .exception ('Failed to call outbound agent' )
279+ raise RuntimeError (
280+ f'Outbound call to { call .agent_card_uri } failed: { e !s} '
281+ ) from e
282+ else :
283+ return results
284+
285+
286+ def _should_hold (inst : instruction_pb2 .Instruction ) -> bool :
287+ """Recursively checks if any part of the instruction requests holding the task."""
288+ if inst .HasField ('return_response' ) and inst .return_response .hold_task :
289+ return True
290+ if inst .HasField ('steps' ):
291+ return any (_should_hold (step ) for step in inst .steps .instructions )
292+ return False
193293
194294
195295async def handle_instruction (
@@ -245,23 +345,57 @@ async def execute(
245345 )
246346 return
247347
348+ should_hold_task = _should_hold (instruction )
349+
248350 try :
249351 logger .info ('Instruction: %s' , instruction )
250352 results = await handle_instruction (instruction )
353+
251354 response_text = '\n ' .join (results )
252355 logger .info ('Response: %s' , response_text )
253- await task_updater .update_status (
254- TaskState .TASK_STATE_COMPLETED ,
255- message = task_updater .new_agent_message (
256- [Part (text = response_text )]
257- ),
258- )
259- logger .info ('Task %s completed' , context .task_id )
356+
357+ if should_hold_task :
358+ logger .info ('Holding task %s as requested' , context .task_id )
359+ # Emitted event: response + task-finished
360+ logger .info (
361+ 'Emitting response and task-finished for held task %s' , context .task_id
362+ )
363+ await task_updater .update_status (
364+ TaskState .TASK_STATE_WORKING ,
365+ message = task_updater .new_agent_message (
366+ [Part (text = response_text + '\n task-finished' )]
367+ ),
368+ )
369+ await asyncio .sleep (2 )
370+
371+ # Continue emitting "task-finished" every 2 seconds
372+ try :
373+ while True :
374+ logger .info (
375+ 'Emitting periodic status update for held task %s' ,
376+ context .task_id ,
377+ )
378+ await task_updater .update_status (
379+ TaskState .TASK_STATE_WORKING ,
380+ message = None ,
381+ )
382+ await asyncio .sleep (2 )
383+ except asyncio .CancelledError :
384+ logger .info ('Task %s cancelled' , context .task_id )
385+ return
386+ else :
387+ await task_updater .update_status (
388+ TaskState .TASK_STATE_COMPLETED ,
389+ message = task_updater .new_agent_message (
390+ [Part (text = response_text )]
391+ ),
392+ )
393+ logger .info ('Task %s completed' , context .task_id )
260394 except Exception as e :
261395 logger .exception ('Error during instruction handling' )
262396 await task_updater .update_status (
263397 TaskState .TASK_STATE_FAILED ,
264- message = task_updater . new_agent_message ([ Part ( text = str ( e ))]) ,
398+ message = None ,
265399 )
266400
267401 async def cancel (
@@ -325,19 +459,19 @@ async def main_async(http_port: int, grpc_port: int) -> None:
325459 name = 'ITK v10 Agent' ,
326460 description = 'Python agent using SDK 1.0.' ,
327461 version = '1.0.0' ,
328- capabilities = AgentCapabilities (
329- streaming = True , push_notifications = True , extended_agent_card = True
330- ),
462+ capabilities = AgentCapabilities (streaming = True ),
331463 default_input_modes = ['text/plain' ],
332464 default_output_modes = ['text/plain' ],
333465 supported_interfaces = interfaces ,
334466 )
335467
336468 task_store = InMemoryTaskStore ()
337469 push_config_store = InMemoryPushNotificationConfigStore ()
470+ httpx_client = httpx .AsyncClient ()
338471 push_sender = BasePushNotificationSender (
339- httpx_client = httpx . AsyncClient () ,
472+ httpx_client = httpx_client ,
340473 config_store = push_config_store ,
474+ context = ServerCallContext (),
341475 )
342476
343477 handler = DefaultRequestHandler (
@@ -396,10 +530,22 @@ async def main_async(http_port: int, grpc_port: int) -> None:
396530 )
397531
398532 config = uvicorn .Config (
399- app , host = '127.0.0.1' , port = http_port , log_level = log_level_str . lower ()
533+ app , host = '127.0.0.1' , port = http_port , log_level = 'info'
400534 )
401535 uvicorn_server = uvicorn .Server (config )
402536
537+ # Signal handling
538+ loop = asyncio .get_running_loop ()
539+
540+ async def shutdown () -> None :
541+ logger .info ('Shutting down...' )
542+ uvicorn_server .should_exit = True
543+ await server .stop (5 )
544+ await httpx_client .aclose ()
545+
546+ for sig in (signal .SIGINT , signal .SIGTERM ):
547+ loop .add_signal_handler (sig , lambda : asyncio .create_task (shutdown ()))
548+
403549 await uvicorn_server .serve ()
404550
405551
0 commit comments