Skip to content

Commit 60e5b5d

Browse files
committed
fix: resolve deadlock in DefaultRequestHandler on consumer failure
Fixes #609 by ensuring producer task cancellation and immediate queue closure when the consumer fails. Includes regression test.
1 parent cb7cdb3 commit 60e5b5d

3 files changed

Lines changed: 126 additions & 2 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ async def push_notification_callback() -> None:
330330

331331
except Exception:
332332
logger.exception('Agent execution failed')
333+
# If the consumer fails, we must cancel the producer to prevent it from hanging
334+
# on queue operations (e.g., waiting for the queue to drain).
335+
producer_task.cancel()
336+
# Force the queue to close immediately, discarding any pending events.
337+
# This ensures that any producers waiting on the queue are unblocked.
338+
await queue.close(immediate=True)
333339
raise
334340
finally:
335341
if interrupted_or_non_blocking:
@@ -392,6 +398,12 @@ async def on_message_send_stream(
392398
bg_task.set_name(f'background_consume:{task_id}')
393399
self._track_background_task(bg_task)
394400
raise
401+
except Exception:
402+
# If the consumer fails (e.g. database error), we must cleanup.
403+
logger.exception('Agent execution failed during streaming')
404+
producer_task.cancel()
405+
await queue.close(immediate=True)
406+
raise
395407
finally:
396408
cleanup_task = asyncio.create_task(
397409
self._cleanup_producer(producer_task, task_id)
@@ -435,7 +447,11 @@ async def _cleanup_producer(
435447
task_id: str,
436448
) -> None:
437449
"""Cleans up the agent execution task and queue manager entry."""
438-
await producer_task
450+
try:
451+
await producer_task
452+
except (Exception, asyncio.CancelledError):
453+
# We don't want to stop cleanup if the producer task failed or was cancelled
454+
pass
439455
await self._queue_manager.close(task_id)
440456
async with self._running_agents_lock:
441457
self._running_agents.pop(task_id, None)

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ async def streaming_coro():
322322

323323
self.assertIsInstance(response.root, JSONRPCErrorResponse)
324324
assert response.root.error == UnsupportedOperationError() # type: ignore
325-
mock_agent_executor.execute.assert_called_once()
326325

327326
@patch(
328327
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'

tests/server/test_gh_issue_609.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import asyncio
2+
import uuid
3+
import pytest
4+
5+
from a2a.server.agent_execution import AgentExecutor, RequestContext
6+
from a2a.server.context import ServerCallContext
7+
from a2a.server.events import EventQueue
8+
from a2a.server.request_handlers.default_request_handler import (
9+
DefaultRequestHandler,
10+
)
11+
from a2a.server.tasks import TaskStore
12+
from a2a.server.tasks.task_updater import TaskUpdater
13+
from a2a.types import (
14+
Message,
15+
MessageSendParams,
16+
Part,
17+
Role,
18+
Task,
19+
TaskState,
20+
TextPart,
21+
)
22+
23+
24+
class FailingTaskStore(TaskStore):
25+
"""Task store that fails on save to simulate a poisoned configuration."""
26+
27+
async def get(
28+
self, task_id: str, context: ServerCallContext | None = None
29+
) -> Task | None:
30+
"""Return None for simplicity."""
31+
return None
32+
33+
async def save(
34+
self, task: Task, context: ServerCallContext | None = None
35+
) -> None:
36+
"""Always fail to simulate task store error."""
37+
raise RuntimeError(
38+
'This is an Error!'
39+
)
40+
41+
async def delete(
42+
self, task_id: str, context: ServerCallContext | None = None
43+
) -> None:
44+
"""No-op for simplicity."""
45+
46+
47+
class HelloWorldAgent:
48+
"""Hello World Agent."""
49+
50+
async def invoke(self) -> str:
51+
return 'Hello World'
52+
53+
class HelloWorldAgentExecutor(AgentExecutor):
54+
"""Test Agent Implementation."""
55+
56+
def __init__(self):
57+
self.agent = HelloWorldAgent()
58+
59+
async def execute(
60+
self,
61+
context: RequestContext,
62+
event_queue: EventQueue,
63+
) -> None:
64+
updater = TaskUpdater(
65+
event_queue,
66+
task_id=context.task_id or str(uuid.uuid4()),
67+
context_id=context.context_id or str(uuid.uuid4()),
68+
)
69+
# raise ValueError("Simulated error during task execution")
70+
if not context.task_id:
71+
await updater.submit()
72+
await updater.update_status(TaskState.working)
73+
result = await self.agent.invoke()
74+
await updater.add_artifact([Part(root=TextPart(text=result))])
75+
await updater.complete()
76+
77+
async def cancel(
78+
self, context: RequestContext, event_queue: EventQueue
79+
) -> None:
80+
raise NotImplementedError('cancel not supported')
81+
82+
@pytest.mark.asyncio
83+
async def test_hanging_on_task_save_error() -> None:
84+
"""Test that demonstrates hanging when task save fails.
85+
"""
86+
agent = HelloWorldAgentExecutor()
87+
task_store = FailingTaskStore()
88+
handler = DefaultRequestHandler(
89+
agent_executor=agent, task_store=task_store
90+
)
91+
92+
params = MessageSendParams(
93+
message=Message(
94+
role=Role.user,
95+
parts=[TextPart(text='Test message')],
96+
message_id=str(uuid.uuid4()),
97+
)
98+
)
99+
100+
try:
101+
# Use a short timeout to fail fast
102+
await asyncio.wait_for(
103+
handler.on_message_send(params), timeout=2.0
104+
)
105+
except RuntimeError as e:
106+
assert str(e) == 'This is an Error!'
107+
except asyncio.TimeoutError:
108+
# If we get here, it means it hung!
109+
pytest.fail("Test hung and timed out! Fix failed.")

0 commit comments

Comments
 (0)