Skip to content

Commit 8c5e9a5

Browse files
committed
make card modifier and extended card modifier async
1 parent ffe31e2 commit 8c5e9a5

12 files changed

Lines changed: 66 additions & 50 deletions

File tree

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -72,9 +72,10 @@ def __init__( # noqa: PLR0913
7272
http_handler: RequestHandler,
7373
extended_agent_card: AgentCard | None = None,
7474
context_builder: CallContextBuilder | None = None,
75-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
75+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
76+
| None = None,
7677
extended_card_modifier: Callable[
77-
[AgentCard, ServerCallContext], AgentCard
78+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
7879
]
7980
| None = None,
8081
max_content_length: int | None = 10 * 1024 * 1024, # 10MB

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import traceback
55

66
from abc import ABC, abstractmethod
7-
from collections.abc import AsyncGenerator, Callable
7+
from collections.abc import AsyncGenerator, Awaitable, Callable
88
from typing import TYPE_CHECKING, Any
99

1010
from pydantic import ValidationError
@@ -178,9 +178,10 @@ def __init__( # noqa: PLR0913
178178
http_handler: RequestHandler,
179179
extended_agent_card: AgentCard | None = None,
180180
context_builder: CallContextBuilder | None = None,
181-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
181+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
182+
| None = None,
182183
extended_card_modifier: Callable[
183-
[AgentCard, ServerCallContext], AgentCard
184+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
184185
]
185186
| None = None,
186187
max_content_length: int | None = 10 * 1024 * 1024, # 10MB
@@ -576,7 +577,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
576577

577578
card_to_serve = self.agent_card
578579
if self.card_modifier:
579-
card_to_serve = self.card_modifier(card_to_serve)
580+
card_to_serve = await self.card_modifier(card_to_serve)
580581

581582
return JSONResponse(
582583
card_to_serve.model_dump(
@@ -605,7 +606,9 @@ async def _handle_get_authenticated_extended_agent_card(
605606
context = self._context_builder.build(request)
606607
# If no base extended card is provided, pass the public card to the modifier
607608
base_card = card_to_serve if card_to_serve else self.agent_card
608-
card_to_serve = self.extended_card_modifier(base_card, context)
609+
card_to_serve = await self.extended_card_modifier(
610+
base_card, context
611+
)
609612

610613
if card_to_serve:
611614
return JSONResponse(

src/a2a/server/apps/jsonrpc/starlette_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -54,9 +54,10 @@ def __init__( # noqa: PLR0913
5454
http_handler: RequestHandler,
5555
extended_agent_card: AgentCard | None = None,
5656
context_builder: CallContextBuilder | None = None,
57-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
57+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
58+
| None = None,
5859
extended_card_modifier: Callable[
59-
[AgentCard, ServerCallContext], AgentCard
60+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6061
]
6162
| None = None,
6263
max_content_length: int | None = 10 * 1024 * 1024, # 10MB

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -49,9 +49,10 @@ def __init__( # noqa: PLR0913
4949
http_handler: RequestHandler,
5050
extended_agent_card: AgentCard | None = None,
5151
context_builder: CallContextBuilder | None = None,
52-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
52+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
53+
| None = None,
5354
extended_card_modifier: Callable[
54-
[AgentCard, ServerCallContext], AgentCard
55+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
5556
]
5657
| None = None,
5758
):

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,10 @@ def __init__( # noqa: PLR0913
5858
http_handler: RequestHandler,
5959
extended_agent_card: AgentCard | None = None,
6060
context_builder: CallContextBuilder | None = None,
61-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
61+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
62+
| None = None,
6263
extended_card_modifier: Callable[
63-
[AgentCard, ServerCallContext], AgentCard
64+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6465
]
6566
| None = None,
6667
):
@@ -150,7 +151,7 @@ async def handle_get_agent_card(
150151
"""
151152
card_to_serve = self.agent_card
152153
if self.card_modifier:
153-
card_to_serve = self.card_modifier(card_to_serve)
154+
card_to_serve = await self.card_modifier(card_to_serve)
154155

155156
return card_to_serve.model_dump(mode='json', exclude_none=True)
156157

@@ -182,9 +183,11 @@ async def handle_authenticated_agent_card(
182183

183184
if self.extended_card_modifier:
184185
context = self._context_builder.build(request)
185-
card_to_serve = self.extended_card_modifier(card_to_serve, context)
186+
card_to_serve = await self.extended_card_modifier(
187+
card_to_serve, context
188+
)
186189
elif self.card_modifier:
187-
card_to_serve = self.card_modifier(card_to_serve)
190+
card_to_serve = await self.card_modifier(card_to_serve)
188191

189192
return card_to_serve.model_dump(mode='json', exclude_none=True)
190193

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ async def GetAgentCard(
339339
"""Get the agent card for the agent served."""
340340
card_to_serve = self.agent_card
341341
if self.card_modifier:
342-
card_to_serve = self.card_modifier(card_to_serve)
342+
card_to_serve = await self.card_modifier(card_to_serve)
343343
return proto_utils.ToProto.agent_card(card_to_serve)
344344

345345
async def abort_context(

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import AsyncIterable, Callable
3+
from collections.abc import AsyncIterable, Awaitable, Callable
44

55
from a2a.server.context import ServerCallContext
66
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -63,10 +63,11 @@ def __init__(
6363
request_handler: RequestHandler,
6464
extended_agent_card: AgentCard | None = None,
6565
extended_card_modifier: Callable[
66-
[AgentCard, ServerCallContext], AgentCard
66+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6767
]
6868
| None = None,
69-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
69+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
70+
| None = None,
7071
):
7172
"""Initializes the JSONRPCHandler.
7273
@@ -450,9 +451,11 @@ async def get_authenticated_extended_card(
450451

451452
card_to_serve = base_card
452453
if self.extended_card_modifier and context:
453-
card_to_serve = self.extended_card_modifier(base_card, context)
454+
card_to_serve = await self.extended_card_modifier(
455+
base_card, context
456+
)
454457
elif self.card_modifier:
455-
card_to_serve = self.card_modifier(base_card)
458+
card_to_serve = await self.card_modifier(base_card)
456459

457460
return GetAuthenticatedExtendedCardResponse(
458461
root=GetAuthenticatedExtendedCardSuccessResponse(

src/a2a/utils/signing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import Any, TypedDict
55

66
from a2a.utils.helpers import canonicalize_agent_card
@@ -54,7 +54,7 @@ def create_agent_card_signer(
5454
signing_key: PyJWK | str | bytes,
5555
protected_header: ProtectedHeader,
5656
header: dict[str, Any] | None = None,
57-
) -> Callable[[AgentCard], AgentCard]:
57+
) -> Callable[[AgentCard], Awaitable[AgentCard]]:
5858
"""Creates a function that signs an AgentCard and adds the signature.
5959
6060
Args:
@@ -66,7 +66,7 @@ def create_agent_card_signer(
6666
A callable that takes an AgentCard and returns the modified AgentCard with a signature.
6767
"""
6868

69-
def agent_card_signer(agent_card: AgentCard) -> AgentCard:
69+
async def agent_card_signer(agent_card: AgentCard) -> AgentCard:
7070
"""Signs agent card."""
7171
canonical_payload = canonicalize_agent_card(agent_card)
7272
payload_dict = json.loads(canonical_payload)

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ async def test_get_agent_card_with_modifier(
209209
) -> None:
210210
"""Test GetAgentCard call with a card_modifier."""
211211

212-
def modifier(card: types.AgentCard) -> types.AgentCard:
212+
async def modifier(card: types.AgentCard) -> types.AgentCard:
213213
modified_card = card.model_copy(deep=True)
214214
modified_card.name = 'Modified gRPC Agent'
215215
return modified_card

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,9 @@ async def test_get_authenticated_extended_card_with_modifier(self) -> None:
12951295
skills=[],
12961296
)
12971297

1298-
def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard:
1298+
async def modifier(
1299+
card: AgentCard, context: ServerCallContext
1300+
) -> AgentCard:
12991301
modified_card = card.model_copy(deep=True)
13001302
modified_card.name = 'Modified Card'
13011303
modified_card.description = (

0 commit comments

Comments
 (0)