@@ -644,6 +644,62 @@ async def cancelling_sleep(*_args: object, **_kwargs: object) -> None:
644644 # First attempt ran; cancel hit on the sleep before the second.
645645 assert mock_transport .send_message .call_count == 1
646646
647+ @pytest .mark .asyncio
648+ async def test_cancelled_error_from_transport_propagates (
649+ self , mock_transport : AsyncMock
650+ ) -> None :
651+ """CancelledError raised by the inner transport bypasses retry."""
652+ mock_transport .send_message .side_effect = asyncio .CancelledError
653+ transport = RetryTransport (
654+ mock_transport , max_retries = 3 , base_delay = 0.01 , jitter = False
655+ )
656+ with pytest .raises (asyncio .CancelledError ):
657+ await transport .send_message (SendMessageRequest ())
658+ assert mock_transport .send_message .call_count == 1
659+
660+ @pytest .mark .asyncio
661+ async def test_cancelled_error_from_streaming_transport_propagates (
662+ self , mock_transport : AsyncMock
663+ ) -> None :
664+ """CancelledError raised by the streaming transport bypasses retry."""
665+ mock_transport .send_message_streaming .side_effect = (
666+ asyncio .CancelledError
667+ )
668+ transport = RetryTransport (
669+ mock_transport , max_retries = 3 , base_delay = 0.01 , jitter = False
670+ )
671+ with pytest .raises (asyncio .CancelledError ):
672+ async for _event in transport .send_message_streaming (
673+ SendMessageRequest ()
674+ ):
675+ pass
676+ assert mock_transport .send_message_streaming .call_count == 1
677+
678+ @pytest .mark .asyncio
679+ async def test_on_retry_cancelled_error_propagates (
680+ self , mock_transport : AsyncMock
681+ ) -> None :
682+ """CancelledError from on_retry must not be swallowed by the catch-all."""
683+
684+ async def cancelling_callback (
685+ * _args : object , ** _kwargs : object
686+ ) -> None :
687+ raise asyncio .CancelledError
688+
689+ transport = RetryTransport (
690+ mock_transport ,
691+ max_retries = 2 ,
692+ base_delay = 0.01 ,
693+ jitter = False ,
694+ on_retry = cancelling_callback ,
695+ )
696+ mock_transport .send_message .side_effect = A2AClientTimeoutError (
697+ 'timeout'
698+ )
699+ with pytest .raises (asyncio .CancelledError ):
700+ await transport .send_message (SendMessageRequest ())
701+ assert mock_transport .send_message .call_count == 1
702+
647703 @pytest .mark .asyncio
648704 async def test_streaming_inner_generator_closed_on_consumer_break (
649705 self , mock_transport : AsyncMock
0 commit comments