forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinterceptor.py
More file actions
95 lines (81 loc) · 3.81 KB
/
interceptor.py
File metadata and controls
95 lines (81 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import logging # noqa: I001
from typing import Any
from a2a.client.auth.credentials import CredentialService
from a2a.client.client import ClientCallContext
from a2a.types.a2a_pb2 import AgentCard
logger = logging.getLogger(__name__)
class AuthInterceptor:
"""An interceptor that automatically adds authentication details to requests.
Based on the agent's security schemes.
"""
def __init__(self, credential_service: CredentialService):
self._credential_service = credential_service
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Applies authentication headers to the request if credentials are available."""
# Proto3 repeated fields (security) and maps (security_schemes) do not track presence.
# HasField() raises ValueError for them.
# We check for truthiness to see if they are non-empty.
if (
agent_card is None
or not agent_card.security_requirements
or not agent_card.security_schemes
):
return request_payload, http_kwargs
for requirement in agent_card.security_requirements:
for scheme_name in requirement.schemes:
credential = await self._credential_service.get_credentials(
scheme_name, context
)
if credential and scheme_name in agent_card.security_schemes:
scheme = agent_card.security_schemes.get(scheme_name)
if not scheme:
continue
headers = http_kwargs.get('headers', {})
# HTTP Bearer authentication
if (
scheme.HasField('http_auth_security_scheme')
and scheme.http_auth_security_scheme.scheme.lower()
== 'bearer'
):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# OAuth2 and OIDC schemes are implicitly Bearer
if scheme.HasField(
'oauth2_security_scheme'
) or scheme.HasField('open_id_connect_security_scheme'):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# API Key in Header
if (
scheme.HasField('api_key_security_scheme')
and scheme.api_key_security_scheme.location.lower()
== 'header'
):
headers[scheme.api_key_security_scheme.name] = (
credential
)
logger.debug(
"Added API Key Header for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
return request_payload, http_kwargs