Skip to content

Commit 0dad45d

Browse files
committed
refactor: Migrate authentication logic to the new interceptor pattern and add streaming support to base client interceptors.
1 parent 92bc02f commit 0dad45d

4 files changed

Lines changed: 155 additions & 56 deletions

File tree

src/a2a/client/auth/interceptor.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import logging # noqa: I001
2-
from typing import Any
32

43
from a2a.client.auth.credentials import CredentialService
54
from a2a.client.client import ClientCallContext
6-
from a2a.types.a2a_pb2 import AgentCard
5+
from a2a.client.interceptors import (
6+
ClientCallInterceptor,
7+
UnionAfterArgs,
8+
UnionBeforeArgs,
9+
)
710

811
logger = logging.getLogger(__name__)
912

1013

11-
class AuthInterceptor:
14+
class AuthInterceptor(ClientCallInterceptor):
1215
"""An interceptor that automatically adds authentication details to requests.
1316
1417
Based on the agent's security schemes.
@@ -17,79 +20,79 @@ class AuthInterceptor:
1720
def __init__(self, credential_service: CredentialService):
1821
self._credential_service = credential_service
1922

20-
async def intercept(
21-
self,
22-
method_name: str,
23-
request_payload: dict[str, Any],
24-
http_kwargs: dict[str, Any],
25-
agent_card: AgentCard | None,
26-
context: ClientCallContext | None,
27-
) -> tuple[dict[str, Any], dict[str, Any]]:
23+
async def before(self, args: UnionBeforeArgs) -> None:
2824
"""Applies authentication headers to the request if credentials are available."""
25+
agent_card = args.agent_card
26+
2927
# Proto3 repeated fields (security) and maps (security_schemes) do not track presence.
3028
# HasField() raises ValueError for them.
3129
# We check for truthiness to see if they are non-empty.
3230
if (
33-
agent_card is None
34-
or not agent_card.security_requirements
31+
not agent_card.security_requirements
3532
or not agent_card.security_schemes
3633
):
37-
return request_payload, http_kwargs
34+
return
3835

3936
for requirement in agent_card.security_requirements:
4037
for scheme_name in requirement.schemes:
4138
credential = await self._credential_service.get_credentials(
42-
scheme_name, context
39+
scheme_name, args.context
4340
)
4441
if credential and scheme_name in agent_card.security_schemes:
4542
scheme = agent_card.security_schemes.get(scheme_name)
4643
if not scheme:
4744
continue
4845

49-
headers = http_kwargs.get('headers', {})
46+
if args.context is None:
47+
args.context = ClientCallContext()
48+
49+
if args.context.service_parameters is None:
50+
args.context.service_parameters = {}
5051

5152
# HTTP Bearer authentication
5253
if (
5354
scheme.HasField('http_auth_security_scheme')
5455
and scheme.http_auth_security_scheme.scheme.lower()
5556
== 'bearer'
5657
):
57-
headers['Authorization'] = f'Bearer {credential}'
58+
args.context.service_parameters['Authorization'] = (
59+
f'Bearer {credential}'
60+
)
5861
logger.debug(
5962
"Added Bearer token for scheme '%s'.",
6063
scheme_name,
6164
)
62-
http_kwargs['headers'] = headers
63-
return request_payload, http_kwargs
65+
return
6466

6567
# OAuth2 and OIDC schemes are implicitly Bearer
6668
if scheme.HasField(
6769
'oauth2_security_scheme'
6870
) or scheme.HasField('open_id_connect_security_scheme'):
69-
headers['Authorization'] = f'Bearer {credential}'
71+
args.context.service_parameters['Authorization'] = (
72+
f'Bearer {credential}'
73+
)
7074
logger.debug(
7175
"Added Bearer token for scheme '%s'.",
7276
scheme_name,
7377
)
74-
http_kwargs['headers'] = headers
75-
return request_payload, http_kwargs
78+
return
7679

7780
# API Key in Header
7881
if (
7982
scheme.HasField('api_key_security_scheme')
8083
and scheme.api_key_security_scheme.location.lower()
8184
== 'header'
8285
):
83-
headers[scheme.api_key_security_scheme.name] = (
84-
credential
85-
)
86+
args.context.service_parameters[
87+
scheme.api_key_security_scheme.name
88+
] = credential
8689
logger.debug(
8790
"Added API Key Header for scheme '%s'.",
8891
scheme_name,
8992
)
90-
http_kwargs['headers'] = headers
91-
return request_payload, http_kwargs
93+
return
9294

9395
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
9496

95-
return request_payload, http_kwargs
97+
async def after(self, args: UnionAfterArgs) -> None:
98+
"""Invoked after the method is executed."""

src/a2a/client/base_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ async def _intercept_after(
515515
async def _format_stream_event(
516516
self, stream_response: StreamResponse, tracker: ClientTaskManager
517517
) -> ClientEvent:
518+
client_event: ClientEvent
518519
if stream_response.HasField('message'):
519520
client_event = (stream_response, None)
520521
await self.consume(client_event, self._card)
Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ClientFactory,
1818
InMemoryContextCredentialStore,
1919
)
20-
from a2a.utils.constants import TransportProtocol
20+
from a2a.client.interceptors import BeforeArgs, ClientCallInput
2121
from a2a.types.a2a_pb2 import (
2222
APIKeySecurityScheme,
2323
AgentCapabilities,
@@ -36,6 +36,7 @@
3636
SendMessageResponse,
3737
StringList,
3838
)
39+
from a2a.utils.constants import TransportProtocol
3940

4041

4142
class HeaderInterceptor(ClientCallInterceptor):
@@ -64,7 +65,6 @@ async def intercept(
6465

6566
def build_success_response(request: httpx.Request) -> httpx.Response:
6667
"""Creates a valid JSON-RPC success response based on the request."""
67-
from a2a.types.a2a_pb2 import SendMessageResponse
6868

6969
request_payload = json.loads(request.content)
7070
message = Message(
@@ -120,19 +120,17 @@ async def test_auth_interceptor_skips_when_no_agent_card(
120120
store: InMemoryContextCredentialStore,
121121
) -> None:
122122
"""Tests that the AuthInterceptor does not modify the request when no AgentCard is provided."""
123-
request_payload = {'foo': 'bar'}
124-
http_kwargs = {'fizz': 'buzz'}
125123
auth_interceptor = AuthInterceptor(credential_service=store)
126-
127-
new_payload, new_kwargs = await auth_interceptor.intercept(
128-
method_name='SendMessage',
129-
request_payload=request_payload,
130-
http_kwargs=http_kwargs,
131-
agent_card=None,
132-
context=ClientCallContext(state={}),
124+
request = SendMessageRequest(message=Message())
125+
context = ClientCallContext(state={})
126+
args = BeforeArgs(
127+
input=ClientCallInput(method='send_message', value=request),
128+
agent_card=AgentCard(),
129+
context=context,
133130
)
134-
assert new_payload == request_payload
135-
assert new_kwargs == http_kwargs
131+
132+
await auth_interceptor.before(args)
133+
assert context.service_parameters is None
136134

137135

138136
@pytest.mark.asyncio
@@ -210,14 +208,13 @@ def wrap_security_scheme(scheme: Any) -> SecurityScheme:
210208
"""Wraps a security scheme in the correct SecurityScheme proto field."""
211209
if isinstance(scheme, APIKeySecurityScheme):
212210
return SecurityScheme(api_key_security_scheme=scheme)
213-
elif isinstance(scheme, HTTPAuthSecurityScheme):
211+
if isinstance(scheme, HTTPAuthSecurityScheme):
214212
return SecurityScheme(http_auth_security_scheme=scheme)
215-
elif isinstance(scheme, OAuth2SecurityScheme):
213+
if isinstance(scheme, OAuth2SecurityScheme):
216214
return SecurityScheme(oauth2_security_scheme=scheme)
217-
elif isinstance(scheme, OpenIdConnectSecurityScheme):
215+
if isinstance(scheme, OpenIdConnectSecurityScheme):
218216
return SecurityScheme(open_id_connect_security_scheme=scheme)
219-
else:
220-
raise ValueError(f'Unknown security scheme type: {type(scheme)}')
217+
raise ValueError(f'Unknown security scheme type: {type(scheme)}')
221218

222219

223220
@dataclass
@@ -363,8 +360,6 @@ async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
363360
scheme_name = 'missing'
364361
session_id = 'session-id'
365362
credential = 'test-token'
366-
request_payload = {'foo': 'bar'}
367-
http_kwargs = {'fizz': 'buzz'}
368363
await store.set_credentials(session_id, scheme_name, credential)
369364
auth_interceptor = AuthInterceptor(credential_service=store)
370365
agent_card = AgentCard(
@@ -386,13 +381,13 @@ async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
386381
],
387382
security_schemes={},
388383
)
389-
390-
new_payload, new_kwargs = await auth_interceptor.intercept(
391-
method_name='SendMessage',
392-
request_payload=request_payload,
393-
http_kwargs=http_kwargs,
384+
request = SendMessageRequest(message=Message())
385+
context = ClientCallContext(state={'sessionId': session_id})
386+
args = BeforeArgs(
387+
input=ClientCallInput(method='send_message', value=request),
394388
agent_card=agent_card,
395-
context=ClientCallContext(state={'sessionId': session_id}),
389+
context=context,
396390
)
397-
assert new_payload == request_payload
398-
assert new_kwargs == http_kwargs
391+
392+
await auth_interceptor.before(args)
393+
assert context.service_parameters is None

tests/client/test_base_client_interceptors.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
AgentCapabilities,
1818
AgentCard,
1919
AgentInterface,
20+
Message,
21+
StreamResponse,
2022
)
2123

2224

@@ -142,3 +144,101 @@ async def mock_before_with_early_return(args: BeforeArgs):
142144
assert isinstance(after_args, AfterArgs)
143145
assert after_args.result.value == 'early_result'
144146
assert after_args.context == context
147+
148+
@pytest.mark.asyncio
149+
async def test_execute_stream_with_interceptors_normal_flow(
150+
self,
151+
base_client: BaseClient,
152+
mock_interceptor: AsyncMock,
153+
):
154+
input_data = ClientCallInput(
155+
method='send_message_streaming', value=MagicMock()
156+
)
157+
context = MagicMock()
158+
159+
async def mock_transport_call(*args, **kwargs):
160+
yield StreamResponse(message=Message(message_id='1'))
161+
162+
# Set up mock interceptor to just pass through
163+
mock_interceptor.before.return_value = None
164+
165+
events = [
166+
e
167+
async for e in base_client._execute_stream_with_interceptors(
168+
input_data=input_data,
169+
context=context,
170+
transport_call=mock_transport_call,
171+
)
172+
]
173+
174+
assert len(events) == 1
175+
176+
# Verify before was called
177+
mock_interceptor.before.assert_called_once()
178+
before_args = mock_interceptor.before.call_args[0][0]
179+
assert isinstance(before_args, BeforeArgs)
180+
assert before_args.input == input_data
181+
assert before_args.context == context
182+
183+
# Verify after was called
184+
mock_interceptor.after.assert_called_once()
185+
after_args = mock_interceptor.after.call_args[0][0]
186+
assert isinstance(after_args, AfterArgs)
187+
assert after_args.result.method == 'send_message_streaming'
188+
189+
@pytest.mark.asyncio
190+
async def test_execute_stream_with_interceptors_early_return(
191+
self,
192+
base_client: BaseClient,
193+
mock_interceptor: AsyncMock,
194+
):
195+
input_data = ClientCallInput(
196+
method='send_message_streaming', value=MagicMock()
197+
)
198+
context = MagicMock()
199+
mock_transport_call = AsyncMock()
200+
201+
# Set up early return in before
202+
early_return_result = ClientCallResult(
203+
method='send_message_streaming',
204+
value=StreamResponse(message=Message(message_id='2')),
205+
)
206+
207+
async def mock_before_with_early_return(args: BeforeArgs):
208+
args.early_return = early_return_result
209+
return {
210+
'early_return': early_return_result,
211+
'executed': [mock_interceptor],
212+
}
213+
214+
mock_interceptor.before.side_effect = mock_before_with_early_return
215+
216+
# Override BaseClient's _intercept_before to respect our early return setup
217+
# as the test's mock interceptor replaces the actual list items
218+
base_client._intercept_before = AsyncMock(
219+
return_value={
220+
'early_return': early_return_result,
221+
'executed': [mock_interceptor],
222+
}
223+
)
224+
225+
events = [
226+
e
227+
async for e in base_client._execute_stream_with_interceptors(
228+
input_data=input_data,
229+
context=context,
230+
transport_call=mock_transport_call,
231+
)
232+
]
233+
234+
assert len(events) == 1
235+
236+
# Verify transport call was NOT made
237+
mock_transport_call.assert_not_called()
238+
239+
# Verify after was called with early return value
240+
mock_interceptor.after.assert_called_once()
241+
after_args = mock_interceptor.after.call_args[0][0]
242+
assert isinstance(after_args, AfterArgs)
243+
assert after_args.result.method == 'send_message_streaming'
244+
assert after_args.context == context

0 commit comments

Comments
 (0)