@@ -2732,3 +2732,80 @@ async def spy_register_producer(tid, task):
27322732
27332733 # Verify queue closed immediately
27342734 mock_queue .close .assert_awaited_with (immediate = True )
2735+
2736+ @pytest .mark .asyncio
2737+ async def test_on_message_send_consumer_error_cancels_producer_and_closes_queue ():
2738+ """Test that if the consumer raises an exception during blocking wait, the producer is cancelled."""
2739+ mock_task_store = AsyncMock (spec = TaskStore )
2740+ mock_queue_manager = AsyncMock (spec = QueueManager )
2741+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
2742+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
2743+
2744+ task_id = 'error_cleanup_blocking_task'
2745+ context_id = 'error_cleanup_blocking_ctx'
2746+
2747+ mock_request_context = MagicMock (spec = RequestContext )
2748+ mock_request_context .task_id = task_id
2749+ mock_request_context .context_id = context_id
2750+ mock_request_context_builder .build .return_value = mock_request_context
2751+
2752+ mock_queue = AsyncMock (spec = EventQueue )
2753+ mock_queue_manager .create_or_tap .return_value = mock_queue
2754+
2755+ request_handler = DefaultRequestHandler (
2756+ agent_executor = mock_agent_executor ,
2757+ task_store = mock_task_store ,
2758+ queue_manager = mock_queue_manager ,
2759+ request_context_builder = mock_request_context_builder ,
2760+ )
2761+
2762+ params = MessageSendParams (
2763+ message = Message (
2764+ role = Role .user ,
2765+ message_id = 'msg_error_blocking' ,
2766+ parts = [],
2767+ )
2768+ )
2769+
2770+ # Mock ResultAggregator to raise exception
2771+ mock_result_aggregator_instance = MagicMock (spec = ResultAggregator )
2772+ mock_result_aggregator_instance .consume_and_break_on_interrupt .side_effect = ValueError ('Consumer failed!' )
2773+
2774+ # Capture the producer task to verify cancellation
2775+ captured_producer_task = None
2776+ original_register = request_handler ._register_producer
2777+
2778+ async def spy_register_producer (tid , task ):
2779+ nonlocal captured_producer_task
2780+ captured_producer_task = task
2781+ # Wrap the cancel method to spy on it
2782+ task .cancel = MagicMock (wraps = task .cancel )
2783+ await original_register (tid , task )
2784+
2785+ with (
2786+ patch (
2787+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
2788+ return_value = mock_result_aggregator_instance ,
2789+ ),
2790+ patch (
2791+ 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task' ,
2792+ return_value = None ,
2793+ ),
2794+ patch .object (
2795+ request_handler ,
2796+ '_register_producer' ,
2797+ side_effect = spy_register_producer ,
2798+ ),
2799+ ):
2800+ # Act
2801+ with pytest .raises (ValueError , match = 'Consumer failed!' ):
2802+ await request_handler .on_message_send (
2803+ params , create_server_call_context ()
2804+ )
2805+
2806+ assert captured_producer_task is not None
2807+ # Verify producer was cancelled
2808+ captured_producer_task .cancel .assert_called ()
2809+
2810+ # Verify queue closed immediately
2811+ mock_queue .close .assert_awaited_with (immediate = True )
0 commit comments