Skip to content

Commit e280a63

Browse files
committed
Add issue repro as a test, simplify
1 parent 0fc2511 commit e280a63

4 files changed

Lines changed: 51 additions & 169 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ dev = [
109109
"no_implicit_optional",
110110
"trio",
111111
"uvicorn>=0.35.0",
112+
"pytest-timeout>=2.4.0",
112113
]
113114

114115
[[tool.uv.index]]

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ async def push_notification_callback() -> None:
329329
)
330330

331331
except Exception:
332-
await self._handle_execution_failure(producer_task, queue)
332+
logger.exception('Agent execution failed')
333+
producer_task.cancel()
333334
raise
334335
finally:
335336
if interrupted_or_non_blocking:
@@ -392,10 +393,6 @@ async def on_message_send_stream(
392393
bg_task.set_name(f'background_consume:{task_id}')
393394
self._track_background_task(bg_task)
394395
raise
395-
except Exception:
396-
# If the consumer fails (e.g. database error), we must cleanup.
397-
await self._handle_execution_failure(producer_task, queue)
398-
raise
399396
finally:
400397
cleanup_task = asyncio.create_task(
401398
self._cleanup_producer(producer_task, task_id)
@@ -433,18 +430,6 @@ def _on_done(completed: asyncio.Task) -> None:
433430

434431
task.add_done_callback(_on_done)
435432

436-
async def _handle_execution_failure(
437-
self, producer_task: asyncio.Task, queue: EventQueue
438-
) -> None:
439-
"""Cancels the producer and closes the queue immediately on failure."""
440-
logger.exception('Agent execution failed')
441-
# If the consumer fails, we must cancel the producer to prevent it from hanging
442-
# on queue operations (e.g., waiting for the queue to drain).
443-
producer_task.cancel()
444-
# Force the queue to close immediately, discarding any pending events.
445-
# This ensures that any producers waiting on the queue are unblocked.
446-
await queue.close(immediate=True)
447-
448433
async def _cleanup_producer(
449434
self,
450435
producer_task: asyncio.Task,
@@ -457,8 +442,6 @@ async def _cleanup_producer(
457442
logger.debug(
458443
'Producer task %s was cancelled during cleanup', task_id
459444
)
460-
except Exception:
461-
logger.exception('Producer task %s failed during cleanup', task_id)
462445
await self._queue_manager.close(task_id)
463446
async with self._running_agents_lock:
464447
self._running_agents.pop(task_id, None)

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 34 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -2646,169 +2646,53 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
26462646
)
26472647

26482648

2649-
@pytest.mark.asyncio
2650-
async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue():
2651-
"""Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately."""
2652-
mock_task_store = AsyncMock(spec=TaskStore)
2653-
mock_queue_manager = AsyncMock(spec=QueueManager)
2654-
mock_agent_executor = AsyncMock(spec=AgentExecutor)
2655-
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
2656-
2657-
task_id = 'error_cleanup_task'
2658-
context_id = 'error_cleanup_ctx'
2659-
2660-
mock_request_context = MagicMock(spec=RequestContext)
2661-
mock_request_context.task_id = task_id
2662-
mock_request_context.context_id = context_id
2663-
mock_request_context_builder.build.return_value = mock_request_context
2664-
2665-
mock_queue = AsyncMock(spec=EventQueue)
2666-
mock_queue_manager.create_or_tap.return_value = mock_queue
2667-
2668-
request_handler = DefaultRequestHandler(
2669-
agent_executor=mock_agent_executor,
2670-
task_store=mock_task_store,
2671-
queue_manager=mock_queue_manager,
2672-
request_context_builder=mock_request_context_builder,
2673-
)
2674-
2675-
params = MessageSendParams(
2676-
message=Message(
2677-
role=Role.user,
2678-
message_id='msg_error_cleanup',
2679-
parts=[],
2680-
# Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error
2649+
class HelloWorldAgentExecutor(AgentExecutor):
2650+
"""Test Agent Implementation."""
2651+
2652+
async def execute(
2653+
self,
2654+
context: RequestContext,
2655+
event_queue: EventQueue,
2656+
) -> None:
2657+
updater = TaskUpdater(
2658+
event_queue,
2659+
task_id=context.task_id or str(uuid.uuid4()),
2660+
context_id=context.context_id or str(uuid.uuid4()),
26812661
)
2682-
)
2683-
2684-
# Mock ResultAggregator to raise exception
2685-
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
2686-
2687-
async def raise_error_gen(_consumer):
2688-
# Raise an exception to simulate consumer failure
2689-
raise ValueError('Consumer failed!')
2690-
yield # unreachable
2691-
2692-
mock_result_aggregator_instance.consume_and_emit.side_effect = (
2693-
raise_error_gen
2694-
)
2695-
2696-
# Capture the producer task to verify cancellation
2697-
captured_producer_task = None
2698-
original_register = request_handler._register_producer
2699-
2700-
async def spy_register_producer(tid, task):
2701-
nonlocal captured_producer_task
2702-
captured_producer_task = task
2703-
# Wrap the cancel method to spy on it
2704-
task.cancel = MagicMock(wraps=task.cancel)
2705-
await original_register(tid, task)
2706-
2707-
with (
2708-
patch(
2709-
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
2710-
return_value=mock_result_aggregator_instance,
2711-
),
2712-
patch(
2713-
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
2714-
return_value=None,
2715-
),
2716-
patch.object(
2717-
request_handler,
2718-
'_register_producer',
2719-
side_effect=spy_register_producer,
2720-
),
2721-
):
2722-
# Act
2723-
with pytest.raises(ValueError, match='Consumer failed!'):
2724-
async for _ in request_handler.on_message_send_stream(
2725-
params, create_server_call_context()
2726-
):
2727-
pass
2728-
2729-
assert captured_producer_task is not None
2730-
# Verify producer was cancelled
2731-
captured_producer_task.cancel.assert_called()
2662+
await updater.update_status(TaskState.working)
2663+
await updater.complete()
27322664

2733-
# Verify queue closed immediately
2734-
mock_queue.close.assert_awaited_with(immediate=True)
2665+
async def cancel(
2666+
self, context: RequestContext, event_queue: EventQueue
2667+
) -> None:
2668+
raise NotImplementedError('cancel not supported')
27352669

27362670

2671+
# Repro is straight from the https://github.com/a2aproject/a2a-python/issues/609.
2672+
# It uses timeout to test against infinite wait, if it's going to be flaky,
2673+
# we should reconsider the approach.
27372674
@pytest.mark.asyncio
2738-
async def test_on_message_send_consumer_error_cancels_producer_and_closes_queue():
2739-
"""Test that if the consumer raises an exception during blocking wait, the producer is cancelled."""
2740-
mock_task_store = AsyncMock(spec=TaskStore)
2741-
mock_queue_manager = AsyncMock(spec=QueueManager)
2742-
mock_agent_executor = AsyncMock(spec=AgentExecutor)
2743-
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
2744-
2745-
task_id = 'error_cleanup_blocking_task'
2746-
context_id = 'error_cleanup_blocking_ctx'
2747-
2748-
mock_request_context = MagicMock(spec=RequestContext)
2749-
mock_request_context.task_id = task_id
2750-
mock_request_context.context_id = context_id
2751-
mock_request_context_builder.build.return_value = mock_request_context
2752-
2753-
mock_queue = AsyncMock(spec=EventQueue)
2754-
mock_queue_manager.create_or_tap.return_value = mock_queue
2675+
@pytest.mark.timeout(1)
2676+
async def test_on_message_send_error_should_not_hang():
2677+
"""Test that if the consumer raises an exception during blocking wait, the producer is cancelled and no deadlock occurs."""
2678+
agent = HelloWorldAgentExecutor()
2679+
task_store = AsyncMock(spec=TaskStore)
2680+
task_store.save.side_effect = RuntimeError('This is an Error!')
27552681

27562682
request_handler = DefaultRequestHandler(
2757-
agent_executor=mock_agent_executor,
2758-
task_store=mock_task_store,
2759-
queue_manager=mock_queue_manager,
2760-
request_context_builder=mock_request_context_builder,
2683+
agent_executor=agent, task_store=task_store
27612684
)
27622685

27632686
params = MessageSendParams(
27642687
message=Message(
27652688
role=Role.user,
27662689
message_id='msg_error_blocking',
2767-
parts=[],
2690+
parts=[Part(root=TextPart(text='Test message'))],
27682691
)
27692692
)
27702693

2771-
# Mock ResultAggregator to raise exception
2772-
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
2773-
mock_result_aggregator_instance.consume_and_break_on_interrupt.side_effect = ValueError(
2774-
'Consumer failed!'
2775-
)
2776-
2777-
# Capture the producer task to verify cancellation
2778-
captured_producer_task = None
2779-
original_register = request_handler._register_producer
2780-
2781-
async def spy_register_producer(tid, task):
2782-
nonlocal captured_producer_task
2783-
captured_producer_task = task
2784-
# Wrap the cancel method to spy on it
2785-
task.cancel = MagicMock(wraps=task.cancel)
2786-
await original_register(tid, task)
2787-
2788-
with (
2789-
patch(
2790-
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
2791-
return_value=mock_result_aggregator_instance,
2792-
),
2793-
patch(
2794-
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
2795-
return_value=None,
2796-
),
2797-
patch.object(
2798-
request_handler,
2799-
'_register_producer',
2800-
side_effect=spy_register_producer,
2801-
),
2802-
):
2803-
# Act
2804-
with pytest.raises(ValueError, match='Consumer failed!'):
2805-
await request_handler.on_message_send(
2806-
params, create_server_call_context()
2807-
)
2808-
2809-
assert captured_producer_task is not None
2810-
# Verify producer was cancelled
2811-
captured_producer_task.cancel.assert_called()
2812-
2813-
# Verify queue closed immediately
2814-
mock_queue.close.assert_awaited_with(immediate=True)
2694+
with pytest.raises(RuntimeError, match='This is an Error!'):
2695+
async for _ in request_handler.on_message_send_stream(
2696+
params, create_server_call_context()
2697+
):
2698+
pass

uv.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)