@@ -2646,169 +2646,53 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
26462646 )
26472647
26482648
2649- @pytest .mark .asyncio
2650- async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue ():
2651- """Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately."""
2652- mock_task_store = AsyncMock (spec = TaskStore )
2653- mock_queue_manager = AsyncMock (spec = QueueManager )
2654- mock_agent_executor = AsyncMock (spec = AgentExecutor )
2655- mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
2656-
2657- task_id = 'error_cleanup_task'
2658- context_id = 'error_cleanup_ctx'
2659-
2660- mock_request_context = MagicMock (spec = RequestContext )
2661- mock_request_context .task_id = task_id
2662- mock_request_context .context_id = context_id
2663- mock_request_context_builder .build .return_value = mock_request_context
2664-
2665- mock_queue = AsyncMock (spec = EventQueue )
2666- mock_queue_manager .create_or_tap .return_value = mock_queue
2667-
2668- request_handler = DefaultRequestHandler (
2669- agent_executor = mock_agent_executor ,
2670- task_store = mock_task_store ,
2671- queue_manager = mock_queue_manager ,
2672- request_context_builder = mock_request_context_builder ,
2673- )
2674-
2675- params = MessageSendParams (
2676- message = Message (
2677- role = Role .user ,
2678- message_id = 'msg_error_cleanup' ,
2679- parts = [],
2680- # Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error
2649+ class HelloWorldAgentExecutor (AgentExecutor ):
2650+ """Test Agent Implementation."""
2651+
2652+ async def execute (
2653+ self ,
2654+ context : RequestContext ,
2655+ event_queue : EventQueue ,
2656+ ) -> None :
2657+ updater = TaskUpdater (
2658+ event_queue ,
2659+ task_id = context .task_id or str (uuid .uuid4 ()),
2660+ context_id = context .context_id or str (uuid .uuid4 ()),
26812661 )
2682- )
2683-
2684- # Mock ResultAggregator to raise exception
2685- mock_result_aggregator_instance = MagicMock (spec = ResultAggregator )
2686-
2687- async def raise_error_gen (_consumer ):
2688- # Raise an exception to simulate consumer failure
2689- raise ValueError ('Consumer failed!' )
2690- yield # unreachable
2691-
2692- mock_result_aggregator_instance .consume_and_emit .side_effect = (
2693- raise_error_gen
2694- )
2695-
2696- # Capture the producer task to verify cancellation
2697- captured_producer_task = None
2698- original_register = request_handler ._register_producer
2699-
2700- async def spy_register_producer (tid , task ):
2701- nonlocal captured_producer_task
2702- captured_producer_task = task
2703- # Wrap the cancel method to spy on it
2704- task .cancel = MagicMock (wraps = task .cancel )
2705- await original_register (tid , task )
2706-
2707- with (
2708- patch (
2709- 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
2710- return_value = mock_result_aggregator_instance ,
2711- ),
2712- patch (
2713- 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task' ,
2714- return_value = None ,
2715- ),
2716- patch .object (
2717- request_handler ,
2718- '_register_producer' ,
2719- side_effect = spy_register_producer ,
2720- ),
2721- ):
2722- # Act
2723- with pytest .raises (ValueError , match = 'Consumer failed!' ):
2724- async for _ in request_handler .on_message_send_stream (
2725- params , create_server_call_context ()
2726- ):
2727- pass
2728-
2729- assert captured_producer_task is not None
2730- # Verify producer was cancelled
2731- captured_producer_task .cancel .assert_called ()
2662+ await updater .update_status (TaskState .working )
2663+ await updater .complete ()
27322664
2733- # Verify queue closed immediately
2734- mock_queue .close .assert_awaited_with (immediate = True )
2665+ async def cancel (
2666+ self , context : RequestContext , event_queue : EventQueue
2667+ ) -> None :
2668+ raise NotImplementedError ('cancel not supported' )
27352669
27362670
2671+ # Repro is straight from the https://github.com/a2aproject/a2a-python/issues/609.
2672+ # It uses timeout to test against infinite wait, if it's going to be flaky,
2673+ # we should reconsider the approach.
27372674@pytest .mark .asyncio
2738- async def test_on_message_send_consumer_error_cancels_producer_and_closes_queue ():
2739- """Test that if the consumer raises an exception during blocking wait, the producer is cancelled."""
2740- mock_task_store = AsyncMock (spec = TaskStore )
2741- mock_queue_manager = AsyncMock (spec = QueueManager )
2742- mock_agent_executor = AsyncMock (spec = AgentExecutor )
2743- mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
2744-
2745- task_id = 'error_cleanup_blocking_task'
2746- context_id = 'error_cleanup_blocking_ctx'
2747-
2748- mock_request_context = MagicMock (spec = RequestContext )
2749- mock_request_context .task_id = task_id
2750- mock_request_context .context_id = context_id
2751- mock_request_context_builder .build .return_value = mock_request_context
2752-
2753- mock_queue = AsyncMock (spec = EventQueue )
2754- mock_queue_manager .create_or_tap .return_value = mock_queue
2675+ @pytest .mark .timeout (1 )
2676+ async def test_on_message_send_error_should_not_hang ():
2677+ """Test that if the consumer raises an exception during blocking wait, the producer is cancelled and no deadlock occurs."""
2678+ agent = HelloWorldAgentExecutor ()
2679+ task_store = AsyncMock (spec = TaskStore )
2680+ task_store .save .side_effect = RuntimeError ('This is an Error!' )
27552681
27562682 request_handler = DefaultRequestHandler (
2757- agent_executor = mock_agent_executor ,
2758- task_store = mock_task_store ,
2759- queue_manager = mock_queue_manager ,
2760- request_context_builder = mock_request_context_builder ,
2683+ agent_executor = agent , task_store = task_store
27612684 )
27622685
27632686 params = MessageSendParams (
27642687 message = Message (
27652688 role = Role .user ,
27662689 message_id = 'msg_error_blocking' ,
2767- parts = [],
2690+ parts = [Part ( root = TextPart ( text = 'Test message' )) ],
27682691 )
27692692 )
27702693
2771- # Mock ResultAggregator to raise exception
2772- mock_result_aggregator_instance = MagicMock (spec = ResultAggregator )
2773- mock_result_aggregator_instance .consume_and_break_on_interrupt .side_effect = ValueError (
2774- 'Consumer failed!'
2775- )
2776-
2777- # Capture the producer task to verify cancellation
2778- captured_producer_task = None
2779- original_register = request_handler ._register_producer
2780-
2781- async def spy_register_producer (tid , task ):
2782- nonlocal captured_producer_task
2783- captured_producer_task = task
2784- # Wrap the cancel method to spy on it
2785- task .cancel = MagicMock (wraps = task .cancel )
2786- await original_register (tid , task )
2787-
2788- with (
2789- patch (
2790- 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
2791- return_value = mock_result_aggregator_instance ,
2792- ),
2793- patch (
2794- 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task' ,
2795- return_value = None ,
2796- ),
2797- patch .object (
2798- request_handler ,
2799- '_register_producer' ,
2800- side_effect = spy_register_producer ,
2801- ),
2802- ):
2803- # Act
2804- with pytest .raises (ValueError , match = 'Consumer failed!' ):
2805- await request_handler .on_message_send (
2806- params , create_server_call_context ()
2807- )
2808-
2809- assert captured_producer_task is not None
2810- # Verify producer was cancelled
2811- captured_producer_task .cancel .assert_called ()
2812-
2813- # Verify queue closed immediately
2814- mock_queue .close .assert_awaited_with (immediate = True )
2694+ with pytest .raises (RuntimeError , match = 'This is an Error!' ):
2695+ async for _ in request_handler .on_message_send_stream (
2696+ params , create_server_call_context ()
2697+ ):
2698+ pass
0 commit comments