From f5b0d1d3a18d72a3fe7846ff460f3a2b57bc42e2 Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Tue, 7 Apr 2026 13:22:25 +0000 Subject: [PATCH] Task status cleanup. --- src/a2a/server/agent_execution/active_task.py | 152 ++++++++++-------- tests/integration/test_scenarios.py | 24 ++- 2 files changed, 97 insertions(+), 79 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index f313ca11e..bf9e129a6 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -32,6 +32,8 @@ Message, Task, TaskState, + TaskStatus, + TaskStatusUpdateEvent, ) from a2a.utils.errors import ( InvalidParamsError, @@ -252,80 +254,75 @@ async def _run_producer(self) -> None: """ logger.debug('Producer[%s]: Started', self._task_id) try: - try: - try: - while True: - ( - request_context, - request_id, - ) = await self._request_queue.get() - await self._request_lock.acquire() - # TODO: Should we create task manager every time? - self._task_manager._call_context = ( - request_context.call_context - ) - request_context.current_task = ( - await self._task_manager.get_task() - ) + active = True + while active: + ( + request_context, + request_id, + ) = await self._request_queue.get() + await self._request_lock.acquire() + # TODO: Should we create task manager every time? + self._task_manager._call_context = request_context.call_context + request_context.current_task = ( + await self._task_manager.get_task() + ) - message = request_context.message - if message: - request_context.current_task = ( - self._task_manager.update_with_message( - message, - cast('Task', request_context.current_task), - ) - ) - await self._task_manager.save_task_event( - request_context.current_task - ) - self._task_created.set() - logger.debug( - 'Producer[%s]: Executing agent task %s', - self._task_id, - request_context.current_task, + message = request_context.message + if message: + request_context.current_task = ( + self._task_manager.update_with_message( + message, + cast('Task', request_context.current_task), ) + ) + await self._task_manager.save_task_event( + request_context.current_task + ) + self._task_created.set() + logger.debug( + 'Producer[%s]: Executing agent task %s', + self._task_id, + request_context.current_task, + ) - try: - await self._agent_executor.execute( - request_context, self._event_queue_agent - ) - logger.debug( - 'Producer[%s]: Execution finished successfully', - self._task_id, - ) - except Exception as e: - async with self._lock: - if self._exception is None: - self._exception = e - raise - finally: - logger.debug( - 'Producer[%s]: Enqueuing request completed event', - self._task_id, - ) - # TODO: Hide from external consumers - await self._event_queue_agent.enqueue_event( - cast('Event', _RequestCompleted(request_id)) - ) - self._request_queue.task_done() + try: + await self._agent_executor.execute( + request_context, self._event_queue_agent + ) + logger.debug( + 'Producer[%s]: Execution finished successfully', + self._task_id, + ) except QueueShutDown: logger.debug( 'Producer[%s]: Request queue shut down', self._task_id ) - except asyncio.CancelledError: - logger.debug('Producer[%s]: Cancelled', self._task_id) - raise - except Exception as e: - logger.exception('Producer[%s]: Failed', self._task_id) - async with self._lock: - if self._exception is None: - self._exception = e - finally: - self._request_queue.shutdown(immediate=True) - await self._event_queue_agent.close(immediate=False) - await self._event_queue_subscribers.close(immediate=False) + raise + except asyncio.CancelledError: + logger.debug('Producer[%s]: Cancelled', self._task_id) + raise + except Exception as e: + logger.exception( + 'Producer[%s]: Execution failed', + self._task_id, + ) + async with self._lock: + await self._mark_task_as_failed(e) + active = False + finally: + logger.debug( + 'Producer[%s]: Enqueuing request completed event', + self._task_id, + ) + # TODO: Hide from external consumers + await self._event_queue_agent.enqueue_event( + cast('Event', _RequestCompleted(request_id)) + ) + self._request_queue.task_done() finally: + self._request_queue.shutdown(immediate=True) + await self._event_queue_agent.close(immediate=False) + await self._event_queue_subscribers.close(immediate=False) logger.debug('Producer[%s]: Completed', self._task_id) async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 @@ -443,8 +440,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 except Exception as e: logger.exception('Consumer[%s]: Failed', self._task_id) async with self._lock: - if self._exception is None: - self._exception = e + await self._mark_task_as_failed(e) finally: # The consumer is dead. The ActiveTask is permanently finished. self._is_finished.set() @@ -581,9 +577,7 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message: logger.exception( 'Cancel[%s]: Agent cancel failed', self._task_id ) - if not self._exception: - self._exception = e - + await self._mark_task_as_failed(e) raise else: logger.debug( @@ -619,6 +613,22 @@ async def _maybe_cleanup(self) -> None: logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id) self._on_cleanup(self) + async def _mark_task_as_failed(self, exception: Exception) -> None: + if self._exception is None: + self._exception = exception + if self._task_created.is_set(): + task = await self._task_manager.get_task() + if task is not None: + await self._event_queue_agent.enqueue_event( + TaskStatusUpdateEvent( + task_id=task.id, + context_id=task.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_FAILED, + ), + ) + ) + async def get_task(self) -> Task: """Get task from db.""" # TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation). diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index 94774e29a..a7d85a28c 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -437,9 +437,8 @@ async def cancel( # Legacy is not creating tasks for agent failures. assert len((await client.list_tasks(ListTasksRequest())).tasks) == 0 else: - # TODO: should it be TASK_STATE_FAILED ? (task,) = (await client.list_tasks(ListTasksRequest())).tasks - assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert task.status.state == TaskState.TASK_STATE_FAILED # Scenario 12/13: Exception after initial event @@ -503,9 +502,12 @@ async def release_agent(): await asyncio.gather(*tasks) - # TODO: should it be TASK_STATE_FAILED ? (task,) = (await client.list_tasks(ListTasksRequest())).tasks - assert task.status.state == TaskState.TASK_STATE_WORKING + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED # Scenario 14: Exception in Cancel @@ -563,9 +565,12 @@ async def cancel( with pytest.raises(A2AClientError, match='TEST_ERROR_IN_CANCEL'): await client.cancel_task(CancelTaskRequest(id=task_id)) - # TODO: should it be TASK_STATE_CANCELED or TASK_STATE_FAILED? (task,) = (await client.list_tasks(ListTasksRequest())).tasks - assert task.status.state == TaskState.TASK_STATE_WORKING + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED # Scenario 15: Subscribe to task that errors out @@ -632,9 +637,12 @@ async def consume_events(): with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'): await consume_task - # TODO: should it be TASK_STATE_FAILED? (task,) = (await client.list_tasks(ListTasksRequest())).tasks - assert task.status.state == TaskState.TASK_STATE_WORKING + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED # Scenario 16: Slow execution and return_immediately=True