forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinterceptor.py
More file actions
98 lines (82 loc) · 3.79 KB
/
interceptor.py
File metadata and controls
98 lines (82 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import logging # noqa: I001
from a2a.client.auth.credentials import CredentialService
from a2a.client.client import ClientCallContext
from a2a.client.interceptors import (
AfterArgs,
BeforeArgs,
ClientCallInterceptor,
)
logger = logging.getLogger(__name__)
class AuthInterceptor(ClientCallInterceptor):
"""An interceptor that automatically adds authentication details to requests.
Based on the agent's security schemes.
"""
def __init__(self, credential_service: CredentialService):
self._credential_service = credential_service
async def before(self, args: BeforeArgs) -> None:
"""Applies authentication headers to the request if credentials are available."""
agent_card = args.agent_card
# Proto3 repeated fields (security) and maps (security_schemes) do not track presence.
# HasField() raises ValueError for them.
# We check for truthiness to see if they are non-empty.
if (
not agent_card.security_requirements
or not agent_card.security_schemes
):
return
for requirement in agent_card.security_requirements:
for scheme_name in requirement.schemes:
credential = await self._credential_service.get_credentials(
scheme_name, args.context
)
if credential and scheme_name in agent_card.security_schemes:
scheme = agent_card.security_schemes.get(scheme_name)
if not scheme:
continue
if args.context is None:
args.context = ClientCallContext()
if args.context.service_parameters is None:
args.context.service_parameters = {}
# HTTP Bearer authentication
if (
scheme.HasField('http_auth_security_scheme')
and scheme.http_auth_security_scheme.scheme.lower()
== 'bearer'
):
args.context.service_parameters['Authorization'] = (
f'Bearer {credential}'
)
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
return
# OAuth2 and OIDC schemes are implicitly Bearer
if scheme.HasField(
'oauth2_security_scheme'
) or scheme.HasField('open_id_connect_security_scheme'):
args.context.service_parameters['Authorization'] = (
f'Bearer {credential}'
)
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
return
# API Key in Header
if (
scheme.HasField('api_key_security_scheme')
and scheme.api_key_security_scheme.location.lower()
== 'header'
):
args.context.service_parameters[
scheme.api_key_security_scheme.name
] = credential
logger.debug(
"Added API Key Header for scheme '%s'.",
scheme_name,
)
return
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
async def after(self, args: AfterArgs) -> None:
"""Invoked after the method is executed."""