1717from a2a .compat .v0_3 import a2a_v0_3_pb2_grpc
1818from a2a .compat .v0_3 .grpc_handler import CompatGrpcHandler
1919from a2a .server .agent_execution import AgentExecutor , RequestContext
20- from a2a .server .routes import create_agent_card_routes , create_jsonrpc_routes
21- from a2a .server .routes .rest_routes import create_rest_routes
2220from a2a .server .events import EventQueue
21+ from a2a .server .routes import (
22+ create_agent_card_routes ,
23+ create_jsonrpc_routes ,
24+ create_rest_routes ,
25+ )
2326from a2a .server .events .in_memory_queue_manager import InMemoryQueueManager
2427from a2a .server .request_handlers import DefaultRequestHandler , GrpcHandler
25- from a2a .server .tasks import TaskUpdater
28+ from a2a .server .tasks import (
29+ TaskUpdater ,
30+ BasePushNotificationSender ,
31+ InMemoryPushNotificationConfigStore ,
32+ )
2633from a2a .server .tasks .inmemory_task_store import InMemoryTaskStore
34+ from a2a .server .context import ServerCallContext
2735from a2a .types import a2a_pb2_grpc
2836from a2a .types .a2a_pb2 import (
2937 AgentCapabilities ,
3543 Task ,
3644 TaskState ,
3745 TaskStatus ,
46+ TaskPushNotificationConfig ,
3847)
3948from a2a .utils import TransportProtocol
4049
41-
42- log_level = os . environ . get ( 'ITK_LOG_LEVEL' , 'INFO' ). upper ( )
50+ log_level_str = os . environ . get ( 'ITK_LOG_LEVEL' , 'INFO' ). upper ()
51+ log_level = getattr ( logging , log_level_str , logging . INFO )
4352logging .basicConfig (level = log_level )
4453logger = logging .getLogger (__name__ )
4554
@@ -106,7 +115,9 @@ def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
106115 )
107116
108117
109- async def handle_call_agent (call : instruction_pb2 .CallAgent ) -> list [str ]:
118+ async def handle_call_agent (
119+ call : instruction_pb2 .CallAgent ,
120+ ) -> list [str ]:
110121 """Handles the CallAgent instruction by invoking another agent."""
111122 logger .info ('Calling agent %s via %s' , call .agent_card_uri , call .transport )
112123
@@ -131,36 +142,47 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
131142 selected_transport == TransportProtocol .GRPC
132143 )
133144
145+ if call .HasField ('push_notification' ):
146+ url = call .push_notification .url
147+ if not url :
148+ raise ValueError ('URL not specified in push_notification behavior' )
149+ if not url .startswith (('http://' , 'https://' )):
150+ url = f'http://{ url } '
151+ config .push_notification_config = TaskPushNotificationConfig (
152+ url = f'{ url } /notifications' ,
153+ token = 'itk-token' , # noqa: S106
154+ )
155+
134156 try :
135- client = await create_client (call .agent_card_uri , client_config = config )
157+ client = await create_client (
158+ call .agent_card_uri ,
159+ client_config = config ,
160+ )
136161
137162 # Wrap nested instruction
138- async with client :
139- nested_msg = wrap_instruction_to_request (call .instruction )
140- request = SendMessageRequest (message = nested_msg )
141-
142- results : list [str ] = []
143- async for event in client .send_message (request ):
144- # Event is StreamResponse
145- logger .info ('Event: %s' , event )
146- stream_resp = event
147-
148- message = None
149- if stream_resp .HasField ('message' ):
150- message = stream_resp .message
151- elif stream_resp .HasField (
152- 'task'
153- ) and stream_resp .task .status .HasField ('message' ):
154- message = stream_resp .task .status .message
155- elif stream_resp .HasField (
156- 'status_update'
157- ) and stream_resp .status_update .status .HasField ('message' ):
158- message = stream_resp .status_update .status .message
159-
160- if message :
161- results .extend (
162- part .text for part in message .parts if part .text
163- )
163+ nested_msg = wrap_instruction_to_request (call .instruction )
164+ request = SendMessageRequest (message = nested_msg )
165+
166+ results = []
167+ async for event in client .send_message (request ):
168+ # Event is streaming response and task
169+ logger .info ('Event: %s' , event )
170+ stream_resp = event
171+
172+ message = None
173+ if stream_resp .HasField ('message' ):
174+ message = stream_resp .message
175+ elif stream_resp .HasField (
176+ 'task'
177+ ) and stream_resp .task .status .HasField ('message' ):
178+ message = stream_resp .task .status .message
179+ elif stream_resp .HasField (
180+ 'status_update'
181+ ) and stream_resp .status_update .status .HasField ('message' ):
182+ message = stream_resp .status_update .status .message
183+
184+ if message :
185+ results .extend (part .text for part in message .parts if part .text )
164186
165187 except Exception as e :
166188 logger .exception ('Failed to call outbound agent' )
@@ -171,7 +193,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
171193 return results
172194
173195
174- async def handle_instruction (inst : instruction_pb2 .Instruction ) -> list [str ]:
196+ async def handle_instruction (
197+ inst : instruction_pb2 .Instruction ,
198+ ) -> list [str ]:
175199 """Recursively handles instructions."""
176200 if inst .HasField ('call_agent' ):
177201 return await handle_call_agent (inst .call_agent )
@@ -303,33 +327,40 @@ async def main_async(http_port: int, grpc_port: int) -> None:
303327 description = 'Python agent using SDK 1.0.' ,
304328 version = '1.0.0' ,
305329 capabilities = AgentCapabilities (
306- streaming = True ,
307- push_notifications = True ,
308- extended_agent_card = True ,
330+ streaming = True , push_notifications = True , extended_agent_card = True
309331 ),
310332 default_input_modes = ['text/plain' ],
311333 default_output_modes = ['text/plain' ],
312334 supported_interfaces = interfaces ,
313335 )
314336
315337 task_store = InMemoryTaskStore ()
338+ push_config_store = InMemoryPushNotificationConfigStore ()
339+ push_sender = BasePushNotificationSender (
340+ httpx_client = httpx .AsyncClient (),
341+ config_store = push_config_store ,
342+ context = ServerCallContext (),
343+ )
344+
316345 handler = DefaultRequestHandler (
317346 agent_executor = V10AgentExecutor (),
318- task_store = task_store ,
319347 agent_card = agent_card ,
348+ task_store = task_store ,
320349 queue_manager = InMemoryQueueManager (),
350+ push_config_store = push_config_store ,
351+ push_sender = push_sender ,
321352 )
322353
323354 handler_extended = DefaultRequestHandler (
324355 agent_executor = V10AgentExecutor (),
325- task_store = task_store ,
326356 agent_card = agent_card ,
357+ task_store = task_store ,
327358 queue_manager = InMemoryQueueManager (),
359+ push_config_store = push_config_store ,
360+ push_sender = push_sender ,
328361 extended_agent_card = agent_card ,
329362 )
330363
331- app = FastAPI ()
332-
333364 agent_card_routes = create_agent_card_routes (
334365 agent_card = agent_card , card_url = '/.well-known/agent-card.json'
335366 )
@@ -338,15 +369,16 @@ async def main_async(http_port: int, grpc_port: int) -> None:
338369 rpc_url = '/' ,
339370 enable_v0_3_compat = True ,
340371 )
341- app .mount (
342- '/jsonrpc' ,
343- FastAPI (routes = jsonrpc_routes + agent_card_routes ),
344- )
345-
346372 rest_routes = create_rest_routes (
347373 request_handler = handler ,
348374 enable_v0_3_compat = True ,
349375 )
376+
377+ app = FastAPI ()
378+ app .mount (
379+ '/jsonrpc' ,
380+ FastAPI (routes = jsonrpc_routes + agent_card_routes ),
381+ )
350382 app .mount ('/rest' , FastAPI (routes = rest_routes + agent_card_routes ))
351383
352384 server = grpc .aio .server ()
@@ -365,9 +397,8 @@ async def main_async(http_port: int, grpc_port: int) -> None:
365397 grpc_port ,
366398 )
367399
368- uvicorn_log_level = os .environ .get ('ITK_LOG_LEVEL' , 'INFO' ).lower ()
369400 config = uvicorn .Config (
370- app , host = '127.0.0.1' , port = http_port , log_level = uvicorn_log_level
401+ app , host = '127.0.0.1' , port = http_port , log_level = log_level_str . lower ()
371402 )
372403 uvicorn_server = uvicorn .Server (config )
373404
0 commit comments