Skip to content

Commit 3d0b60a

Browse files
feat: add async context manager support to BaseClient
1 parent 390b763 commit 3d0b60a

2 files changed

Lines changed: 36 additions & 0 deletions

File tree

src/a2a/client/base_client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from collections.abc import AsyncIterator, Callable
2+
from types import TracebackType
23
from typing import Any
34

5+
from typing_extensions import Self
6+
47
from a2a.client.client import (
58
Client,
69
ClientCallContext,
@@ -43,6 +46,19 @@ def __init__(
4346
self._config = config
4447
self._transport = transport
4548

49+
async def __aenter__(self) -> Self:
50+
"""Enters the async context manager, returning the client itself."""
51+
return self
52+
53+
async def __aexit__(
54+
self,
55+
exc_type: type[BaseException] | None,
56+
exc_val: BaseException | None,
57+
exc_tb: TracebackType | None,
58+
) -> None:
59+
"""Exits the async context manager, ensuring close() is called."""
60+
await self.close()
61+
4662
async def send_message(
4763
self,
4864
request: Message,

tests/client/test_base_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ def base_client(
6161
)
6262

6363

64+
@pytest.mark.asyncio
65+
async def test_base_client_async_context_manager(
66+
base_client: BaseClient, mock_transport: AsyncMock
67+
) -> None:
68+
async with base_client as client:
69+
assert client is base_client
70+
mock_transport.close.assert_not_awaited()
71+
mock_transport.close.assert_awaited_once()
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_base_client_async_context_manager_on_exception(
76+
base_client: BaseClient, mock_transport: AsyncMock
77+
) -> None:
78+
with pytest.raises(RuntimeError, match='boom'):
79+
async with base_client:
80+
raise RuntimeError('boom')
81+
mock_transport.close.assert_awaited_once()
82+
83+
6484
@pytest.mark.asyncio
6585
async def test_send_message_streaming(
6686
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message

0 commit comments

Comments
 (0)