Skip to content

Commit 7f8d258

Browse files
committed
Merge branch '1.0-dev' of https://github.com/sokoliva/a2a-python into write-0-3
2 parents 123deea + a910cbc commit 7f8d258

25 files changed

Lines changed: 663 additions & 300 deletions

src/a2a/client/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,21 @@
99
)
1010
from a2a.client.base_client import BaseClient
1111
from a2a.client.card_resolver import A2ACardResolver
12-
from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
12+
from a2a.client.client import (
13+
Client,
14+
ClientCallContext,
15+
ClientConfig,
16+
ClientEvent,
17+
Consumer,
18+
)
1319
from a2a.client.client_factory import ClientFactory, minimal_agent_card
1420
from a2a.client.errors import (
1521
A2AClientError,
1622
A2AClientTimeoutError,
1723
AgentCardResolutionError,
1824
)
1925
from a2a.client.helpers import create_text_message_object
20-
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
26+
from a2a.client.interceptors import ClientCallInterceptor
2127

2228

2329
logger = logging.getLogger(__name__)

src/a2a/client/auth/credentials.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22

3-
from a2a.client.middleware import ClientCallContext
3+
from a2a.client.client import ClientCallContext
44

55

66
class CredentialService(ABC):

src/a2a/client/auth/interceptor.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import logging # noqa: I001
2-
from typing import Any
32

43
from a2a.client.auth.credentials import CredentialService
5-
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
6-
from a2a.types.a2a_pb2 import AgentCard
4+
from a2a.client.client import ClientCallContext
5+
from a2a.client.interceptors import (
6+
AfterArgs,
7+
BeforeArgs,
8+
ClientCallInterceptor,
9+
)
710

811
logger = logging.getLogger(__name__)
912

@@ -17,79 +20,79 @@ class AuthInterceptor(ClientCallInterceptor):
1720
def __init__(self, credential_service: CredentialService):
1821
self._credential_service = credential_service
1922

20-
async def intercept(
21-
self,
22-
method_name: str,
23-
request_payload: dict[str, Any],
24-
http_kwargs: dict[str, Any],
25-
agent_card: AgentCard | None,
26-
context: ClientCallContext | None,
27-
) -> tuple[dict[str, Any], dict[str, Any]]:
23+
async def before(self, args: BeforeArgs) -> None:
2824
"""Applies authentication headers to the request if credentials are available."""
25+
agent_card = args.agent_card
26+
2927
# Proto3 repeated fields (security) and maps (security_schemes) do not track presence.
3028
# HasField() raises ValueError for them.
3129
# We check for truthiness to see if they are non-empty.
3230
if (
33-
agent_card is None
34-
or not agent_card.security_requirements
31+
not agent_card.security_requirements
3532
or not agent_card.security_schemes
3633
):
37-
return request_payload, http_kwargs
34+
return
3835

3936
for requirement in agent_card.security_requirements:
4037
for scheme_name in requirement.schemes:
4138
credential = await self._credential_service.get_credentials(
42-
scheme_name, context
39+
scheme_name, args.context
4340
)
4441
if credential and scheme_name in agent_card.security_schemes:
4542
scheme = agent_card.security_schemes.get(scheme_name)
4643
if not scheme:
4744
continue
4845

49-
headers = http_kwargs.get('headers', {})
46+
if args.context is None:
47+
args.context = ClientCallContext()
48+
49+
if args.context.service_parameters is None:
50+
args.context.service_parameters = {}
5051

5152
# HTTP Bearer authentication
5253
if (
5354
scheme.HasField('http_auth_security_scheme')
5455
and scheme.http_auth_security_scheme.scheme.lower()
5556
== 'bearer'
5657
):
57-
headers['Authorization'] = f'Bearer {credential}'
58+
args.context.service_parameters['Authorization'] = (
59+
f'Bearer {credential}'
60+
)
5861
logger.debug(
5962
"Added Bearer token for scheme '%s'.",
6063
scheme_name,
6164
)
62-
http_kwargs['headers'] = headers
63-
return request_payload, http_kwargs
65+
return
6466

6567
# OAuth2 and OIDC schemes are implicitly Bearer
6668
if scheme.HasField(
6769
'oauth2_security_scheme'
6870
) or scheme.HasField('open_id_connect_security_scheme'):
69-
headers['Authorization'] = f'Bearer {credential}'
71+
args.context.service_parameters['Authorization'] = (
72+
f'Bearer {credential}'
73+
)
7074
logger.debug(
7175
"Added Bearer token for scheme '%s'.",
7276
scheme_name,
7377
)
74-
http_kwargs['headers'] = headers
75-
return request_payload, http_kwargs
78+
return
7679

7780
# API Key in Header
7881
if (
7982
scheme.HasField('api_key_security_scheme')
8083
and scheme.api_key_security_scheme.location.lower()
8184
== 'header'
8285
):
83-
headers[scheme.api_key_security_scheme.name] = (
84-
credential
85-
)
86+
args.context.service_parameters[
87+
scheme.api_key_security_scheme.name
88+
] = credential
8689
logger.debug(
8790
"Added API Key Header for scheme '%s'.",
8891
scheme_name,
8992
)
90-
http_kwargs['headers'] = headers
91-
return request_payload, http_kwargs
93+
return
9294

9395
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
9496

95-
return request_payload, http_kwargs
97+
async def after(self, args: AfterArgs) -> None:
98+
"""Invoked after the method is executed."""

0 commit comments

Comments
 (0)