Skip to content

Commit 04c46ca

Browse files
feat(server): add async context manager support to EventQueue
1 parent cced34d commit 04c46ca

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

src/a2a/server/events/event_queue.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import logging
33
import sys
44

5+
from types import TracebackType
6+
7+
from typing_extensions import Self
8+
59
from a2a.types import (
610
Message,
711
Task,
@@ -43,6 +47,19 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None:
4347
self._lock = asyncio.Lock()
4448
logger.debug('EventQueue initialized.')
4549

50+
async def __aenter__(self) -> Self:
51+
"""Enters the async context manager, returning the queue itself."""
52+
return self
53+
54+
async def __aexit__(
55+
self,
56+
exc_type: type[BaseException] | None,
57+
exc_val: BaseException | None,
58+
exc_tb: TracebackType | None,
59+
) -> None:
60+
"""Exits the async context manager, ensuring close() is called."""
61+
await self.close()
62+
4663
async def enqueue_event(self, event: Event) -> None:
4764
"""Enqueues an event to this queue and all its children.
4865

tests/server/events/test_event_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,27 @@ def test_constructor_invalid_max_queue_size() -> None:
6969
):
7070
EventQueue(max_queue_size=-10)
7171

72+
@pytest.mark.asyncio
73+
async def test_event_queue_async_context_manager(
74+
event_queue: EventQueue,
75+
) -> None:
76+
"""Test that EventQueue can be used as an async context manager."""
77+
async with event_queue as q:
78+
assert q is event_queue
79+
assert event_queue.is_closed() is False
80+
assert event_queue.is_closed() is True
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_event_queue_async_context_manager_on_exception(
85+
event_queue: EventQueue,
86+
) -> None:
87+
"""Test that close() is called even when an exception occurs inside the context."""
88+
with pytest.raises(RuntimeError, match='boom'):
89+
async with event_queue:
90+
raise RuntimeError('boom')
91+
assert event_queue.is_closed() is True
92+
7293

7394
@pytest.mark.asyncio
7495
async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:

0 commit comments

Comments
 (0)