55import logging
66import uuid
77
8- from typing import TYPE_CHECKING , cast
8+ from typing import TYPE_CHECKING , Any , cast
99
1010from a2a .server .agent_execution .context import RequestContext
1111
5656}
5757
5858
59+ class _RequestStarted :
60+ def __init__ (self , request_id : uuid .UUID , request_context : RequestContext ):
61+ self .request_id = request_id
62+ self .request_context = request_context
63+
64+
5965class _RequestCompleted :
6066 def __init__ (self , request_id : uuid .UUID ):
6167 self .request_id = request_id
@@ -199,25 +205,13 @@ async def start(
199205 logger .debug ('TASK (start): %s' , task )
200206
201207 if task :
208+ self ._task_created .set ()
202209 if task .status .state in TERMINAL_TASK_STATES :
203210 raise InvalidParamsError (
204211 message = f'Task { task .id } is in terminal state: { task .status .state } '
205212 )
206- else :
207- if not create_task_if_missing :
208- raise TaskNotFoundError
209-
210- # New task. Create and save it so it's not "missing" if queried immediately
211- # (especially important for return_immediately=True)
212- if self ._task_manager .context_id is None :
213- raise ValueError ('Context ID is required for new tasks' )
214- task = self ._task_manager ._init_task_obj (
215- self ._task_id ,
216- self ._task_manager .context_id ,
217- )
218- await self ._task_manager .save_task_event (task )
219- if self ._push_sender :
220- await self ._push_sender .send_notification (task .id , task )
213+ elif not create_task_if_missing :
214+ raise TaskNotFoundError
221215
222216 except Exception :
223217 logger .debug (
@@ -253,72 +247,72 @@ async def _run_producer(self) -> None:
253247 Runs as a detached asyncio.Task. Safe to cancel.
254248 """
255249 logger .debug ('Producer[%s]: Started' , self ._task_id )
250+ request_context = None
256251 try :
257- active = True
258- while active :
252+ while True :
259253 (
260254 request_context ,
261255 request_id ,
262256 ) = await self ._request_queue .get ()
263257 await self ._request_lock .acquire ()
264258 # TODO: Should we create task manager every time?
265259 self ._task_manager ._call_context = request_context .call_context
260+
266261 request_context .current_task = (
267262 await self ._task_manager .get_task ()
268263 )
269264
270- message = request_context .message
271- if message :
272- request_context .current_task = (
273- self ._task_manager .update_with_message (
274- message ,
275- cast ('Task' , request_context .current_task ),
276- )
277- )
278- await self ._task_manager .save_task_event (
279- request_context .current_task
280- )
281- self ._task_created .set ()
282265 logger .debug (
283266 'Producer[%s]: Executing agent task %s' ,
284267 self ._task_id ,
285268 request_context .current_task ,
286269 )
287270
288271 try :
272+ await self ._event_queue_agent .enqueue_event (
273+ cast (
274+ 'Event' ,
275+ _RequestStarted (request_id , request_context ),
276+ )
277+ )
278+
289279 await self ._agent_executor .execute (
290280 request_context , self ._event_queue_agent
291281 )
292282 logger .debug (
293283 'Producer[%s]: Execution finished successfully' ,
294284 self ._task_id ,
295285 )
296- except QueueShutDown :
297- logger .debug (
298- 'Producer[%s]: Request queue shut down' , self ._task_id
299- )
300- raise
301- except asyncio .CancelledError :
302- logger .debug ('Producer[%s]: Cancelled' , self ._task_id )
303- raise
304- except Exception as e :
305- logger .exception (
306- 'Producer[%s]: Execution failed' ,
307- self ._task_id ,
308- )
309- async with self ._lock :
310- await self ._mark_task_as_failed (e )
311- active = False
312286 finally :
313287 logger .debug (
314288 'Producer[%s]: Enqueuing request completed event' ,
315289 self ._task_id ,
316290 )
317- # TODO: Hide from external consumers
318291 await self ._event_queue_agent .enqueue_event (
319292 cast ('Event' , _RequestCompleted (request_id ))
320293 )
321294 self ._request_queue .task_done ()
295+ except asyncio .CancelledError :
296+ logger .debug ('Producer[%s]: Cancelled' , self ._task_id )
297+
298+ except QueueShutDown :
299+ logger .debug ('Producer[%s]: Queue shut down' , self ._task_id )
300+
301+ except Exception as e :
302+ logger .exception (
303+ 'Producer[%s]: Execution failed' ,
304+ self ._task_id ,
305+ )
306+ # Create task and mark as failed.
307+ if request_context :
308+ await self ._task_manager .ensure_task_id (
309+ self ._task_id ,
310+ request_context .context_id or '' ,
311+ )
312+ self ._task_created .set ()
313+ async with self ._lock :
314+ await self ._mark_task_as_failed (e )
315+
322316 finally :
323317 self ._request_queue .shutdown (immediate = True )
324318 await self ._event_queue_agent .close (immediate = False )
@@ -338,6 +332,8 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
338332 `_is_finished`, unblocking all global subscribers and wait() calls.
339333 """
340334 logger .debug ('Consumer[%s]: Started' , self ._task_id )
335+ task_mode = None
336+ message_to_save = None
341337 try :
342338 try :
343339 try :
@@ -347,6 +343,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
347343 'Consumer[%s]: Waiting for event' ,
348344 self ._task_id ,
349345 )
346+ new_task = None
350347 event = await self ._event_queue_agent .dequeue_event ()
351348 logger .debug (
352349 'Consumer[%s]: Dequeued event %s' ,
@@ -361,15 +358,60 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
361358 self ._task_id ,
362359 )
363360 self ._request_lock .release ()
361+ elif isinstance (event , _RequestStarted ):
362+ logger .debug (
363+ 'Consumer[%s]: Request started' ,
364+ self ._task_id ,
365+ )
366+ message_to_save = event .request_context .message
367+
364368 elif isinstance (event , Message ):
369+ if task_mode is not None :
370+ if task_mode :
371+ logger .error (
372+ 'Received Message() object in task mode.'
373+ )
374+ else :
375+ logger .error (
376+ 'Multiple Message() objects received.'
377+ )
378+ task_mode = False
365379 logger .debug (
366380 'Consumer[%s]: Setting result to Message: %s' ,
367381 self ._task_id ,
368382 event ,
369383 )
370384 else :
385+ if task_mode is False :
386+ logger .error (
387+ 'Received %s in message mode.' ,
388+ type (event ).__name__ ,
389+ )
390+
391+ new_task = (
392+ await self ._task_manager .ensure_task_id (
393+ self ._task_id ,
394+ event .context_id ,
395+ )
396+ )
397+
398+ if message_to_save is not None :
399+ new_task = (
400+ self ._task_manager .update_with_message (
401+ message_to_save ,
402+ new_task ,
403+ )
404+ )
405+ await self ._task_manager .save_task_event (
406+ new_task
407+ )
408+ message_to_save = None
409+
410+ self ._task_created .set ()
411+
412+ task_mode = True
371413 # Save structural events (like TaskStatusUpdate) to DB.
372- # TODO: Create task manager every time ?
414+
373415 self ._task_manager .context_id = event .context_id
374416 await self ._task_manager .process (event )
375417
@@ -432,8 +474,19 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
432474 self ._task_id , event
433475 )
434476 finally :
477+ if new_task is not None :
478+ new_task_copy = Task ()
479+ new_task_copy .CopyFrom (new_task )
480+ new_task = new_task_copy
481+
482+ logger .debug (
483+ 'Consumer[%s]: Enqueuing\n Event: %s\n New Task: %s\n ' ,
484+ self ._task_id ,
485+ event ,
486+ new_task ,
487+ )
435488 await self ._event_queue_subscribers .enqueue_event (
436- event
489+ cast ( 'Any' , ( event , new_task ))
437490 )
438491 self ._event_queue_agent .task_done ()
439492 except QueueShutDown :
@@ -459,6 +512,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
459512 * ,
460513 request : RequestContext | None = None ,
461514 include_initial_task : bool = False ,
515+ replace_status_update_with_task : bool = False ,
462516 ) -> AsyncGenerator [Event , None ]:
463517 """Creates a queue tap and yields events as they are produced.
464518
@@ -506,9 +560,25 @@ async def subscribe( # noqa: PLR0912, PLR0915
506560
507561 # Wait for next event or task completion
508562 try :
509- event = await asyncio .wait_for (
563+ dequeued = await asyncio .wait_for (
510564 tapped_queue .dequeue_event (), timeout = 0.1
511565 )
566+ event , updated_task = cast ('Any' , dequeued )
567+ logger .debug (
568+ 'Subscriber[%s]\n Dequeued event %s\n Updated task %s\n ' ,
569+ self ._task_id ,
570+ event ,
571+ updated_task ,
572+ )
573+ if replace_status_update_with_task and isinstance (
574+ event , TaskStatusUpdateEvent
575+ ):
576+ logger .debug (
577+ 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s' ,
578+ self ._task_id ,
579+ updated_task ,
580+ )
581+ event = updated_task
512582 if self ._exception :
513583 raise self ._exception from None
514584 if isinstance (event , _RequestCompleted ):
@@ -522,6 +592,12 @@ async def subscribe( # noqa: PLR0912, PLR0915
522592 )
523593 return
524594 continue
595+ elif isinstance (event , _RequestStarted ):
596+ logger .debug (
597+ 'Subscriber[%s]: Request started' ,
598+ self ._task_id ,
599+ )
600+ continue
525601 except (asyncio .TimeoutError , TimeoutError ):
526602 if self ._is_finished .is_set ():
527603 if self ._exception :
@@ -545,7 +621,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
545621 # Evaluate if this was the last subscriber on a finished task.
546622 await self ._maybe_cleanup ()
547623
548- async def cancel (self , call_context : ServerCallContext ) -> Task | Message :
624+ async def cancel (self , call_context : ServerCallContext ) -> Task :
549625 """Cancels the running active task.
550626
551627 Concurrency Guarantee:
@@ -558,11 +634,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
558634 # TODO: Conflicts with call_context on the pending request.
559635 self ._task_manager ._call_context = call_context
560636
561- task = await self .get_task ()
637+ task = await self ._task_manager . get_task ()
562638 request_context = RequestContext (
563639 call_context = call_context ,
564640 task_id = self ._task_id ,
565- context_id = task .context_id ,
641+ context_id = task .context_id if task else None ,
566642 task = task ,
567643 )
568644
@@ -591,7 +667,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
591667 )
592668
593669 await self ._is_finished .wait ()
594- return await self .get_task ()
670+ task = await self ._task_manager .get_task ()
671+ if not task :
672+ raise RuntimeError ('Task should have been created' )
673+ return task
595674
596675 async def _maybe_cleanup (self ) -> None :
597676 """Triggers cleanup if task is finished and has no subscribers.
0 commit comments