1717from a2a .compat .v0_3 import a2a_v0_3_pb2_grpc
1818from a2a .compat .v0_3 .grpc_handler import CompatGrpcHandler
1919from a2a .server .agent_execution import AgentExecutor , RequestContext
20- from a2a .server .routes import create_agent_card_routes , create_jsonrpc_routes
21- from a2a .server .routes .rest_routes import create_rest_routes
2220from a2a .server .events import EventQueue
21+ from a2a .server .routes import (
22+ create_agent_card_routes ,
23+ create_jsonrpc_routes ,
24+ create_rest_routes ,
25+ )
2326from a2a .server .events .in_memory_queue_manager import InMemoryQueueManager
2427from a2a .server .request_handlers import DefaultRequestHandler , GrpcHandler
25- from a2a .server .tasks import TaskUpdater
28+ from a2a .server .tasks import (
29+ TaskUpdater ,
30+ BasePushNotificationSender ,
31+ InMemoryPushNotificationConfigStore ,
32+ )
2633from a2a .server .tasks .inmemory_task_store import InMemoryTaskStore
34+ from a2a .server .context import ServerCallContext
2735from a2a .types import a2a_pb2_grpc
2836from a2a .types .a2a_pb2 import (
2937 AgentCapabilities ,
3543 Task ,
3644 TaskState ,
3745 TaskStatus ,
46+ TaskPushNotificationConfig ,
3847)
3948from a2a .utils import TransportProtocol
49+ from a2a .server .tasks .push_notification_sender import PushNotificationEvent
4050
41-
42- log_level = os . environ . get ( 'ITK_LOG_LEVEL' , 'INFO' ). upper ( )
51+ log_level_str = os . environ . get ( 'ITK_LOG_LEVEL' , 'INFO' ). upper ()
52+ log_level = getattr ( logging , log_level_str , logging . INFO )
4353logging .basicConfig (level = log_level )
4454logger = logging .getLogger (__name__ )
4555
4656
57+
58+
59+
4760def extract_instruction (
4861 message : Message | None ,
4962) -> instruction_pb2 .Instruction | None :
@@ -65,7 +78,7 @@ def extract_instruction(
6578 # Some clients might send it as base64 in text part
6679 raw = base64 .b64decode (part .text )
6780 inst .ParseFromString (raw )
68- except Exception :
81+ except Exception : # noqa: BLE001
6982 logger .debug (
7083 'Failed to parse instruction from binary part' ,
7184 exc_info = True ,
@@ -80,7 +93,7 @@ def extract_instruction(
8093 raw = base64 .b64decode (part .text )
8194 inst = instruction_pb2 .Instruction ()
8295 inst .ParseFromString (raw )
83- except Exception :
96+ except Exception : # noqa: BLE001
8497 logger .debug (
8598 'Failed to parse instruction from text part' , exc_info = True
8699 )
@@ -106,7 +119,9 @@ def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
106119 )
107120
108121
109- async def handle_call_agent (call : instruction_pb2 .CallAgent ) -> list [str ]:
122+ async def handle_call_agent (
123+ call : instruction_pb2 .CallAgent ,
124+ ) -> list [str ]:
110125 """Handles the CallAgent instruction by invoking another agent."""
111126 logger .info ('Calling agent %s via %s' , call .agent_card_uri , call .transport )
112127
@@ -119,7 +134,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
119134 'GRPC' : TransportProtocol .GRPC ,
120135 }
121136
122- selected_transport = transport_map .get (call .transport .upper ())
137+ selected_transport = transport_map .get (
138+ call .transport .upper (), TransportProtocol .JSONRPC
139+ )
123140 if selected_transport is None :
124141 raise ValueError (f'Unsupported transport: { call .transport } ' )
125142
@@ -131,36 +148,46 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
131148 selected_transport == TransportProtocol .GRPC
132149 )
133150
151+ if call .HasField ('push_notification' ):
152+ url = call .push_notification .url
153+ if not url :
154+ raise ValueError ('URL not specified in push_notification behavior' )
155+ if not url .startswith (('http://' , 'https://' )):
156+ url = f'http://{ url } '
157+ config .push_notification_config = TaskPushNotificationConfig (
158+ url = f'{ url } /notifications' ,
159+ token = 'itk-token' ,
160+ )
161+
134162 try :
135- client = await create_client (call .agent_card_uri , client_config = config )
163+ client = await create_client (
164+ call .agent_card_uri ,
165+ client_config = config ,
166+ )
136167
137168 # Wrap nested instruction
138- async with client :
139- nested_msg = wrap_instruction_to_request (call .instruction )
140- request = SendMessageRequest (message = nested_msg )
141-
142- results : list [str ] = []
143- async for event in client .send_message (request ):
144- # Event is StreamResponse
145- logger .info ('Event: %s' , event )
146- stream_resp = event
147-
148- message = None
149- if stream_resp .HasField ('message' ):
150- message = stream_resp .message
151- elif stream_resp .HasField (
152- 'task'
153- ) and stream_resp .task .status .HasField ('message' ):
154- message = stream_resp .task .status .message
155- elif stream_resp .HasField (
156- 'status_update'
157- ) and stream_resp .status_update .status .HasField ('message' ):
158- message = stream_resp .status_update .status .message
159-
160- if message :
161- results .extend (
162- part .text for part in message .parts if part .text
163- )
169+ nested_msg = wrap_instruction_to_request (call .instruction )
170+ request = SendMessageRequest (message = nested_msg )
171+
172+
173+ results = []
174+ async for event in client .send_message (request ):
175+ # Event is streaming response and task
176+ logger .info ('Event: %s' , event )
177+ stream_resp = event
178+
179+ message = None
180+ if stream_resp .HasField ('message' ):
181+ message = stream_resp .message
182+ elif stream_resp .HasField ('task' ) and stream_resp .task .status .HasField ('message' ):
183+ message = stream_resp .task .status .message
184+ elif stream_resp .HasField (
185+ 'status_update'
186+ ) and stream_resp .status_update .status .HasField ('message' ):
187+ message = stream_resp .status_update .status .message
188+
189+ if message :
190+ results .extend (part .text for part in message .parts if part .text )
164191
165192 except Exception as e :
166193 logger .exception ('Failed to call outbound agent' )
@@ -171,7 +198,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
171198 return results
172199
173200
174- async def handle_instruction (inst : instruction_pb2 .Instruction ) -> list [str ]:
201+ async def handle_instruction (
202+ inst : instruction_pb2 .Instruction ,
203+ ) -> list [str ]:
175204 """Recursively handles instructions."""
176205 if inst .HasField ('call_agent' ):
177206 return await handle_call_agent (inst .call_agent )
@@ -302,34 +331,39 @@ async def main_async(http_port: int, grpc_port: int) -> None:
302331 name = 'ITK v10 Agent' ,
303332 description = 'Python agent using SDK 1.0.' ,
304333 version = '1.0.0' ,
305- capabilities = AgentCapabilities (
306- streaming = True ,
307- push_notifications = True ,
308- extended_agent_card = True ,
309- ),
334+ capabilities = AgentCapabilities (streaming = True ),
310335 default_input_modes = ['text/plain' ],
311336 default_output_modes = ['text/plain' ],
312337 supported_interfaces = interfaces ,
313338 )
314339
315340 task_store = InMemoryTaskStore ()
341+ push_config_store = InMemoryPushNotificationConfigStore ()
342+ push_sender = BasePushNotificationSender (
343+ httpx_client = httpx .AsyncClient (),
344+ config_store = push_config_store ,
345+ context = ServerCallContext (),
346+ )
347+
316348 handler = DefaultRequestHandler (
317349 agent_executor = V10AgentExecutor (),
318- task_store = task_store ,
319350 agent_card = agent_card ,
351+ task_store = task_store ,
320352 queue_manager = InMemoryQueueManager (),
353+ push_config_store = push_config_store ,
354+ push_sender = push_sender ,
321355 )
322356
323357 handler_extended = DefaultRequestHandler (
324358 agent_executor = V10AgentExecutor (),
325- task_store = task_store ,
326359 agent_card = agent_card ,
360+ task_store = task_store ,
327361 queue_manager = InMemoryQueueManager (),
362+ push_config_store = push_config_store ,
363+ push_sender = push_sender ,
328364 extended_agent_card = agent_card ,
329365 )
330366
331- app = FastAPI ()
332-
333367 agent_card_routes = create_agent_card_routes (
334368 agent_card = agent_card , card_url = '/.well-known/agent-card.json'
335369 )
@@ -338,15 +372,16 @@ async def main_async(http_port: int, grpc_port: int) -> None:
338372 rpc_url = '/' ,
339373 enable_v0_3_compat = True ,
340374 )
341- app .mount (
342- '/jsonrpc' ,
343- FastAPI (routes = jsonrpc_routes + agent_card_routes ),
344- )
345-
346375 rest_routes = create_rest_routes (
347376 request_handler = handler ,
348377 enable_v0_3_compat = True ,
349378 )
379+
380+ app = FastAPI ()
381+ app .mount (
382+ '/jsonrpc' ,
383+ FastAPI (routes = jsonrpc_routes + agent_card_routes ),
384+ )
350385 app .mount ('/rest' , FastAPI (routes = rest_routes + agent_card_routes ))
351386
352387 server = grpc .aio .server ()
@@ -365,9 +400,8 @@ async def main_async(http_port: int, grpc_port: int) -> None:
365400 grpc_port ,
366401 )
367402
368- uvicorn_log_level = os .environ .get ('ITK_LOG_LEVEL' , 'INFO' ).lower ()
369403 config = uvicorn .Config (
370- app , host = '127.0.0.1' , port = http_port , log_level = uvicorn_log_level
404+ app , host = '127.0.0.1' , port = http_port , log_level = 'info'
371405 )
372406 uvicorn_server = uvicorn .Server (config )
373407
0 commit comments