Skip to content

Commit 13f7dd3

Browse files
committed
make card modifier and extended card modifier async
1 parent ffe31e2 commit 13f7dd3

9 files changed

Lines changed: 70 additions & 59 deletions

File tree

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -72,9 +72,10 @@ def __init__( # noqa: PLR0913
7272
http_handler: RequestHandler,
7373
extended_agent_card: AgentCard | None = None,
7474
context_builder: CallContextBuilder | None = None,
75-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
75+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
76+
| None = None,
7677
extended_card_modifier: Callable[
77-
[AgentCard, ServerCallContext], AgentCard
78+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
7879
]
7980
| None = None,
8081
max_content_length: int | None = 10 * 1024 * 1024, # 10MB
@@ -90,9 +91,9 @@ def __init__( # noqa: PLR0913
9091
context_builder: The CallContextBuilder used to construct the
9192
ServerCallContext passed to the http_handler. If None, no
9293
ServerCallContext is passed.
93-
card_modifier: An optional callback to dynamically modify the public
94+
card_modifier: An optional async callback to dynamically modify the public
9495
agent card before it is served.
95-
extended_card_modifier: An optional callback to dynamically modify
96+
extended_card_modifier: An optional async callback to dynamically modify
9697
the extended agent card before it is served. It receives the
9798
call context.
9899
max_content_length: The maximum allowed content length for incoming

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import traceback
55

66
from abc import ABC, abstractmethod
7-
from collections.abc import AsyncGenerator, Callable
7+
from collections.abc import AsyncGenerator, Awaitable, Callable
88
from typing import TYPE_CHECKING, Any
99

1010
from pydantic import ValidationError
@@ -178,9 +178,10 @@ def __init__( # noqa: PLR0913
178178
http_handler: RequestHandler,
179179
extended_agent_card: AgentCard | None = None,
180180
context_builder: CallContextBuilder | None = None,
181-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
181+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
182+
| None = None,
182183
extended_card_modifier: Callable[
183-
[AgentCard, ServerCallContext], AgentCard
184+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
184185
]
185186
| None = None,
186187
max_content_length: int | None = 10 * 1024 * 1024, # 10MB
@@ -196,9 +197,9 @@ def __init__( # noqa: PLR0913
196197
context_builder: The CallContextBuilder used to construct the
197198
ServerCallContext passed to the http_handler. If None, no
198199
ServerCallContext is passed.
199-
card_modifier: An optional callback to dynamically modify the public
200+
card_modifier: An optional async callback to dynamically modify the public
200201
agent card before it is served.
201-
extended_card_modifier: An optional callback to dynamically modify
202+
extended_card_modifier: An optional async callback to dynamically modify
202203
the extended agent card before it is served. It receives the
203204
call context.
204205
max_content_length: The maximum allowed content length for incoming
@@ -576,7 +577,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
576577

577578
card_to_serve = self.agent_card
578579
if self.card_modifier:
579-
card_to_serve = self.card_modifier(card_to_serve)
580+
card_to_serve = await self.card_modifier(card_to_serve)
580581

581582
return JSONResponse(
582583
card_to_serve.model_dump(
@@ -605,7 +606,9 @@ async def _handle_get_authenticated_extended_agent_card(
605606
context = self._context_builder.build(request)
606607
# If no base extended card is provided, pass the public card to the modifier
607608
base_card = card_to_serve if card_to_serve else self.agent_card
608-
card_to_serve = self.extended_card_modifier(base_card, context)
609+
card_to_serve = await self.extended_card_modifier(
610+
base_card, context
611+
)
609612

610613
if card_to_serve:
611614
return JSONResponse(

src/a2a/server/apps/jsonrpc/starlette_app.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -54,9 +54,10 @@ def __init__( # noqa: PLR0913
5454
http_handler: RequestHandler,
5555
extended_agent_card: AgentCard | None = None,
5656
context_builder: CallContextBuilder | None = None,
57-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
57+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
58+
| None = None,
5859
extended_card_modifier: Callable[
59-
[AgentCard, ServerCallContext], AgentCard
60+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6061
]
6162
| None = None,
6263
max_content_length: int | None = 10 * 1024 * 1024, # 10MB
@@ -72,9 +73,9 @@ def __init__( # noqa: PLR0913
7273
context_builder: The CallContextBuilder used to construct the
7374
ServerCallContext passed to the http_handler. If None, no
7475
ServerCallContext is passed.
75-
card_modifier: An optional callback to dynamically modify the public
76+
card_modifier: An optional async callback to dynamically modify the public
7677
agent card before it is served.
77-
extended_card_modifier: An optional callback to dynamically modify
78+
extended_card_modifier: An optional async callback to dynamically modify
7879
the extended agent card before it is served. It receives the
7980
call context.
8081
max_content_length: The maximum allowed content length for incoming

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -49,9 +49,10 @@ def __init__( # noqa: PLR0913
4949
http_handler: RequestHandler,
5050
extended_agent_card: AgentCard | None = None,
5151
context_builder: CallContextBuilder | None = None,
52-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
52+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
53+
| None = None,
5354
extended_card_modifier: Callable[
54-
[AgentCard, ServerCallContext], AgentCard
55+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
5556
]
5657
| None = None,
5758
):
@@ -66,9 +67,9 @@ def __init__( # noqa: PLR0913
6667
context_builder: The CallContextBuilder used to construct the
6768
ServerCallContext passed to the http_handler. If None, no
6869
ServerCallContext is passed.
69-
card_modifier: An optional callback to dynamically modify the public
70+
card_modifier: An optional async callback to dynamically modify the public
7071
agent card before it is served.
71-
extended_card_modifier: An optional callback to dynamically modify
72+
extended_card_modifier: An optional async callback to dynamically modify
7273
the extended agent card before it is served. It receives the
7374
call context.
7475
"""

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,10 @@ def __init__( # noqa: PLR0913
5858
http_handler: RequestHandler,
5959
extended_agent_card: AgentCard | None = None,
6060
context_builder: CallContextBuilder | None = None,
61-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
61+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
62+
| None = None,
6263
extended_card_modifier: Callable[
63-
[AgentCard, ServerCallContext], AgentCard
64+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6465
]
6566
| None = None,
6667
):
@@ -75,9 +76,9 @@ def __init__( # noqa: PLR0913
7576
context_builder: The CallContextBuilder used to construct the
7677
ServerCallContext passed to the http_handler. If None, no
7778
ServerCallContext is passed.
78-
card_modifier: An optional callback to dynamically modify the public
79+
card_modifier: An optional async callback to dynamically modify the public
7980
agent card before it is served.
80-
extended_card_modifier: An optional callback to dynamically modify
81+
extended_card_modifier: An optional async callback to dynamically modify
8182
the extended agent card before it is served. It receives the
8283
call context.
8384
"""
@@ -150,7 +151,7 @@ async def handle_get_agent_card(
150151
"""
151152
card_to_serve = self.agent_card
152153
if self.card_modifier:
153-
card_to_serve = self.card_modifier(card_to_serve)
154+
card_to_serve = await self.card_modifier(card_to_serve)
154155

155156
return card_to_serve.model_dump(mode='json', exclude_none=True)
156157

@@ -182,9 +183,11 @@ async def handle_authenticated_agent_card(
182183

183184
if self.extended_card_modifier:
184185
context = self._context_builder.build(request)
185-
card_to_serve = self.extended_card_modifier(card_to_serve, context)
186+
card_to_serve = await self.extended_card_modifier(
187+
card_to_serve, context
188+
)
186189
elif self.card_modifier:
187-
card_to_serve = self.card_modifier(card_to_serve)
190+
card_to_serve = await self.card_modifier(card_to_serve)
188191

189192
return card_to_serve.model_dump(mode='json', exclude_none=True)
190193

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
delegate requests to.
100100
context_builder: The CallContextBuilder object. If none the
101101
DefaultCallContextBuilder is used.
102-
card_modifier: An optional callback to dynamically modify the public
102+
card_modifier: An optional async callback to dynamically modify the public
103103
agent card before it is served.
104104
"""
105105
self.agent_card = agent_card
@@ -339,7 +339,7 @@ async def GetAgentCard(
339339
"""Get the agent card for the agent served."""
340340
card_to_serve = self.agent_card
341341
if self.card_modifier:
342-
card_to_serve = self.card_modifier(card_to_serve)
342+
card_to_serve = await self.card_modifier(card_to_serve)
343343
return proto_utils.ToProto.agent_card(card_to_serve)
344344

345345
async def abort_context(

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import AsyncIterable, Callable
3+
from collections.abc import AsyncIterable, Awaitable, Callable
44

55
from a2a.server.context import ServerCallContext
66
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -63,21 +63,22 @@ def __init__(
6363
request_handler: RequestHandler,
6464
extended_agent_card: AgentCard | None = None,
6565
extended_card_modifier: Callable[
66-
[AgentCard, ServerCallContext], AgentCard
66+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6767
]
6868
| None = None,
69-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
69+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
70+
| None = None,
7071
):
7172
"""Initializes the JSONRPCHandler.
7273
7374
Args:
7475
agent_card: The AgentCard describing the agent's capabilities.
7576
request_handler: The underlying `RequestHandler` instance to delegate requests to.
7677
extended_agent_card: An optional, distinct Extended AgentCard to be served
77-
extended_card_modifier: An optional callback to dynamically modify
78+
extended_card_modifier: An optional async callback to dynamically modify
7879
the extended agent card before it is served. It receives the
7980
call context.
80-
card_modifier: An optional callback to dynamically modify the public
81+
card_modifier: An optional async callback to dynamically modify the public
8182
agent card before it is served.
8283
"""
8384
self.agent_card = agent_card
@@ -450,9 +451,11 @@ async def get_authenticated_extended_card(
450451

451452
card_to_serve = base_card
452453
if self.extended_card_modifier and context:
453-
card_to_serve = self.extended_card_modifier(base_card, context)
454+
card_to_serve = await self.extended_card_modifier(
455+
base_card, context
456+
)
454457
elif self.card_modifier:
455-
card_to_serve = self.card_modifier(base_card)
458+
card_to_serve = await self.card_modifier(base_card)
456459

457460
return GetAuthenticatedExtendedCardResponse(
458461
root=GetAuthenticatedExtendedCardSuccessResponse(

src/a2a/utils/signing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import Any, TypedDict
55

66
from a2a.utils.helpers import canonicalize_agent_card
@@ -54,7 +54,7 @@ def create_agent_card_signer(
5454
signing_key: PyJWK | str | bytes,
5555
protected_header: ProtectedHeader,
5656
header: dict[str, Any] | None = None,
57-
) -> Callable[[AgentCard], AgentCard]:
57+
) -> Callable[[AgentCard], Awaitable[AgentCard]]:
5858
"""Creates a function that signs an AgentCard and adds the signature.
5959
6060
Args:
@@ -63,10 +63,10 @@ def create_agent_card_signer(
6363
header: Unprotected header parameters.
6464
6565
Returns:
66-
A callable that takes an AgentCard and returns the modified AgentCard with a signature.
66+
A callable that takes an AgentCard and returns the awaitable modified AgentCard with a signature.
6767
"""
6868

69-
def agent_card_signer(agent_card: AgentCard) -> AgentCard:
69+
async def agent_card_signer(agent_card: AgentCard) -> AgentCard:
7070
"""Signs agent card."""
7171
canonical_payload = canonicalize_agent_card(agent_card)
7272
payload_dict = json.loads(canonical_payload)

tests/utils/test_signing.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
1+
from typing import Any
2+
3+
import pytest
4+
from cryptography.hazmat.primitives import asymmetric
5+
from jwt.utils import base64url_encode
6+
17
from a2a.types import (
2-
AgentCard,
38
AgentCapabilities,
4-
AgentSkill,
5-
)
6-
from a2a.types import (
79
AgentCard,
8-
AgentCapabilities,
9-
AgentSkill,
1010
AgentCardSignature,
11+
AgentSkill,
1112
)
1213
from a2a.utils import signing
13-
from typing import Any
14-
from jwt.utils import base64url_encode
15-
16-
import pytest
17-
from cryptography.hazmat.primitives import asymmetric
1814

1915

2016
def create_key_provider(verification_key: str | bytes | dict[str, Any]):
@@ -53,7 +49,8 @@ def sample_agent_card() -> AgentCard:
5349
)
5450

5551

56-
def test_signer_and_verifier_symmetric(sample_agent_card: AgentCard):
52+
@pytest.mark.asyncio
53+
async def test_signer_and_verifier_symmetric(sample_agent_card: AgentCard):
5754
"""Test the agent card signing and verification process with symmetric key encryption."""
5855
key = 'key12345' # Using a simple symmetric key for HS256
5956
wrong_key = 'wrongkey'
@@ -67,7 +64,7 @@ def test_signer_and_verifier_symmetric(sample_agent_card: AgentCard):
6764
'typ': 'JOSE',
6865
},
6966
)
70-
signed_card = agent_card_signer(sample_agent_card)
67+
signed_card = await agent_card_signer(sample_agent_card)
7168

7269
assert signed_card.signatures is not None
7370
assert len(signed_card.signatures) == 1
@@ -92,7 +89,8 @@ def test_signer_and_verifier_symmetric(sample_agent_card: AgentCard):
9289
verifier_wrong_key(signed_card)
9390

9491

95-
def test_signer_and_verifier_symmetric_multiple_signatures(
92+
@pytest.mark.asyncio
93+
async def test_signer_and_verifier_symmetric_multiple_signatures(
9694
sample_agent_card: AgentCard,
9795
):
9896
"""Test the agent card signing and verification process with symmetric key encryption.
@@ -115,7 +113,7 @@ def test_signer_and_verifier_symmetric_multiple_signatures(
115113
'typ': 'JOSE',
116114
},
117115
)
118-
signed_card = agent_card_signer(sample_agent_card)
116+
signed_card = await agent_card_signer(sample_agent_card)
119117

120118
assert signed_card.signatures is not None
121119
assert len(signed_card.signatures) == 2
@@ -140,7 +138,8 @@ def test_signer_and_verifier_symmetric_multiple_signatures(
140138
verifier_wrong_key(signed_card)
141139

142140

143-
def test_signer_and_verifier_asymmetric(sample_agent_card: AgentCard):
141+
@pytest.mark.asyncio
142+
async def test_signer_and_verifier_asymmetric(sample_agent_card: AgentCard):
144143
"""Test the agent card signing and verification process with an asymmetric key encryption."""
145144
# Generate a dummy EC private key for ES256
146145
private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1())
@@ -160,7 +159,7 @@ def test_signer_and_verifier_asymmetric(sample_agent_card: AgentCard):
160159
'typ': 'JOSE',
161160
},
162161
)
163-
signed_card = agent_card_signer(sample_agent_card)
162+
signed_card = await agent_card_signer(sample_agent_card)
164163

165164
assert signed_card.signatures is not None
166165
assert len(signed_card.signatures) == 1

0 commit comments

Comments
 (0)