Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,21 @@ async def on_cancel_task(

consumer = EventConsumer(queue)
result = await result_aggregator.consume_all(consumer)
if isinstance(result, Task):
return result
if not isinstance(result, Task):
raise ServerError(
error=InternalError(
message='Agent did not return valid response for cancel'
)
)

raise ServerError(
error=InternalError(
message='Agent did not return valid response for cancel'
if result.status.state != TaskState.canceled:
raise ServerError(
error=TaskNotCancelableError(
message=f'Task cannot be canceled - current state: {result.status.state}'
)
)
)

return result

async def _run_event_stream(
self, request: RequestContext, queue: EventQueue
Expand Down
50 changes: 50 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,56 @@ async def test_on_cancel_task_cancels_running_agent():
mock_agent_executor.cancel.assert_awaited_once()


@pytest.mark.asyncio
async def test_on_cancel_task_completes_during_cancellation():
"""Test on_cancel_task fails to cancel a task due to concurrent task completion."""
task_id = 'running_agent_task_to_cancel'
sample_task = create_sample_task(task_id=task_id)
mock_task_store = AsyncMock(spec=TaskStore)
mock_task_store.get.return_value = sample_task

mock_queue_manager = AsyncMock(spec=QueueManager)
mock_event_queue = AsyncMock(spec=EventQueue)
mock_queue_manager.tap.return_value = mock_event_queue

mock_agent_executor = AsyncMock(spec=AgentExecutor)

# Mock ResultAggregator
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
mock_result_aggregator_instance.consume_all.return_value = (
create_sample_task(task_id=task_id, status_state=TaskState.completed)
)

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
)

# Simulate a running agent task
mock_producer_task = AsyncMock(spec=asyncio.Task)
request_handler._running_agents[task_id] = mock_producer_task

from a2a.utils.errors import (
ServerError, # Local import
TaskNotCancelableError, # Local import
)

with patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
):
params = TaskIdParams(id=task_id)
with pytest.raises(ServerError) as exc_info:
await request_handler.on_cancel_task(
params, create_server_call_context()
)

mock_producer_task.cancel.assert_called_once()
mock_agent_executor.cancel.assert_awaited_once()
assert isinstance(exc_info.value.error, TaskNotCancelableError)
Comment thread
yarolegovich marked this conversation as resolved.


@pytest.mark.asyncio
async def test_on_cancel_task_invalid_result_type():
"""Test on_cancel_task when result_aggregator returns a Message instead of a Task."""
Expand Down
Loading