Skip to content

Commit a544768

Browse files
committed
Task status cleanup.
1 parent a61f6d4 commit a544768

2 files changed

Lines changed: 96 additions & 79 deletions

File tree

src/a2a/server/agent_execution/active_task.py

Lines changed: 80 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
Message,
3333
Task,
3434
TaskState,
35+
TaskStatus,
36+
TaskStatusUpdateEvent,
3537
)
3638
from a2a.utils.errors import (
3739
InvalidParamsError,
@@ -252,80 +254,74 @@ async def _run_producer(self) -> None:
252254
"""
253255
logger.debug('Producer[%s]: Started', self._task_id)
254256
try:
255-
try:
256-
try:
257-
while True:
258-
(
259-
request_context,
260-
request_id,
261-
) = await self._request_queue.get()
262-
await self._request_lock.acquire()
263-
# TODO: Should we create task manager every time?
264-
self._task_manager._call_context = (
265-
request_context.call_context
266-
)
267-
request_context.current_task = (
268-
await self._task_manager.get_task()
269-
)
257+
active = True
258+
while active:
259+
(
260+
request_context,
261+
request_id,
262+
) = await self._request_queue.get()
263+
await self._request_lock.acquire()
264+
# TODO: Should we create task manager every time?
265+
self._task_manager._call_context = request_context.call_context
266+
request_context.current_task = (
267+
await self._task_manager.get_task()
268+
)
270269

271-
message = request_context.message
272-
if message:
273-
request_context.current_task = (
274-
self._task_manager.update_with_message(
275-
message,
276-
cast('Task', request_context.current_task),
277-
)
278-
)
279-
await self._task_manager.save_task_event(
280-
request_context.current_task
281-
)
282-
self._task_created.set()
283-
logger.debug(
284-
'Producer[%s]: Executing agent task %s',
285-
self._task_id,
286-
request_context.current_task,
270+
message = request_context.message
271+
if message:
272+
request_context.current_task = (
273+
self._task_manager.update_with_message(
274+
message,
275+
cast('Task', request_context.current_task),
287276
)
277+
)
278+
await self._task_manager.save_task_event(
279+
request_context.current_task
280+
)
281+
self._task_created.set()
282+
logger.debug(
283+
'Producer[%s]: Executing agent task %s',
284+
self._task_id,
285+
request_context.current_task,
286+
)
288287

289-
try:
290-
await self._agent_executor.execute(
291-
request_context, self._event_queue_agent
292-
)
293-
logger.debug(
294-
'Producer[%s]: Execution finished successfully',
295-
self._task_id,
296-
)
297-
except Exception as e:
298-
async with self._lock:
299-
if self._exception is None:
300-
self._exception = e
301-
raise
302-
finally:
303-
logger.debug(
304-
'Producer[%s]: Enqueuing request completed event',
305-
self._task_id,
306-
)
307-
# TODO: Hide from external consumers
308-
await self._event_queue_agent.enqueue_event(
309-
cast('Event', _RequestCompleted(request_id))
310-
)
311-
self._request_queue.task_done()
288+
try:
289+
await self._agent_executor.execute(
290+
request_context, self._event_queue_agent
291+
)
292+
logger.debug(
293+
'Producer[%s]: Execution finished successfully',
294+
self._task_id,
295+
)
312296
except QueueShutDown:
313297
logger.debug(
314298
'Producer[%s]: Request queue shut down', self._task_id
315299
)
316-
except asyncio.CancelledError:
317-
logger.debug('Producer[%s]: Cancelled', self._task_id)
318-
raise
319-
except Exception as e:
320-
logger.exception('Producer[%s]: Failed', self._task_id)
321-
async with self._lock:
322-
if self._exception is None:
323-
self._exception = e
324-
finally:
325-
self._request_queue.shutdown(immediate=True)
326-
await self._event_queue_agent.close(immediate=False)
327-
await self._event_queue_subscribers.close(immediate=False)
300+
raise
301+
except asyncio.CancelledError:
302+
logger.debug('Producer[%s]: Cancelled', self._task_id)
303+
raise
304+
except Exception as e: # noqa: BLE001 -Catch all other exceptions from agent.
305+
logger.debug(
306+
'Producer[%s]: Execution failed: %s', self._task_id, e
307+
)
308+
async with self._lock:
309+
await self._mark_task_as_failed(e)
310+
active = False
311+
finally:
312+
logger.debug(
313+
'Producer[%s]: Enqueuing request completed event',
314+
self._task_id,
315+
)
316+
# TODO: Hide from external consumers
317+
await self._event_queue_agent.enqueue_event(
318+
cast('Event', _RequestCompleted(request_id))
319+
)
320+
self._request_queue.task_done()
328321
finally:
322+
self._request_queue.shutdown(immediate=True)
323+
await self._event_queue_agent.close(immediate=False)
324+
await self._event_queue_subscribers.close(immediate=False)
329325
logger.debug('Producer[%s]: Completed', self._task_id)
330326

331327
async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
@@ -443,8 +439,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
443439
except Exception as e:
444440
logger.exception('Consumer[%s]: Failed', self._task_id)
445441
async with self._lock:
446-
if self._exception is None:
447-
self._exception = e
442+
await self._mark_task_as_failed(e)
448443
finally:
449444
# The consumer is dead. The ActiveTask is permanently finished.
450445
self._is_finished.set()
@@ -581,9 +576,7 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
581576
logger.exception(
582577
'Cancel[%s]: Agent cancel failed', self._task_id
583578
)
584-
if not self._exception:
585-
self._exception = e
586-
579+
await self._mark_task_as_failed(e)
587580
raise
588581
else:
589582
logger.debug(
@@ -619,6 +612,22 @@ async def _maybe_cleanup(self) -> None:
619612
logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id)
620613
self._on_cleanup(self)
621614

615+
async def _mark_task_as_failed(self, exception: Exception) -> None:
616+
if self._exception is None:
617+
self._exception = exception
618+
if self._task_created.is_set():
619+
task = await self._task_manager.get_task()
620+
if task is not None:
621+
await self._event_queue_agent.enqueue_event(
622+
TaskStatusUpdateEvent(
623+
task_id=task.id,
624+
context_id=task.context_id,
625+
status=TaskStatus(
626+
state=TaskState.TASK_STATE_FAILED,
627+
),
628+
)
629+
)
630+
622631
async def get_task(self) -> Task:
623632
"""Get task from db."""
624633
# TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation).

tests/integration/test_scenarios.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,8 @@ async def cancel(
437437
# Legacy is not creating tasks for agent failures.
438438
assert len((await client.list_tasks(ListTasksRequest())).tasks) == 0
439439
else:
440-
# TODO: should it be TASK_STATE_FAILED ?
441440
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
442-
assert task.status.state == TaskState.TASK_STATE_SUBMITTED
441+
assert task.status.state == TaskState.TASK_STATE_FAILED
443442

444443

445444
# Scenario 12/13: Exception after initial event
@@ -503,9 +502,12 @@ async def release_agent():
503502

504503
await asyncio.gather(*tasks)
505504

506-
# TODO: should it be TASK_STATE_FAILED ?
507505
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
508-
assert task.status.state == TaskState.TASK_STATE_WORKING
506+
if use_legacy:
507+
# Legacy does not update task state on exception.
508+
assert task.status.state == TaskState.TASK_STATE_WORKING
509+
else:
510+
assert task.status.state == TaskState.TASK_STATE_FAILED
509511

510512

511513
# Scenario 14: Exception in Cancel
@@ -563,9 +565,12 @@ async def cancel(
563565
with pytest.raises(A2AClientError, match='TEST_ERROR_IN_CANCEL'):
564566
await client.cancel_task(CancelTaskRequest(id=task_id))
565567

566-
# TODO: should it be TASK_STATE_CANCELED or TASK_STATE_FAILED?
567568
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
568-
assert task.status.state == TaskState.TASK_STATE_WORKING
569+
if use_legacy:
570+
# Legacy does not update task state on exception.
571+
assert task.status.state == TaskState.TASK_STATE_WORKING
572+
else:
573+
assert task.status.state == TaskState.TASK_STATE_FAILED
569574

570575

571576
# Scenario 15: Subscribe to task that errors out
@@ -632,9 +637,12 @@ async def consume_events():
632637
with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'):
633638
await consume_task
634639

635-
# TODO: should it be TASK_STATE_FAILED?
636640
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
637-
assert task.status.state == TaskState.TASK_STATE_WORKING
641+
if use_legacy:
642+
# Legacy does not update task state on exception.
643+
assert task.status.state == TaskState.TASK_STATE_WORKING
644+
else:
645+
assert task.status.state == TaskState.TASK_STATE_FAILED
638646

639647

640648
# Scenario 16: Slow execution and return_immediately=True

0 commit comments

Comments
 (0)