Skip to content

Commit fd629cc

Browse files
author
Łukasz Bobiński
committed
Simplify fix for consumer error handling
1 parent 3c5f101 commit fd629cc

2 files changed

Lines changed: 91 additions & 10 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,7 @@ async def push_notification_callback() -> None:
329329
)
330330

331331
except Exception:
332-
logger.exception('Agent execution failed')
333-
# If the consumer fails, we must cancel the producer to prevent it from hanging
334-
# on queue operations (e.g., waiting for the queue to drain).
335-
producer_task.cancel()
336-
# Force the queue to close immediately, discarding any pending events.
337-
# This ensures that any producers waiting on the queue are unblocked.
338-
await queue.close(immediate=True)
332+
await self._handle_execution_failure(producer_task, queue)
339333
raise
340334
finally:
341335
if interrupted_or_non_blocking:
@@ -400,9 +394,7 @@ async def on_message_send_stream(
400394
raise
401395
except Exception:
402396
# If the consumer fails (e.g. database error), we must cleanup.
403-
logger.exception('Agent execution failed during streaming')
404-
producer_task.cancel()
405-
await queue.close(immediate=True)
397+
await self._handle_execution_failure(producer_task, queue)
406398
raise
407399
finally:
408400
cleanup_task = asyncio.create_task(
@@ -441,6 +433,18 @@ def _on_done(completed: asyncio.Task) -> None:
441433

442434
task.add_done_callback(_on_done)
443435

436+
async def _handle_execution_failure(
437+
self, producer_task: asyncio.Task, queue: EventQueue
438+
) -> None:
439+
"""Cancels the producer and closes the queue immediately on failure."""
440+
logger.exception('Agent execution failed')
441+
# If the consumer fails, we must cancel the producer to prevent it from hanging
442+
# on queue operations (e.g., waiting for the queue to drain).
443+
producer_task.cancel()
444+
# Force the queue to close immediately, discarding any pending events.
445+
# This ensures that any producers waiting on the queue are unblocked.
446+
await queue.close(immediate=True)
447+
444448
async def _cleanup_producer(
445449
self,
446450
producer_task: asyncio.Task,

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2732,3 +2732,80 @@ async def spy_register_producer(tid, task):
27322732

27332733
# Verify queue closed immediately
27342734
mock_queue.close.assert_awaited_with(immediate=True)
2735+
2736+
@pytest.mark.asyncio
2737+
async def test_on_message_send_consumer_error_cancels_producer_and_closes_queue():
2738+
"""Test that if the consumer raises an exception during blocking wait, the producer is cancelled."""
2739+
mock_task_store = AsyncMock(spec=TaskStore)
2740+
mock_queue_manager = AsyncMock(spec=QueueManager)
2741+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
2742+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
2743+
2744+
task_id = 'error_cleanup_blocking_task'
2745+
context_id = 'error_cleanup_blocking_ctx'
2746+
2747+
mock_request_context = MagicMock(spec=RequestContext)
2748+
mock_request_context.task_id = task_id
2749+
mock_request_context.context_id = context_id
2750+
mock_request_context_builder.build.return_value = mock_request_context
2751+
2752+
mock_queue = AsyncMock(spec=EventQueue)
2753+
mock_queue_manager.create_or_tap.return_value = mock_queue
2754+
2755+
request_handler = DefaultRequestHandler(
2756+
agent_executor=mock_agent_executor,
2757+
task_store=mock_task_store,
2758+
queue_manager=mock_queue_manager,
2759+
request_context_builder=mock_request_context_builder,
2760+
)
2761+
2762+
params = MessageSendParams(
2763+
message=Message(
2764+
role=Role.user,
2765+
message_id='msg_error_blocking',
2766+
parts=[],
2767+
)
2768+
)
2769+
2770+
# Mock ResultAggregator to raise exception
2771+
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
2772+
mock_result_aggregator_instance.consume_and_break_on_interrupt.side_effect = ValueError('Consumer failed!')
2773+
2774+
# Capture the producer task to verify cancellation
2775+
captured_producer_task = None
2776+
original_register = request_handler._register_producer
2777+
2778+
async def spy_register_producer(tid, task):
2779+
nonlocal captured_producer_task
2780+
captured_producer_task = task
2781+
# Wrap the cancel method to spy on it
2782+
task.cancel = MagicMock(wraps=task.cancel)
2783+
await original_register(tid, task)
2784+
2785+
with (
2786+
patch(
2787+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
2788+
return_value=mock_result_aggregator_instance,
2789+
),
2790+
patch(
2791+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
2792+
return_value=None,
2793+
),
2794+
patch.object(
2795+
request_handler,
2796+
'_register_producer',
2797+
side_effect=spy_register_producer,
2798+
),
2799+
):
2800+
# Act
2801+
with pytest.raises(ValueError, match='Consumer failed!'):
2802+
await request_handler.on_message_send(
2803+
params, create_server_call_context()
2804+
)
2805+
2806+
assert captured_producer_task is not None
2807+
# Verify producer was cancelled
2808+
captured_producer_task.cancel.assert_called()
2809+
2810+
# Verify queue closed immediately
2811+
mock_queue.close.assert_awaited_with(immediate=True)

0 commit comments

Comments
 (0)