11import asyncio
2+ import contextlib
23import logging
34import time
45
4849 TaskQueryParams ,
4950 TaskState ,
5051 TaskStatus ,
52+ TaskStatusUpdateEvent ,
5153 TextPart ,
5254 UnsupportedOperationError ,
5355)
@@ -1331,6 +1333,15 @@ async def single_event_stream():
13311333 mock_result_aggregator_instance .consume_and_emit .return_value = (
13321334 single_event_stream ()
13331335 )
1336+ # Signal when background consume_all is started
1337+ bg_started = asyncio .Event ()
1338+
1339+ async def mock_consume_all (_consumer ):
1340+ bg_started .set ()
1341+ # emulate short-running background work
1342+ await asyncio .sleep (0 )
1343+
1344+ mock_result_aggregator_instance .consume_all = mock_consume_all
13341345
13351346 produced_task : asyncio .Task | None = None
13361347 cleanup_task : asyncio .Task | None = None
@@ -1367,6 +1378,9 @@ def create_task_spy(coro):
13671378 assert produced_task is not None
13681379 assert cleanup_task is not None
13691380
1381+ # Assert background consume_all started
1382+ await asyncio .wait_for (bg_started .wait (), timeout = 0.2 )
1383+
13701384 # execute should have started
13711385 await asyncio .wait_for (execute_started .wait (), timeout = 0.1 )
13721386
@@ -1385,6 +1399,91 @@ def create_task_spy(coro):
13851399 # Running agents is cleared
13861400 assert task_id not in request_handler ._running_agents
13871401
1402+ # Cleanup any lingering background tasks started by on_message_send_stream
1403+ # (e.g., background_consume)
1404+ for t in list (request_handler ._background_tasks ):
1405+ t .cancel ()
1406+ with contextlib .suppress (asyncio .CancelledError ):
1407+ await t
1408+
1409+
1410+ @pytest .mark .asyncio
1411+ async def test_disconnect_persists_final_task_to_store ():
1412+ """After client disconnect, ensure background consumer persists final Task to store."""
1413+ task_store = InMemoryTaskStore ()
1414+ queue_manager = InMemoryQueueManager ()
1415+
1416+ # Custom agent that emits a working update then a completed final update
1417+ class FinishingAgent (AgentExecutor ):
1418+ def __init__ (self ):
1419+ self .allow_finish = asyncio .Event ()
1420+
1421+ async def execute (
1422+ self , context : RequestContext , event_queue : EventQueue
1423+ ):
1424+ from typing import cast
1425+
1426+ updater = TaskUpdater (
1427+ event_queue ,
1428+ cast ('str' , context .task_id ),
1429+ cast ('str' , context .context_id ),
1430+ )
1431+ await updater .update_status (TaskState .working )
1432+ await self .allow_finish .wait ()
1433+ await updater .update_status (TaskState .completed )
1434+
1435+ async def cancel (
1436+ self , context : RequestContext , event_queue : EventQueue
1437+ ):
1438+ return None
1439+
1440+ agent = FinishingAgent ()
1441+
1442+ handler = DefaultRequestHandler (
1443+ agent_executor = agent , task_store = task_store , queue_manager = queue_manager
1444+ )
1445+
1446+ params = MessageSendParams (
1447+ message = Message (
1448+ role = Role .user ,
1449+ message_id = 'msg_persist' ,
1450+ parts = [],
1451+ )
1452+ )
1453+
1454+ # Start streaming and consume the first event (working)
1455+ agen = handler .on_message_send_stream (params , create_server_call_context ())
1456+ first = await agen .__anext__ ()
1457+ if isinstance (first , TaskStatusUpdateEvent ):
1458+ assert first .status .state == TaskState .working
1459+ task_id = first .task_id
1460+ else :
1461+ assert (
1462+ isinstance (first , Task ) and first .status .state == TaskState .working
1463+ )
1464+ task_id = first .id
1465+
1466+ # Disconnect client
1467+ await asyncio .wait_for (agen .aclose (), timeout = 0.1 )
1468+
1469+ # Finish agent and allow background consumer to persist final state
1470+ agent .allow_finish .set ()
1471+
1472+ # Wait until background_consume task for this task_id is gone
1473+ await wait_until (
1474+ lambda : all (
1475+ not t .get_name ().startswith (f'background_consume:{ task_id } ' )
1476+ for t in handler ._background_tasks
1477+ ),
1478+ timeout = 1.0 ,
1479+ interval = 0.01 ,
1480+ )
1481+
1482+ # Verify task is persisted as completed
1483+ persisted = await task_store .get (task_id , create_server_call_context ())
1484+ assert persisted is not None
1485+ assert persisted .status .state == TaskState .completed
1486+
13881487
13891488async def wait_until (predicate , timeout : float = 0.2 , interval : float = 0.0 ):
13901489 """Await until predicate() is True or timeout elapses."""
@@ -1505,6 +1604,12 @@ def create_task_spy(coro):
15051604 timeout = 0.1 ,
15061605 )
15071606
1607+ # Cleanup any lingering background tasks
1608+ for t in list (request_handler ._background_tasks ):
1609+ t .cancel ()
1610+ with contextlib .suppress (asyncio .CancelledError ):
1611+ await t
1612+
15081613
15091614@pytest .mark .asyncio
15101615async def test_on_message_send_stream_task_id_mismatch ():
0 commit comments