5555 TaskState .rejected ,
5656}
5757
58+
5859@trace_class (kind = SpanKind .SERVER )
5960class DefaultRequestHandler (RequestHandler ):
6061 """Default request handler for all incoming requests.
@@ -168,23 +169,25 @@ async def _run_event_stream(
168169 await self .agent_executor .execute (request , queue )
169170 await queue .close ()
170171
171- async def on_message_send (
172+ async def _setup_message_execution (
172173 self ,
173174 params : MessageSendParams ,
174175 context : ServerCallContext | None = None ,
175- ) -> Message | Task :
176- """Default handler for 'message/send' interface ( non-streaming) .
176+ ) -> tuple [ TaskManager , str , EventQueue , ResultAggregator , asyncio . Task ] :
177+ """Common setup logic for both streaming and non-streaming message handling .
177178
178- Starts the agent execution for the message and waits for the final
179- result (Task or Message).
179+ Returns:
180+ A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
180181 """
182+ # Create task manager and validate existing task
181183 task_manager = TaskManager (
182184 task_id = params .message .taskId ,
183185 context_id = params .message .contextId ,
184186 task_store = self .task_store ,
185187 initial_message = params .message ,
186188 )
187189 task : Task | None = await task_manager .get_task ()
190+
188191 if task :
189192 if task .status .state in TERMINAL_TASK_STATES :
190193 raise ServerError (
@@ -206,6 +209,8 @@ async def on_message_send(
206209 await self ._push_notifier .set_info (
207210 task .id , params .configuration .pushNotificationConfig
208211 )
212+
213+ # Build request context
209214 request_context = await self ._request_context_builder .build (
210215 params = params ,
211216 task_id = task .id if task else None ,
@@ -222,13 +227,49 @@ async def on_message_send(
222227 result_aggregator = ResultAggregator (task_manager )
223228 # TODO: to manage the non-blocking flows.
224229 producer_task = asyncio .create_task (
225- self ._run_event_stream (
226- request_context ,
227- queue ,
228- )
230+ self ._run_event_stream (request_context , queue )
229231 )
230232 await self ._register_producer (task_id , producer_task )
231233
234+ return task_manager , task_id , queue , result_aggregator , producer_task
235+
236+ def _validate_task_id_match (self , task_id : str , event_task_id : str ) -> None :
237+ """Validates that agent-generated task ID matches the expected task ID."""
238+ if task_id != event_task_id :
239+ logger .error (
240+ f'Agent generated task_id={ event_task_id } does not match the RequestContext task_id={ task_id } .'
241+ )
242+ raise ServerError (
243+ InternalError (message = 'Task ID mismatch in agent response' )
244+ )
245+
246+ async def _send_push_notification_if_needed (
247+ self , task_id : str , result_aggregator : ResultAggregator
248+ ) -> None :
249+ """Sends push notification if configured and task is available."""
250+ if self ._push_notifier and task_id :
251+ latest_task = await result_aggregator .current_result
252+ if isinstance (latest_task , Task ):
253+ await self ._push_notifier .send_notification (latest_task )
254+
255+ async def on_message_send (
256+ self ,
257+ params : MessageSendParams ,
258+ context : ServerCallContext | None = None ,
259+ ) -> Message | Task :
260+ """Default handler for 'message/send' interface (non-streaming).
261+
262+ Starts the agent execution for the message and waits for the final
263+ result (Task or Message).
264+ """
265+ (
266+ task_manager ,
267+ task_id ,
268+ queue ,
269+ result_aggregator ,
270+ producer_task ,
271+ ) = await self ._setup_message_execution (params , context )
272+
232273 consumer = EventConsumer (queue )
233274 producer_task .add_done_callback (consumer .agent_task_callback )
234275
@@ -241,14 +282,16 @@ async def on_message_send(
241282 if not result :
242283 raise ServerError (error = InternalError ())
243284
244- if isinstance (result , Task ) and task_id != result .id :
245- logger .error (
246- f'Agent generated task_id={ result .id } does not match the RequestContext task_id={ task_id } .'
247- )
248- raise ServerError (
249- InternalError (message = 'Task ID mismatch in agent response' )
250- )
285+ if isinstance (result , Task ):
286+ self ._validate_task_id_match (task_id , result .id )
251287
288+ await self ._send_push_notification_if_needed (
289+ task_id , result_aggregator
290+ )
291+
292+ except Exception as e :
293+ logger .error (f'Agent execution failed. Error: { e } ' )
294+ raise
252295 finally :
253296 if interrupted :
254297 # TODO: Track this disconnected cleanup task.
@@ -270,85 +313,34 @@ async def on_message_send_stream(
270313 Starts the agent execution and yields events as they are produced
271314 by the agent.
272315 """
273- task_manager = TaskManager (
274- task_id = params .message .taskId ,
275- context_id = params .message .contextId ,
276- task_store = self .task_store ,
277- initial_message = params .message ,
278- )
279- task : Task | None = await task_manager .get_task ()
280-
281- if task :
282- if task .status .state in TERMINAL_TASK_STATES :
283- raise ServerError (
284- error = InvalidParamsError (
285- message = f'Task { task .id } is in terminal state: { task .status .state } '
286- )
287- )
288-
289- task = task_manager .update_with_message (params .message , task )
290- if self .should_add_push_info (params ):
291- assert isinstance (self ._push_notifier , PushNotifier )
292- assert isinstance (
293- params .configuration , MessageSendConfiguration
294- )
295- assert isinstance (
296- params .configuration .pushNotificationConfig ,
297- PushNotificationConfig ,
298- )
299- await self ._push_notifier .set_info (
300- task .id , params .configuration .pushNotificationConfig
301- )
302- else :
303- queue = EventQueue ()
304- result_aggregator = ResultAggregator (task_manager )
305- request_context = await self ._request_context_builder .build (
306- params = params ,
307- task_id = task .id if task else None ,
308- context_id = params .message .contextId ,
309- task = task ,
310- context = context ,
311- )
312-
313- task_id = cast ('str' , request_context .task_id )
314- queue = await self ._queue_manager .create_or_tap (task_id )
315- producer_task = asyncio .create_task (
316- self ._run_event_stream (
317- request_context ,
318- queue ,
319- )
320- )
321- await self ._register_producer (task_id , producer_task )
316+ (
317+ task_manager ,
318+ task_id ,
319+ queue ,
320+ result_aggregator ,
321+ producer_task ,
322+ ) = await self ._setup_message_execution (params , context )
322323
323324 try :
324325 consumer = EventConsumer (queue )
325326 producer_task .add_done_callback (consumer .agent_task_callback )
326327 async for event in result_aggregator .consume_and_emit (consumer ):
327328 if isinstance (event , Task ):
328- if task_id != event .id :
329- logger .error (
330- f'Agent generated task_id={ event .id } does not match the RequestContext task_id={ task_id } .'
331- )
332- raise ServerError (
333- InternalError (
334- message = 'Task ID mismatch in agent response'
335- )
336- )
337-
338- if (
339- self ._push_notifier
340- and params .configuration
341- and params .configuration .pushNotificationConfig
342- ):
343- await self ._push_notifier .set_info (
344- task_id ,
345- params .configuration .pushNotificationConfig ,
346- )
347-
348- if self ._push_notifier and task_id :
349- latest_task = await result_aggregator .current_result
350- if isinstance (latest_task , Task ):
351- await self ._push_notifier .send_notification (latest_task )
329+ self ._validate_task_id_match (task_id , event .id )
330+
331+ if (
332+ self ._push_notifier
333+ and params .configuration
334+ and params .configuration .pushNotificationConfig
335+ ):
336+ await self ._push_notifier .set_info (
337+ task_id ,
338+ params .configuration .pushNotificationConfig ,
339+ )
340+
341+ await self ._send_push_notification_if_needed (
342+ task_id , result_aggregator
343+ )
352344 yield event
353345 finally :
354346 await self ._cleanup_producer (producer_task , task_id )
0 commit comments