@@ -62,15 +62,40 @@ def base_client(
6262
6363
6464@pytest .mark .asyncio
65- async def test_transport_aenter_returns_self (mock_transport : AsyncMock ) -> None :
66- result = await ClientTransport .__aenter__ (mock_transport )
67- assert result is mock_transport
65+ async def test_transport_async_context_manager () -> None :
66+ class TestTransport (ClientTransport ):
67+ def __init__ (self ) -> None :
68+ self .closed = False
69+
70+ async def close (self ) -> None :
71+ self .closed = True
72+
73+ TestTransport .__abstractmethods__ = set () # type: ignore[attr-defined]
74+
75+ transport = TestTransport ()
76+ async with transport as t :
77+ assert t is transport
78+
79+ assert transport .closed
6880
6981
7082@pytest .mark .asyncio
71- async def test_transport_aexit_calls_close (mock_transport : AsyncMock ) -> None :
72- await ClientTransport .__aexit__ (mock_transport , None , None , None )
73- mock_transport .close .assert_awaited_once ()
83+ async def test_transport_async_context_manager_on_exception () -> None :
84+ class TestTransport (ClientTransport ):
85+ def __init__ (self ) -> None :
86+ self .closed = False
87+
88+ async def close (self ) -> None :
89+ self .closed = True
90+
91+ TestTransport .__abstractmethods__ = set () # type: ignore[attr-defined]
92+
93+ transport = TestTransport ()
94+ with pytest .raises (RuntimeError , match = 'boom' ):
95+ async with transport :
96+ raise RuntimeError ('boom' )
97+
98+ assert transport .closed
7499
75100
76101@pytest .mark .asyncio
0 commit comments