Skip to content

Commit 48d9976

Browse files
author
Łukasz Bobiński
committed
test: added test case for deadlock prevention
1 parent d3ec1b3 commit 48d9976

1 file changed

Lines changed: 88 additions & 0 deletions

File tree

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2644,3 +2644,91 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
26442644
f'Task {task_id} was specified but does not exist'
26452645
in exc_info.value.error.message
26462646
)
2647+
2648+
2649+
@pytest.mark.asyncio
2650+
async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue():
2651+
"""Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately."""
2652+
mock_task_store = AsyncMock(spec=TaskStore)
2653+
mock_queue_manager = AsyncMock(spec=QueueManager)
2654+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
2655+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
2656+
2657+
task_id = 'error_cleanup_task'
2658+
context_id = 'error_cleanup_ctx'
2659+
2660+
mock_request_context = MagicMock(spec=RequestContext)
2661+
mock_request_context.task_id = task_id
2662+
mock_request_context.context_id = context_id
2663+
mock_request_context_builder.build.return_value = mock_request_context
2664+
2665+
mock_queue = AsyncMock(spec=EventQueue)
2666+
mock_queue_manager.create_or_tap.return_value = mock_queue
2667+
2668+
request_handler = DefaultRequestHandler(
2669+
agent_executor=mock_agent_executor,
2670+
task_store=mock_task_store,
2671+
queue_manager=mock_queue_manager,
2672+
request_context_builder=mock_request_context_builder,
2673+
)
2674+
2675+
params = MessageSendParams(
2676+
message=Message(
2677+
role=Role.user,
2678+
message_id='msg_error_cleanup',
2679+
parts=[],
2680+
# Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error
2681+
)
2682+
)
2683+
2684+
# Mock ResultAggregator to raise exception
2685+
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
2686+
2687+
async def raise_error_gen(_consumer):
2688+
# Raise an exception to simulate consumer failure
2689+
raise ValueError('Consumer failed!')
2690+
yield # unreachable
2691+
2692+
mock_result_aggregator_instance.consume_and_emit.side_effect = (
2693+
raise_error_gen
2694+
)
2695+
2696+
# Capture the producer task to verify cancellation
2697+
captured_producer_task = None
2698+
original_register = request_handler._register_producer
2699+
2700+
async def spy_register_producer(tid, task):
2701+
nonlocal captured_producer_task
2702+
captured_producer_task = task
2703+
# Wrap the cancel method to spy on it
2704+
task.cancel = MagicMock(wraps=task.cancel)
2705+
await original_register(tid, task)
2706+
2707+
with (
2708+
patch(
2709+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
2710+
return_value=mock_result_aggregator_instance,
2711+
),
2712+
patch(
2713+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
2714+
return_value=None,
2715+
),
2716+
patch.object(
2717+
request_handler,
2718+
'_register_producer',
2719+
side_effect=spy_register_producer,
2720+
),
2721+
):
2722+
# Act
2723+
with pytest.raises(ValueError, match='Consumer failed!'):
2724+
async for _ in request_handler.on_message_send_stream(
2725+
params, create_server_call_context()
2726+
):
2727+
pass
2728+
2729+
assert captured_producer_task is not None
2730+
# Verify producer was cancelled
2731+
captured_producer_task.cancel.assert_called()
2732+
2733+
# Verify queue closed immediately
2734+
mock_queue.close.assert_awaited_with(immediate=True)

0 commit comments

Comments
 (0)