Skip to content

Commit 9492a0d

Browse files
committed
test(a2a/client/auth/interceptor.py): add tests in test_auth_middleware.py to improve coverage
Signed-off-by: Shingo OKAWA <shingo.okawa.g.h.c@gmail.com>
1 parent f932ab9 commit 9492a0d

2 files changed

Lines changed: 130 additions & 9 deletions

File tree

src/a2a/client/auth/interceptor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,8 @@ async def intercept(
4949
scheme_def_union = agent_card.securitySchemes.get(
5050
scheme_name
5151
)
52-
if not scheme_def_union:
53-
continue
5452
scheme_def = scheme_def_union.root
55-
5653
headers = http_kwargs.get('headers', {})
57-
5854
match scheme_def:
5955
# Case 1a: HTTP Bearer scheme with an if guard
6056
case HTTPAuthSecurityScheme() if (

tests/client/test_auth_middleware.py

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AgentCapabilities,
1414
AgentCard,
1515
AuthorizationCodeOAuthFlow,
16+
HTTPAuthSecurityScheme,
1617
In,
1718
Message,
1819
MessageSendParams,
@@ -103,6 +104,26 @@ def store():
103104
yield store
104105

105106

107+
@pytest.mark.asyncio
108+
async def test_auth_interceptor_skips_when_no_agent_card(store):
109+
"""
110+
Tests that the AuthInterceptor does not modify the request when no AgentCard is provided.
111+
"""
112+
request_payload = {'foo': 'bar'}
113+
http_kwargs = {'fizz': 'buzz'}
114+
auth_interceptor = AuthInterceptor(credential_service=store)
115+
116+
new_payload, new_kwargs = await auth_interceptor.intercept(
117+
method_name='message/send',
118+
request_payload=request_payload,
119+
http_kwargs=http_kwargs,
120+
agent_card=None,
121+
context=ClientCallContext(state={}),
122+
)
123+
assert new_payload == request_payload
124+
assert new_kwargs == http_kwargs
125+
126+
106127
@pytest.mark.asyncio
107128
async def test_in_memory_context_credential_store(store):
108129
"""
@@ -118,25 +139,21 @@ async def test_in_memory_context_credential_store(store):
118139
context = ClientCallContext(state={'sessionId': session_id})
119140
retrieved_credential = await store.get_credentials(scheme_name, context)
120141
assert retrieved_credential == credential
121-
122142
# Assert: Retrieval with wrong session ID returns None
123143
wrong_context = ClientCallContext(state={'sessionId': 'wrong-session'})
124144
retrieved_credential_wrong = await store.get_credentials(
125145
scheme_name, wrong_context
126146
)
127147
assert retrieved_credential_wrong is None
128-
129148
# Assert: Retrieval with no context returns None
130149
retrieved_credential_none = await store.get_credentials(scheme_name, None)
131150
assert retrieved_credential_none is None
132-
133151
# Assert: Retrieval with context but no sessionId returns None
134152
empty_context = ClientCallContext(state={})
135153
retrieved_credential_empty = await store.get_credentials(
136154
scheme_name, empty_context
137155
)
138156
assert retrieved_credential_empty is None
139-
140157
# Assert: Overwrite the credential when session_id already exists
141158
new_credential = 'new-token'
142159
await store.set_credentials(session_id, scheme_name, new_credential)
@@ -163,13 +180,24 @@ async def test_client_with_simple_interceptor():
163180

164181
@dataclass
165182
class AuthTestCase:
183+
"""
184+
Represents a test scenario for verifying authentication behavior in AuthInterceptor.
185+
"""
186+
166187
url: str
188+
"""The endpoint URL of the agent to which the request is sent."""
167189
session_id: str
190+
"""The client session ID used to fetch credentials from the credential store."""
168191
scheme_name: str
192+
"""The name of the security scheme defined in the agent card."""
169193
credential: str
194+
"""The actual credential value (e.g., API key, access token) to be injected."""
170195
security_scheme: Any
196+
"""The security scheme object (e.g., APIKeySecurityScheme, OAuth2SecurityScheme, etc.) to define behavior."""
171197
expected_header_key: str
198+
"""The expected HTTP header name to be set by the interceptor."""
172199
expected_header_value_func: Callable[[str], str]
200+
"""A function that maps the credential to its expected header value (e.g., lambda c: f"Bearer {c}")."""
173201

174202

175203
api_key_test_case = AuthTestCase(
@@ -223,9 +251,23 @@ class AuthTestCase:
223251
)
224252

225253

254+
bearer_test_case = AuthTestCase(
255+
url='http://agent.com/rpc',
256+
session_id='session-id',
257+
scheme_name='bearer',
258+
credential='bearer-token-123',
259+
security_scheme=HTTPAuthSecurityScheme(
260+
scheme='bearer',
261+
),
262+
expected_header_key='Authorization',
263+
expected_header_value_func=lambda c: f'Bearer {c}',
264+
)
265+
266+
226267
@pytest.mark.asyncio
227268
@pytest.mark.parametrize(
228-
'test_case', [api_key_test_case, oauth2_test_case, oidc_test_case]
269+
'test_case',
270+
[api_key_test_case, oauth2_test_case, oidc_test_case, bearer_test_case],
229271
)
230272
@respx.mock
231273
async def test_auth_interceptor_variants(test_case, store):
@@ -266,3 +308,86 @@ async def test_auth_interceptor_variants(test_case, store):
266308
assert request.headers[
267309
test_case.expected_header_key
268310
] == test_case.expected_header_value_func(test_case.credential)
311+
312+
313+
@pytest.mark.asyncio
314+
async def test_auth_interceptor_falls_back_on_unsupported_scheme(store):
315+
"""
316+
Tests that AuthInterceptor skips applying headers when the scheme type is unsupported.
317+
This ensures the final return statement is hit.
318+
"""
319+
scheme_name = 'unknown'
320+
session_id = 'session-id'
321+
credential = 'ignored-token'
322+
request_payload = {'foo': 'bar'}
323+
http_kwargs = {'fizz': 'buzz'}
324+
await store.set_credentials(session_id, scheme_name, credential)
325+
auth_interceptor = AuthInterceptor(credential_service=store)
326+
agent_card = AgentCard(
327+
url='http://agent.com/rpc',
328+
name='unknownbot',
329+
description='A bot that uses unsupported scheme',
330+
version='1.0',
331+
defaultInputModes=[],
332+
defaultOutputModes=[],
333+
skills=[],
334+
capabilities=AgentCapabilities(),
335+
security=[{scheme_name: []}],
336+
securitySchemes={
337+
'digest': SecurityScheme(
338+
root=HTTPAuthSecurityScheme(
339+
scheme='digest',
340+
type='http',
341+
),
342+
),
343+
},
344+
)
345+
346+
new_payload, new_kwargs = await auth_interceptor.intercept(
347+
method_name='message/send',
348+
request_payload=request_payload,
349+
http_kwargs=http_kwargs,
350+
agent_card=agent_card,
351+
context=ClientCallContext(state={'sessionId': session_id}),
352+
)
353+
assert new_payload == request_payload
354+
assert new_kwargs == http_kwargs
355+
356+
357+
@pytest.mark.asyncio
358+
async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
359+
store,
360+
):
361+
"""
362+
Tests that AuthInterceptor skips a scheme if it's listed in security requirements
363+
but not defined in securitySchemes.
364+
"""
365+
scheme_name = 'missing'
366+
session_id = 'session-id'
367+
credential = 'dummy-token'
368+
request_payload = {'foo': 'bar'}
369+
http_kwargs = {'fizz': 'buzz'}
370+
await store.set_credentials(session_id, scheme_name, credential)
371+
auth_interceptor = AuthInterceptor(credential_service=store)
372+
agent_card = AgentCard(
373+
url='http://agent.com/rpc',
374+
name='missingbot',
375+
description='A bot that uses missing scheme definition',
376+
version='1.0',
377+
defaultInputModes=[],
378+
defaultOutputModes=[],
379+
skills=[],
380+
capabilities=AgentCapabilities(),
381+
security=[{scheme_name: []}],
382+
securitySchemes={},
383+
)
384+
385+
new_payload, new_kwargs = await auth_interceptor.intercept(
386+
method_name='message/send',
387+
request_payload=request_payload,
388+
http_kwargs=http_kwargs,
389+
agent_card=agent_card,
390+
context=ClientCallContext(state={'sessionId': session_id}),
391+
)
392+
assert new_payload == request_payload
393+
assert new_kwargs == http_kwargs

0 commit comments

Comments
 (0)