5757 TaskState .rejected ,
5858}
5959
60+
6061@trace_class (kind = SpanKind .SERVER )
6162class DefaultRequestHandler (RequestHandler ):
6263 """Default request handler for all incoming requests.
@@ -173,23 +174,25 @@ async def _run_event_stream(
173174 await self .agent_executor .execute (request , queue )
174175 await queue .close ()
175176
176- async def on_message_send (
177+ async def _setup_message_execution (
177178 self ,
178179 params : MessageSendParams ,
179180 context : ServerCallContext | None = None ,
180- ) -> Message | Task :
181- """Default handler for 'message/send' interface ( non-streaming) .
181+ ) -> tuple [ TaskManager , str , EventQueue , ResultAggregator , asyncio . Task ] :
182+ """Common setup logic for both streaming and non-streaming message handling .
182183
183- Starts the agent execution for the message and waits for the final
184- result (Task or Message).
184+ Returns:
185+ A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
185186 """
187+ # Create task manager and validate existing task
186188 task_manager = TaskManager (
187189 task_id = params .message .taskId ,
188190 context_id = params .message .contextId ,
189191 task_store = self .task_store ,
190192 initial_message = params .message ,
191193 )
192194 task : Task | None = await task_manager .get_task ()
195+
193196 if task :
194197 if task .status .state in TERMINAL_TASK_STATES :
195198 raise ServerError (
@@ -211,6 +214,8 @@ async def on_message_send(
211214 await self ._push_config_store .set_info (
212215 task .id , params .configuration .pushNotificationConfig
213216 )
217+
218+ # Build request context
214219 request_context = await self ._request_context_builder .build (
215220 params = params ,
216221 task_id = task .id if task else None ,
@@ -227,13 +232,49 @@ async def on_message_send(
227232 result_aggregator = ResultAggregator (task_manager )
228233 # TODO: to manage the non-blocking flows.
229234 producer_task = asyncio .create_task (
230- self ._run_event_stream (
231- request_context ,
232- queue ,
233- )
235+ self ._run_event_stream (request_context , queue )
234236 )
235237 await self ._register_producer (task_id , producer_task )
236238
239+ return task_manager , task_id , queue , result_aggregator , producer_task
240+
241+ def _validate_task_id_match (self , task_id : str , event_task_id : str ) -> None :
242+ """Validates that agent-generated task ID matches the expected task ID."""
243+ if task_id != event_task_id :
244+ logger .error (
245+ f'Agent generated task_id={ event_task_id } does not match the RequestContext task_id={ task_id } .'
246+ )
247+ raise ServerError (
248+ InternalError (message = 'Task ID mismatch in agent response' )
249+ )
250+
251+ async def _send_push_notification_if_needed (
252+ self , task_id : str , result_aggregator : ResultAggregator
253+ ) -> None :
254+ """Sends push notification if configured and task is available."""
255+ if self ._push_sender and task_id :
256+ latest_task = await result_aggregator .current_result
257+ if isinstance (latest_task , Task ):
258+ await self ._push_sender .send_notification (latest_task )
259+
260+ async def on_message_send (
261+ self ,
262+ params : MessageSendParams ,
263+ context : ServerCallContext | None = None ,
264+ ) -> Message | Task :
265+ """Default handler for 'message/send' interface (non-streaming).
266+
267+ Starts the agent execution for the message and waits for the final
268+ result (Task or Message).
269+ """
270+ (
271+ task_manager ,
272+ task_id ,
273+ queue ,
274+ result_aggregator ,
275+ producer_task ,
276+ ) = await self ._setup_message_execution (params , context )
277+
237278 consumer = EventConsumer (queue )
238279 producer_task .add_done_callback (consumer .agent_task_callback )
239280
@@ -246,18 +287,20 @@ async def on_message_send(
246287 if not result :
247288 raise ServerError (error = InternalError ())
248289
249- if isinstance (result , Task ) and task_id != result .id :
250- logger .error (
251- f'Agent generated task_id={ result .id } does not match the RequestContext task_id={ task_id } .'
252- )
253- raise ServerError (
254- InternalError (message = 'Task ID mismatch in agent response' )
255- )
290+ if isinstance (result , Task ):
291+ self ._validate_task_id_match (task_id , result .id )
292+
293+ await self ._send_push_notification_if_needed (
294+ task_id , result_aggregator
295+ )
256296
297+ except Exception as e :
298+ logger .error (f'Agent execution failed. Error: { e } ' )
299+ raise
257300 finally :
258301 if interrupted :
259302 # TODO: Track this disconnected cleanup task.
260- asyncio .create_task ( # noqa: RUF006
303+ asyncio .create_task ( # noqa: RUF006
261304 self ._cleanup_producer (producer_task , task_id )
262305 )
263306 else :
@@ -275,85 +318,34 @@ async def on_message_send_stream(
275318 Starts the agent execution and yields events as they are produced
276319 by the agent.
277320 """
278- task_manager = TaskManager (
279- task_id = params .message .taskId ,
280- context_id = params .message .contextId ,
281- task_store = self .task_store ,
282- initial_message = params .message ,
283- )
284- task : Task | None = await task_manager .get_task ()
285-
286- if task :
287- if task .status .state in TERMINAL_TASK_STATES :
288- raise ServerError (
289- error = InvalidParamsError (
290- message = f'Task { task .id } is in terminal state: { task .status .state } '
291- )
292- )
293-
294- task = task_manager .update_with_message (params .message , task )
295- if self .should_add_push_info (params ):
296- assert self ._push_config_store is not None
297- assert isinstance (
298- params .configuration , MessageSendConfiguration
299- )
300- assert isinstance (
301- params .configuration .pushNotificationConfig ,
302- PushNotificationConfig ,
303- )
304- await self ._push_config_store .set_info (
305- task .id , params .configuration .pushNotificationConfig
306- )
307- else :
308- queue = EventQueue ()
309- result_aggregator = ResultAggregator (task_manager )
310- request_context = await self ._request_context_builder .build (
311- params = params ,
312- task_id = task .id if task else None ,
313- context_id = params .message .contextId ,
314- task = task ,
315- context = context ,
316- )
317-
318- task_id = cast ('str' , request_context .task_id )
319- queue = await self ._queue_manager .create_or_tap (task_id )
320- producer_task = asyncio .create_task (
321- self ._run_event_stream (
322- request_context ,
323- queue ,
324- )
325- )
326- await self ._register_producer (task_id , producer_task )
321+ (
322+ task_manager ,
323+ task_id ,
324+ queue ,
325+ result_aggregator ,
326+ producer_task ,
327+ ) = await self ._setup_message_execution (params , context )
327328
328329 try :
329330 consumer = EventConsumer (queue )
330331 producer_task .add_done_callback (consumer .agent_task_callback )
331332 async for event in result_aggregator .consume_and_emit (consumer ):
332333 if isinstance (event , Task ):
333- if task_id != event .id :
334- logger .error (
335- f'Agent generated task_id={ event .id } does not match the RequestContext task_id={ task_id } .'
336- )
337- raise ServerError (
338- InternalError (
339- message = 'Task ID mismatch in agent response'
340- )
341- )
342-
343- if (
344- self ._push_config_store # Check if store is available for config
345- and params .configuration
346- and params .configuration .pushNotificationConfig
347- ):
348- await self ._push_config_store .set_info (
349- task_id ,
350- params .configuration .pushNotificationConfig ,
351- )
352-
353- if self ._push_sender and task_id : # Check if sender is available
354- latest_task = await result_aggregator .current_result
355- if isinstance (latest_task , Task ):
356- await self ._push_sender .send_notification (latest_task )
334+ self ._validate_task_id_match (task_id , event .id )
335+
336+ if (
337+ self ._push_config_store
338+ and params .configuration
339+ and params .configuration .pushNotificationConfig
340+ ):
341+ await self ._push_config_store .set_info (
342+ task_id ,
343+ params .configuration .pushNotificationConfig ,
344+ )
345+
346+ await self ._send_push_notification_if_needed (
347+ task_id , result_aggregator
348+ )
357349 yield event
358350 finally :
359351 await self ._cleanup_producer (producer_task , task_id )
@@ -415,7 +407,9 @@ async def on_get_task_push_notification_config(
415407 if not task :
416408 raise ServerError (error = TaskNotFoundError ())
417409
418- push_notification_config = await self ._push_config_store .get_info (params .id )
410+ push_notification_config = await self ._push_config_store .get_info (
411+ params .id
412+ )
419413 if not push_notification_config or not push_notification_config [0 ]:
420414 raise ServerError (error = InternalError ())
421415
@@ -477,14 +471,18 @@ async def on_list_task_push_notification_config(
477471 if not task :
478472 raise ServerError (error = TaskNotFoundError ())
479473
480- push_notification_config_list = await self ._push_config_store .get_info (params .id )
474+ push_notification_config_list = await self ._push_config_store .get_info (
475+ params .id
476+ )
481477
482478 task_push_notification_config = []
483479 if push_notification_config_list :
484480 for config in push_notification_config_list :
485- task_push_notification_config .append (TaskPushNotificationConfig (
486- taskId = params .id , pushNotificationConfig = config
487- ))
481+ task_push_notification_config .append (
482+ TaskPushNotificationConfig (
483+ taskId = params .id , pushNotificationConfig = config
484+ )
485+ )
488486
489487 return task_push_notification_config
490488
@@ -504,7 +502,9 @@ async def on_delete_task_push_notification_config(
504502 if not task :
505503 raise ServerError (error = TaskNotFoundError ())
506504
507- await self ._push_config_store .delete_info (params .id , params .pushNotificationConfigId )
505+ await self ._push_config_store .delete_info (
506+ params .id , params .pushNotificationConfigId
507+ )
508508
509509 def should_add_push_info (self , params : MessageSendParams ) -> bool :
510510 """Determines if push notification info should be set for a task."""
0 commit comments