@@ -2644,3 +2644,91 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
26442644 f'Task { task_id } was specified but does not exist'
26452645 in exc_info .value .error .message
26462646 )
2647+
2648+
2649+ @pytest .mark .asyncio
2650+ async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue ():
2651+ """Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately."""
2652+ mock_task_store = AsyncMock (spec = TaskStore )
2653+ mock_queue_manager = AsyncMock (spec = QueueManager )
2654+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
2655+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
2656+
2657+ task_id = 'error_cleanup_task'
2658+ context_id = 'error_cleanup_ctx'
2659+
2660+ mock_request_context = MagicMock (spec = RequestContext )
2661+ mock_request_context .task_id = task_id
2662+ mock_request_context .context_id = context_id
2663+ mock_request_context_builder .build .return_value = mock_request_context
2664+
2665+ mock_queue = AsyncMock (spec = EventQueue )
2666+ mock_queue_manager .create_or_tap .return_value = mock_queue
2667+
2668+ request_handler = DefaultRequestHandler (
2669+ agent_executor = mock_agent_executor ,
2670+ task_store = mock_task_store ,
2671+ queue_manager = mock_queue_manager ,
2672+ request_context_builder = mock_request_context_builder ,
2673+ )
2674+
2675+ params = MessageSendParams (
2676+ message = Message (
2677+ role = Role .user ,
2678+ message_id = 'msg_error_cleanup' ,
2679+ parts = [],
2680+ # Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error
2681+ )
2682+ )
2683+
2684+ # Mock ResultAggregator to raise exception
2685+ mock_result_aggregator_instance = MagicMock (spec = ResultAggregator )
2686+
2687+ async def raise_error_gen (_consumer ):
2688+ # Raise an exception to simulate consumer failure
2689+ raise ValueError ('Consumer failed!' )
2690+ yield # unreachable
2691+
2692+ mock_result_aggregator_instance .consume_and_emit .side_effect = (
2693+ raise_error_gen
2694+ )
2695+
2696+ # Capture the producer task to verify cancellation
2697+ captured_producer_task = None
2698+ original_register = request_handler ._register_producer
2699+
2700+ async def spy_register_producer (tid , task ):
2701+ nonlocal captured_producer_task
2702+ captured_producer_task = task
2703+ # Wrap the cancel method to spy on it
2704+ task .cancel = MagicMock (wraps = task .cancel )
2705+ await original_register (tid , task )
2706+
2707+ with (
2708+ patch (
2709+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
2710+ return_value = mock_result_aggregator_instance ,
2711+ ),
2712+ patch (
2713+ 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task' ,
2714+ return_value = None ,
2715+ ),
2716+ patch .object (
2717+ request_handler ,
2718+ '_register_producer' ,
2719+ side_effect = spy_register_producer ,
2720+ ),
2721+ ):
2722+ # Act
2723+ with pytest .raises (ValueError , match = 'Consumer failed!' ):
2724+ async for _ in request_handler .on_message_send_stream (
2725+ params , create_server_call_context ()
2726+ ):
2727+ pass
2728+
2729+ assert captured_producer_task is not None
2730+ # Verify producer was cancelled
2731+ captured_producer_task .cancel .assert_called ()
2732+
2733+ # Verify queue closed immediately
2734+ mock_queue .close .assert_awaited_with (immediate = True )
0 commit comments