55import logging
66import uuid
77
8- from typing import TYPE_CHECKING , cast
8+ from typing import TYPE_CHECKING , Any , cast
99
1010from a2a .server .agent_execution .context import RequestContext
1111
5656}
5757
5858
59+ class _RequestStarted :
60+ def __init__ (self , request_id : uuid .UUID , request_context : RequestContext ):
61+ self .request_id = request_id
62+ self .request_context = request_context
63+
64+
5965class _RequestCompleted :
6066 def __init__ (self , request_id : uuid .UUID ):
6167 self .request_id = request_id
@@ -199,25 +205,13 @@ async def start(
199205 logger .debug ('TASK (start): %s' , task )
200206
201207 if task :
208+ self ._task_created .set ()
202209 if task .status .state in TERMINAL_TASK_STATES :
203210 raise InvalidParamsError (
204211 message = f'Task { task .id } is in terminal state: { task .status .state } '
205212 )
206- else :
207- if not create_task_if_missing :
208- raise TaskNotFoundError
209-
210- # New task. Create and save it so it's not "missing" if queried immediately
211- # (especially important for return_immediately=True)
212- if self ._task_manager .context_id is None :
213- raise ValueError ('Context ID is required for new tasks' )
214- task = self ._task_manager ._init_task_obj (
215- self ._task_id ,
216- self ._task_manager .context_id ,
217- )
218- await self ._task_manager .save_task_event (task )
219- if self ._push_sender :
220- await self ._push_sender .send_notification (task .id , task )
213+ elif not create_task_if_missing :
214+ raise TaskNotFoundError
221215
222216 except Exception :
223217 logger .debug (
@@ -253,72 +247,72 @@ async def _run_producer(self) -> None:
253247 Runs as a detached asyncio.Task. Safe to cancel.
254248 """
255249 logger .debug ('Producer[%s]: Started' , self ._task_id )
250+ request_context = None
256251 try :
257- active = True
258- while active :
252+ while True :
259253 (
260254 request_context ,
261255 request_id ,
262256 ) = await self ._request_queue .get ()
263257 await self ._request_lock .acquire ()
264258 # TODO: Should we create task manager every time?
265259 self ._task_manager ._call_context = request_context .call_context
260+
266261 request_context .current_task = (
267262 await self ._task_manager .get_task ()
268263 )
269264
270- message = request_context .message
271- if message :
272- request_context .current_task = (
273- self ._task_manager .update_with_message (
274- message ,
275- cast ('Task' , request_context .current_task ),
276- )
277- )
278- await self ._task_manager .save_task_event (
279- request_context .current_task
280- )
281- self ._task_created .set ()
282265 logger .debug (
283266 'Producer[%s]: Executing agent task %s' ,
284267 self ._task_id ,
285268 request_context .current_task ,
286269 )
287270
288271 try :
272+ await self ._event_queue_agent .enqueue_event (
273+ cast (
274+ 'Event' ,
275+ _RequestStarted (request_id , request_context ),
276+ )
277+ )
278+
289279 await self ._agent_executor .execute (
290280 request_context , self ._event_queue_agent
291281 )
292282 logger .debug (
293283 'Producer[%s]: Execution finished successfully' ,
294284 self ._task_id ,
295285 )
296- except QueueShutDown :
297- logger .debug (
298- 'Producer[%s]: Request queue shut down' , self ._task_id
299- )
300- raise
301- except asyncio .CancelledError :
302- logger .debug ('Producer[%s]: Cancelled' , self ._task_id )
303- raise
304- except Exception as e :
305- logger .exception (
306- 'Producer[%s]: Execution failed' ,
307- self ._task_id ,
308- )
309- async with self ._lock :
310- await self ._mark_task_as_failed (e )
311- active = False
312286 finally :
313287 logger .debug (
314288 'Producer[%s]: Enqueuing request completed event' ,
315289 self ._task_id ,
316290 )
317- # TODO: Hide from external consumers
318291 await self ._event_queue_agent .enqueue_event (
319292 cast ('Event' , _RequestCompleted (request_id ))
320293 )
321294 self ._request_queue .task_done ()
295+ except asyncio .CancelledError :
296+ logger .debug ('Producer[%s]: Cancelled' , self ._task_id )
297+
298+ except QueueShutDown :
299+ logger .debug ('Producer[%s]: Queue shut down' , self ._task_id )
300+
301+ except Exception as e :
302+ logger .exception (
303+ 'Producer[%s]: Execution failed' ,
304+ self ._task_id ,
305+ )
306+ # Create task and mark as failed.
307+ if request_context :
308+ await self ._task_manager .ensure_task_id (
309+ self ._task_id ,
310+ request_context .context_id or '' ,
311+ )
312+ self ._task_created .set ()
313+ async with self ._lock :
314+ await self ._mark_task_as_failed (e )
315+
322316 finally :
323317 self ._request_queue .shutdown (immediate = True )
324318 await self ._event_queue_agent .close (immediate = False )
@@ -338,6 +332,10 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
338332 `_is_finished`, unblocking all global subscribers and wait() calls.
339333 """
340334 logger .debug ('Consumer[%s]: Started' , self ._task_id )
335+ task_mode = None
336+ message_to_save = None
337+ # TODO: Make helper methods
338+ # TODO: Support Task enqueue
341339 try :
342340 try :
343341 try :
@@ -347,6 +345,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
347345 'Consumer[%s]: Waiting for event' ,
348346 self ._task_id ,
349347 )
348+ new_task = None
350349 event = await self ._event_queue_agent .dequeue_event ()
351350 logger .debug (
352351 'Consumer[%s]: Dequeued event %s' ,
@@ -361,24 +360,78 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
361360 self ._task_id ,
362361 )
363362 self ._request_lock .release ()
363+ elif isinstance (event , _RequestStarted ):
364+ logger .debug (
365+ 'Consumer[%s]: Request started' ,
366+ self ._task_id ,
367+ )
368+ message_to_save = event .request_context .message
369+
364370 elif isinstance (event , Message ):
371+ if task_mode is not None :
372+ if task_mode :
373+ logger .error (
374+ 'Received Message() object in task mode.'
375+ )
376+ else :
377+ logger .error (
378+ 'Multiple Message() objects received.'
379+ )
380+ task_mode = False
365381 logger .debug (
366382 'Consumer[%s]: Setting result to Message: %s' ,
367383 self ._task_id ,
368384 event ,
369385 )
370386 else :
387+ if task_mode is False :
388+ logger .error (
389+ 'Received %s in message mode.' ,
390+ type (event ).__name__ ,
391+ )
392+
393+ if isinstance (event , Task ):
394+ new_task = event
395+ await self ._task_manager .save_task_event (
396+ new_task
397+ )
398+ else :
399+ new_task = (
400+ await self ._task_manager .ensure_task_id (
401+ self ._task_id ,
402+ event .context_id ,
403+ )
404+ )
405+
406+ if message_to_save is not None :
407+ new_task = self ._task_manager .update_with_message (
408+ message_to_save ,
409+ new_task ,
410+ )
411+ await (
412+ self ._task_manager .save_task_event (
413+ new_task
414+ )
415+ )
416+ message_to_save = None
417+
418+ task_mode = True
371419 # Save structural events (like TaskStatusUpdate) to DB.
372- # TODO: Create task manager every time ?
420+
373421 self ._task_manager .context_id = event .context_id
374- await self ._task_manager .process (event )
422+ if not isinstance (event , Task ):
423+ await self ._task_manager .process (event )
424+
425+ self ._task_created .set ()
375426
376427 # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states
377428 new_task = await self ._task_manager .get_task ()
378429 if new_task is None :
379430 raise RuntimeError (
380431 f'Task { self .task_id } not found'
381432 )
433+ if isinstance (event , Task ):
434+ event = new_task
382435 is_interrupted = (
383436 new_task .status .state
384437 in INTERRUPTED_TASK_STATES
@@ -432,8 +485,23 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
432485 self ._task_id , event
433486 )
434487 finally :
488+ if new_task is not None :
489+ new_task_copy = Task ()
490+ new_task_copy .CopyFrom (new_task )
491+ new_task = new_task_copy
492+ if isinstance (event , Task ):
493+ new_task_copy = Task ()
494+ new_task_copy .CopyFrom (event )
495+ event = new_task_copy
496+
497+ logger .debug (
498+ 'Consumer[%s]: Enqueuing\n Event: %s\n New Task: %s\n ' ,
499+ self ._task_id ,
500+ event ,
501+ new_task ,
502+ )
435503 await self ._event_queue_subscribers .enqueue_event (
436- event
504+ cast ( 'Any' , ( event , new_task ))
437505 )
438506 self ._event_queue_agent .task_done ()
439507 except QueueShutDown :
@@ -459,6 +527,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
459527 * ,
460528 request : RequestContext | None = None ,
461529 include_initial_task : bool = False ,
530+ replace_status_update_with_task : bool = False ,
462531 ) -> AsyncGenerator [Event , None ]:
463532 """Creates a queue tap and yields events as they are produced.
464533
@@ -506,9 +575,25 @@ async def subscribe( # noqa: PLR0912, PLR0915
506575
507576 # Wait for next event or task completion
508577 try :
509- event = await asyncio .wait_for (
578+ dequeued = await asyncio .wait_for (
510579 tapped_queue .dequeue_event (), timeout = 0.1
511580 )
581+ event , updated_task = cast ('Any' , dequeued )
582+ logger .debug (
583+ 'Subscriber[%s]\n Dequeued event %s\n Updated task %s\n ' ,
584+ self ._task_id ,
585+ event ,
586+ updated_task ,
587+ )
588+ if replace_status_update_with_task and isinstance (
589+ event , TaskStatusUpdateEvent
590+ ):
591+ logger .debug (
592+ 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s' ,
593+ self ._task_id ,
594+ updated_task ,
595+ )
596+ event = updated_task
512597 if self ._exception :
513598 raise self ._exception from None
514599 if isinstance (event , _RequestCompleted ):
@@ -522,6 +607,12 @@ async def subscribe( # noqa: PLR0912, PLR0915
522607 )
523608 return
524609 continue
610+ elif isinstance (event , _RequestStarted ):
611+ logger .debug (
612+ 'Subscriber[%s]: Request started' ,
613+ self ._task_id ,
614+ )
615+ continue
525616 except (asyncio .TimeoutError , TimeoutError ):
526617 if self ._is_finished .is_set ():
527618 if self ._exception :
@@ -545,7 +636,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
545636 # Evaluate if this was the last subscriber on a finished task.
546637 await self ._maybe_cleanup ()
547638
548- async def cancel (self , call_context : ServerCallContext ) -> Task | Message :
639+ async def cancel (self , call_context : ServerCallContext ) -> Task :
549640 """Cancels the running active task.
550641
551642 Concurrency Guarantee:
@@ -558,11 +649,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
558649 # TODO: Conflicts with call_context on the pending request.
559650 self ._task_manager ._call_context = call_context
560651
561- task = await self .get_task ()
652+ task = await self ._task_manager . get_task ()
562653 request_context = RequestContext (
563654 call_context = call_context ,
564655 task_id = self ._task_id ,
565- context_id = task .context_id ,
656+ context_id = task .context_id if task else None ,
566657 task = task ,
567658 )
568659
@@ -591,7 +682,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
591682 )
592683
593684 await self ._is_finished .wait ()
594- return await self .get_task ()
685+ task = await self ._task_manager .get_task ()
686+ if not task :
687+ raise RuntimeError ('Task should have been created' )
688+ return task
595689
596690 async def _maybe_cleanup (self ) -> None :
597691 """Triggers cleanup if task is finished and has no subscribers.
0 commit comments