Skip to content

Commit 5c5720e

Browse files
authored
Merge branch '1.0-dev' into GetExtendedAgentCard
2 parents 3697191 + 4fc6b54 commit 5c5720e

2 files changed

Lines changed: 97 additions & 79 deletions

File tree

src/a2a/server/agent_execution/active_task.py

Lines changed: 81 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
Message,
3333
Task,
3434
TaskState,
35+
TaskStatus,
36+
TaskStatusUpdateEvent,
3537
)
3638
from a2a.utils.errors import (
3739
InvalidParamsError,
@@ -252,80 +254,75 @@ async def _run_producer(self) -> None:
252254
"""
253255
logger.debug('Producer[%s]: Started', self._task_id)
254256
try:
255-
try:
256-
try:
257-
while True:
258-
(
259-
request_context,
260-
request_id,
261-
) = await self._request_queue.get()
262-
await self._request_lock.acquire()
263-
# TODO: Should we create task manager every time?
264-
self._task_manager._call_context = (
265-
request_context.call_context
266-
)
267-
request_context.current_task = (
268-
await self._task_manager.get_task()
269-
)
257+
active = True
258+
while active:
259+
(
260+
request_context,
261+
request_id,
262+
) = await self._request_queue.get()
263+
await self._request_lock.acquire()
264+
# TODO: Should we create task manager every time?
265+
self._task_manager._call_context = request_context.call_context
266+
request_context.current_task = (
267+
await self._task_manager.get_task()
268+
)
270269

271-
message = request_context.message
272-
if message:
273-
request_context.current_task = (
274-
self._task_manager.update_with_message(
275-
message,
276-
cast('Task', request_context.current_task),
277-
)
278-
)
279-
await self._task_manager.save_task_event(
280-
request_context.current_task
281-
)
282-
self._task_created.set()
283-
logger.debug(
284-
'Producer[%s]: Executing agent task %s',
285-
self._task_id,
286-
request_context.current_task,
270+
message = request_context.message
271+
if message:
272+
request_context.current_task = (
273+
self._task_manager.update_with_message(
274+
message,
275+
cast('Task', request_context.current_task),
287276
)
277+
)
278+
await self._task_manager.save_task_event(
279+
request_context.current_task
280+
)
281+
self._task_created.set()
282+
logger.debug(
283+
'Producer[%s]: Executing agent task %s',
284+
self._task_id,
285+
request_context.current_task,
286+
)
288287

289-
try:
290-
await self._agent_executor.execute(
291-
request_context, self._event_queue_agent
292-
)
293-
logger.debug(
294-
'Producer[%s]: Execution finished successfully',
295-
self._task_id,
296-
)
297-
except Exception as e:
298-
async with self._lock:
299-
if self._exception is None:
300-
self._exception = e
301-
raise
302-
finally:
303-
logger.debug(
304-
'Producer[%s]: Enqueuing request completed event',
305-
self._task_id,
306-
)
307-
# TODO: Hide from external consumers
308-
await self._event_queue_agent.enqueue_event(
309-
cast('Event', _RequestCompleted(request_id))
310-
)
311-
self._request_queue.task_done()
288+
try:
289+
await self._agent_executor.execute(
290+
request_context, self._event_queue_agent
291+
)
292+
logger.debug(
293+
'Producer[%s]: Execution finished successfully',
294+
self._task_id,
295+
)
312296
except QueueShutDown:
313297
logger.debug(
314298
'Producer[%s]: Request queue shut down', self._task_id
315299
)
316-
except asyncio.CancelledError:
317-
logger.debug('Producer[%s]: Cancelled', self._task_id)
318-
raise
319-
except Exception as e:
320-
logger.exception('Producer[%s]: Failed', self._task_id)
321-
async with self._lock:
322-
if self._exception is None:
323-
self._exception = e
324-
finally:
325-
self._request_queue.shutdown(immediate=True)
326-
await self._event_queue_agent.close(immediate=False)
327-
await self._event_queue_subscribers.close(immediate=False)
300+
raise
301+
except asyncio.CancelledError:
302+
logger.debug('Producer[%s]: Cancelled', self._task_id)
303+
raise
304+
except Exception as e:
305+
logger.exception(
306+
'Producer[%s]: Execution failed',
307+
self._task_id,
308+
)
309+
async with self._lock:
310+
await self._mark_task_as_failed(e)
311+
active = False
312+
finally:
313+
logger.debug(
314+
'Producer[%s]: Enqueuing request completed event',
315+
self._task_id,
316+
)
317+
# TODO: Hide from external consumers
318+
await self._event_queue_agent.enqueue_event(
319+
cast('Event', _RequestCompleted(request_id))
320+
)
321+
self._request_queue.task_done()
328322
finally:
323+
self._request_queue.shutdown(immediate=True)
324+
await self._event_queue_agent.close(immediate=False)
325+
await self._event_queue_subscribers.close(immediate=False)
329326
logger.debug('Producer[%s]: Completed', self._task_id)
330327

331328
async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
@@ -443,8 +440,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
443440
except Exception as e:
444441
logger.exception('Consumer[%s]: Failed', self._task_id)
445442
async with self._lock:
446-
if self._exception is None:
447-
self._exception = e
443+
await self._mark_task_as_failed(e)
448444
finally:
449445
# The consumer is dead. The ActiveTask is permanently finished.
450446
self._is_finished.set()
@@ -581,9 +577,7 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
581577
logger.exception(
582578
'Cancel[%s]: Agent cancel failed', self._task_id
583579
)
584-
if not self._exception:
585-
self._exception = e
586-
580+
await self._mark_task_as_failed(e)
587581
raise
588582
else:
589583
logger.debug(
@@ -619,6 +613,22 @@ async def _maybe_cleanup(self) -> None:
619613
logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id)
620614
self._on_cleanup(self)
621615

616+
async def _mark_task_as_failed(self, exception: Exception) -> None:
617+
if self._exception is None:
618+
self._exception = exception
619+
if self._task_created.is_set():
620+
task = await self._task_manager.get_task()
621+
if task is not None:
622+
await self._event_queue_agent.enqueue_event(
623+
TaskStatusUpdateEvent(
624+
task_id=task.id,
625+
context_id=task.context_id,
626+
status=TaskStatus(
627+
state=TaskState.TASK_STATE_FAILED,
628+
),
629+
)
630+
)
631+
622632
async def get_task(self) -> Task:
623633
"""Get task from db."""
624634
# TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation).

tests/integration/test_scenarios.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,8 @@ async def cancel(
447447
# Legacy is not creating tasks for agent failures.
448448
assert len((await client.list_tasks(ListTasksRequest())).tasks) == 0
449449
else:
450-
# TODO: should it be TASK_STATE_FAILED ?
451450
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
452-
assert task.status.state == TaskState.TASK_STATE_SUBMITTED
451+
assert task.status.state == TaskState.TASK_STATE_FAILED
453452

454453

455454
# Scenario 12/13: Exception after initial event
@@ -513,9 +512,12 @@ async def release_agent():
513512

514513
await asyncio.gather(*tasks)
515514

516-
# TODO: should it be TASK_STATE_FAILED ?
517515
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
518-
assert task.status.state == TaskState.TASK_STATE_WORKING
516+
if use_legacy:
517+
# Legacy does not update task state on exception.
518+
assert task.status.state == TaskState.TASK_STATE_WORKING
519+
else:
520+
assert task.status.state == TaskState.TASK_STATE_FAILED
519521

520522

521523
# Scenario 14: Exception in Cancel
@@ -573,9 +575,12 @@ async def cancel(
573575
with pytest.raises(A2AClientError, match='TEST_ERROR_IN_CANCEL'):
574576
await client.cancel_task(CancelTaskRequest(id=task_id))
575577

576-
# TODO: should it be TASK_STATE_CANCELED or TASK_STATE_FAILED?
577578
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
578-
assert task.status.state == TaskState.TASK_STATE_WORKING
579+
if use_legacy:
580+
# Legacy does not update task state on exception.
581+
assert task.status.state == TaskState.TASK_STATE_WORKING
582+
else:
583+
assert task.status.state == TaskState.TASK_STATE_FAILED
579584

580585

581586
# Scenario 15: Subscribe to task that errors out
@@ -642,9 +647,12 @@ async def consume_events():
642647
with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'):
643648
await consume_task
644649

645-
# TODO: should it be TASK_STATE_FAILED?
646650
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
647-
assert task.status.state == TaskState.TASK_STATE_WORKING
651+
if use_legacy:
652+
# Legacy does not update task state on exception.
653+
assert task.status.state == TaskState.TASK_STATE_WORKING
654+
else:
655+
assert task.status.state == TaskState.TASK_STATE_FAILED
648656

649657

650658
# Scenario 16: Slow execution and return_immediately=True

0 commit comments

Comments
 (0)