-
Notifications
You must be signed in to change notification settings - Fork 428
Expand file tree
/
Copy pathinterceptor.py
More file actions
93 lines (79 loc) · 3.69 KB
/
interceptor.py
File metadata and controls
93 lines (79 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import logging # noqa: I001
from typing import Any
from a2a.client.auth.credentials import CredentialService
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.types import (
AgentCard,
APIKeySecurityScheme,
HTTPAuthSecurityScheme,
In,
OAuth2SecurityScheme,
OpenIdConnectSecurityScheme,
)
logger = logging.getLogger(__name__)
class AuthInterceptor(ClientCallInterceptor):
"""An interceptor that automatically adds authentication details to requests.
Based on the agent's security schemes.
"""
def __init__(self, credential_service: CredentialService):
self._credential_service = credential_service
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Applies authentication headers to the request if credentials are available."""
if (
agent_card is None
or agent_card.security is None
or agent_card.security_schemes is None
):
return request_payload, http_kwargs
for requirement in agent_card.security:
for scheme_name in requirement:
credential = await self._credential_service.get_credentials(
scheme_name, context
)
if credential and scheme_name in agent_card.security_schemes:
scheme_def_union = agent_card.security_schemes.get(
scheme_name
)
if not scheme_def_union:
continue
scheme_def = scheme_def_union.root
headers = http_kwargs.get('headers', {})
match scheme_def:
# Case 1a: HTTP Bearer scheme with an if guard
case HTTPAuthSecurityScheme() if (
scheme_def.scheme.lower() == 'bearer'
):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type})."
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# Case 1b: OAuth2 and OIDC schemes, which are implicitly Bearer
case (
OAuth2SecurityScheme()
| OpenIdConnectSecurityScheme()
):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type})."
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# Case 2: API Key in Header
case APIKeySecurityScheme(in_=In.header):
headers[scheme_def.name] = credential
logger.debug(
f"Added API Key Header for scheme '{scheme_name}'."
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
return request_payload, http_kwargs