Skip to content

Commit d543213

Browse files
committed
make card modifier and extended card modifier async
1 parent ffe31e2 commit d543213

12 files changed

Lines changed: 269 additions & 47 deletions

File tree

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -72,9 +72,10 @@ def __init__( # noqa: PLR0913
7272
http_handler: RequestHandler,
7373
extended_agent_card: AgentCard | None = None,
7474
context_builder: CallContextBuilder | None = None,
75-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
75+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
76+
| None = None,
7677
extended_card_modifier: Callable[
77-
[AgentCard, ServerCallContext], AgentCard
78+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
7879
]
7980
| None = None,
8081
max_content_length: int | None = 10 * 1024 * 1024, # 10MB

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import traceback
55

66
from abc import ABC, abstractmethod
7-
from collections.abc import AsyncGenerator, Callable
7+
from collections.abc import AsyncGenerator, Awaitable, Callable
8+
from inspect import iscoroutinefunction
89
from typing import TYPE_CHECKING, Any
910

1011
from pydantic import ValidationError
@@ -178,10 +179,13 @@ def __init__( # noqa: PLR0913
178179
http_handler: RequestHandler,
179180
extended_agent_card: AgentCard | None = None,
180181
context_builder: CallContextBuilder | None = None,
181-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
182+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
183+
| Callable[[AgentCard], AgentCard]
184+
| None = None,
182185
extended_card_modifier: Callable[
183-
[AgentCard, ServerCallContext], AgentCard
186+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
184187
]
188+
| Callable[[AgentCard, ServerCallContext], AgentCard]
185189
| None = None,
186190
max_content_length: int | None = 10 * 1024 * 1024, # 10MB
187191
) -> None:
@@ -576,7 +580,10 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
576580

577581
card_to_serve = self.agent_card
578582
if self.card_modifier:
579-
card_to_serve = self.card_modifier(card_to_serve)
583+
if iscoroutinefunction(self.card_modifier):
584+
card_to_serve = await self.card_modifier(card_to_serve)
585+
else:
586+
card_to_serve = self.card_modifier(card_to_serve)
580587

581588
return JSONResponse(
582589
card_to_serve.model_dump(
@@ -605,7 +612,12 @@ async def _handle_get_authenticated_extended_agent_card(
605612
context = self._context_builder.build(request)
606613
# If no base extended card is provided, pass the public card to the modifier
607614
base_card = card_to_serve if card_to_serve else self.agent_card
608-
card_to_serve = self.extended_card_modifier(base_card, context)
615+
if iscoroutinefunction(self.extended_card_modifier):
616+
card_to_serve = await self.extended_card_modifier(
617+
base_card, context
618+
)
619+
else:
620+
card_to_serve = self.extended_card_modifier(base_card, context)
609621

610622
if card_to_serve:
611623
return JSONResponse(

src/a2a/server/apps/jsonrpc/starlette_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -54,9 +54,10 @@ def __init__( # noqa: PLR0913
5454
http_handler: RequestHandler,
5555
extended_agent_card: AgentCard | None = None,
5656
context_builder: CallContextBuilder | None = None,
57-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
57+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
58+
| None = None,
5859
extended_card_modifier: Callable[
59-
[AgentCard, ServerCallContext], AgentCard
60+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6061
]
6162
| None = None,
6263
max_content_length: int | None = 10 * 1024 * 1024, # 10MB

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -49,9 +49,10 @@ def __init__( # noqa: PLR0913
4949
http_handler: RequestHandler,
5050
extended_agent_card: AgentCard | None = None,
5151
context_builder: CallContextBuilder | None = None,
52-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
52+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
53+
| None = None,
5354
extended_card_modifier: Callable[
54-
[AgentCard, ServerCallContext], AgentCard
55+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
5556
]
5657
| None = None,
5758
):

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33

44
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
5+
from inspect import iscoroutinefunction
56
from typing import TYPE_CHECKING, Any
67

78

@@ -58,10 +59,13 @@ def __init__( # noqa: PLR0913
5859
http_handler: RequestHandler,
5960
extended_agent_card: AgentCard | None = None,
6061
context_builder: CallContextBuilder | None = None,
61-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
62+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
63+
| Callable[[AgentCard], AgentCard]
64+
| None = None,
6265
extended_card_modifier: Callable[
63-
[AgentCard, ServerCallContext], AgentCard
66+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6467
]
68+
| Callable[[AgentCard, ServerCallContext], AgentCard]
6569
| None = None,
6670
):
6771
"""Initializes the RESTApplication.
@@ -150,7 +154,10 @@ async def handle_get_agent_card(
150154
"""
151155
card_to_serve = self.agent_card
152156
if self.card_modifier:
153-
card_to_serve = self.card_modifier(card_to_serve)
157+
if iscoroutinefunction(self.card_modifier):
158+
card_to_serve = await self.card_modifier(card_to_serve)
159+
else:
160+
card_to_serve = self.card_modifier(card_to_serve)
154161

155162
return card_to_serve.model_dump(mode='json', exclude_none=True)
156163

@@ -182,9 +189,19 @@ async def handle_authenticated_agent_card(
182189

183190
if self.extended_card_modifier:
184191
context = self._context_builder.build(request)
185-
card_to_serve = self.extended_card_modifier(card_to_serve, context)
192+
if iscoroutinefunction(self.extended_card_modifier):
193+
card_to_serve = await self.extended_card_modifier(
194+
card_to_serve, context
195+
)
196+
else:
197+
card_to_serve = self.extended_card_modifier(
198+
card_to_serve, context
199+
)
186200
elif self.card_modifier:
187-
card_to_serve = self.card_modifier(card_to_serve)
201+
if iscoroutinefunction(self.extended_card_modifier):
202+
card_to_serve = await self.card_modifier(card_to_serve)
203+
else:
204+
card_to_serve = self.card_modifier(card_to_serve)
188205

189206
return card_to_serve.model_dump(mode='json', exclude_none=True)
190207

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import logging
44

55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterable, Sequence
6+
from collections.abc import AsyncIterable, Awaitable, Sequence
7+
from inspect import iscoroutinefunction
78

89

910
try:
@@ -89,7 +90,9 @@ def __init__(
8990
agent_card: AgentCard,
9091
request_handler: RequestHandler,
9192
context_builder: CallContextBuilder | None = None,
92-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
93+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
94+
| Callable[[AgentCard], AgentCard]
95+
| None = None,
9396
):
9497
"""Initializes the GrpcHandler.
9598
@@ -339,7 +342,10 @@ async def GetAgentCard(
339342
"""Get the agent card for the agent served."""
340343
card_to_serve = self.agent_card
341344
if self.card_modifier:
342-
card_to_serve = self.card_modifier(card_to_serve)
345+
if iscoroutinefunction(self.card_modifier):
346+
card_to_serve = await self.card_modifier(card_to_serve)
347+
else:
348+
card_to_serve = self.card_modifier(card_to_serve)
343349
return proto_utils.ToProto.agent_card(card_to_serve)
344350

345351
async def abort_context(

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

3-
from collections.abc import AsyncIterable, Callable
3+
from collections.abc import AsyncIterable, Awaitable, Callable
4+
from inspect import iscoroutinefunction
45

56
from a2a.server.context import ServerCallContext
67
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -63,10 +64,13 @@ def __init__(
6364
request_handler: RequestHandler,
6465
extended_agent_card: AgentCard | None = None,
6566
extended_card_modifier: Callable[
66-
[AgentCard, ServerCallContext], AgentCard
67+
[AgentCard, ServerCallContext], Awaitable[AgentCard]
6768
]
69+
| Callable[[AgentCard, ServerCallContext], AgentCard]
70+
| None = None,
71+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard]]
72+
| Callable[[AgentCard], AgentCard]
6873
| None = None,
69-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
7074
):
7175
"""Initializes the JSONRPCHandler.
7276
@@ -450,9 +454,17 @@ async def get_authenticated_extended_card(
450454

451455
card_to_serve = base_card
452456
if self.extended_card_modifier and context:
453-
card_to_serve = self.extended_card_modifier(base_card, context)
457+
if iscoroutinefunction(self.extended_card_modifier):
458+
card_to_serve = await self.extended_card_modifier(
459+
base_card, context
460+
)
461+
else:
462+
card_to_serve = self.extended_card_modifier(base_card, context)
454463
elif self.card_modifier:
455-
card_to_serve = self.card_modifier(base_card)
464+
if iscoroutinefunction(self.card_modifier):
465+
card_to_serve = await self.card_modifier(base_card)
466+
else:
467+
card_to_serve = self.card_modifier(base_card)
456468

457469
return GetAuthenticatedExtendedCardResponse(
458470
root=GetAuthenticatedExtendedCardSuccessResponse(

src/a2a/utils/signing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import Any, TypedDict
55

66
from a2a.utils.helpers import canonicalize_agent_card
@@ -54,7 +54,7 @@ def create_agent_card_signer(
5454
signing_key: PyJWK | str | bytes,
5555
protected_header: ProtectedHeader,
5656
header: dict[str, Any] | None = None,
57-
) -> Callable[[AgentCard], AgentCard]:
57+
) -> Callable[[AgentCard], Awaitable[AgentCard]]:
5858
"""Creates a function that signs an AgentCard and adds the signature.
5959
6060
Args:
@@ -66,7 +66,7 @@ def create_agent_card_signer(
6666
A callable that takes an AgentCard and returns the modified AgentCard with a signature.
6767
"""
6868

69-
def agent_card_signer(agent_card: AgentCard) -> AgentCard:
69+
async def agent_card_signer(agent_card: AgentCard) -> AgentCard:
7070
"""Signs agent card."""
7171
canonical_payload = canonicalize_agent_card(agent_card)
7272
payload_dict = json.loads(canonical_payload)

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,34 @@ async def test_get_agent_card_with_modifier(
209209
) -> None:
210210
"""Test GetAgentCard call with a card_modifier."""
211211

212+
async def modifier(card: types.AgentCard) -> types.AgentCard:
213+
modified_card = card.model_copy(deep=True)
214+
modified_card.name = 'Modified gRPC Agent'
215+
return modified_card
216+
217+
grpc_handler_modified = GrpcHandler(
218+
agent_card=sample_agent_card,
219+
request_handler=mock_request_handler,
220+
card_modifier=modifier,
221+
)
222+
223+
request_proto = a2a_pb2.GetAgentCardRequest()
224+
response = await grpc_handler_modified.GetAgentCard(
225+
request_proto, mock_grpc_context
226+
)
227+
228+
assert response.name == 'Modified gRPC Agent'
229+
assert response.version == sample_agent_card.version
230+
231+
232+
@pytest.mark.asyncio
233+
async def test_get_agent_card_with_modifier_sync(
234+
mock_request_handler: AsyncMock,
235+
sample_agent_card: types.AgentCard,
236+
mock_grpc_context: AsyncMock,
237+
) -> None:
238+
"""Test GetAgentCard call with a synchronous card_modifier."""
239+
212240
def modifier(card: types.AgentCard) -> types.AgentCard:
213241
modified_card = card.model_copy(deep=True)
214242
modified_card.name = 'Modified gRPC Agent'

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,57 @@ async def test_get_authenticated_extended_card_with_modifier(self) -> None:
12951295
skills=[],
12961296
)
12971297

1298+
async def modifier(
1299+
card: AgentCard, context: ServerCallContext
1300+
) -> AgentCard:
1301+
modified_card = card.model_copy(deep=True)
1302+
modified_card.name = 'Modified Card'
1303+
modified_card.description = (
1304+
f'Modified for context: {context.state.get("foo")}'
1305+
)
1306+
return modified_card
1307+
1308+
handler = JSONRPCHandler(
1309+
self.mock_agent_card,
1310+
mock_request_handler,
1311+
extended_agent_card=mock_base_card,
1312+
extended_card_modifier=modifier,
1313+
)
1314+
request = GetAuthenticatedExtendedCardRequest(id='ext-card-req-mod')
1315+
call_context = ServerCallContext(state={'foo': 'bar'})
1316+
1317+
# Act
1318+
response: GetAuthenticatedExtendedCardResponse = (
1319+
await handler.get_authenticated_extended_card(request, call_context)
1320+
)
1321+
1322+
# Assert
1323+
self.assertIsInstance(
1324+
response.root, GetAuthenticatedExtendedCardSuccessResponse
1325+
)
1326+
self.assertEqual(response.root.id, 'ext-card-req-mod')
1327+
modified_card = response.root.result
1328+
self.assertEqual(modified_card.name, 'Modified Card')
1329+
self.assertEqual(modified_card.description, 'Modified for context: bar')
1330+
self.assertEqual(modified_card.version, '1.0')
1331+
1332+
async def test_get_authenticated_extended_card_with_modifier_sync(
1333+
self,
1334+
) -> None:
1335+
"""Test successful retrieval of a synchronously dynamically modified extended agent card."""
1336+
# Arrange
1337+
mock_request_handler = AsyncMock(spec=DefaultRequestHandler)
1338+
mock_base_card = AgentCard(
1339+
name='Base Card',
1340+
description='Base details',
1341+
url='http://agent.example.com/api',
1342+
version='1.0',
1343+
capabilities=AgentCapabilities(),
1344+
default_input_modes=['text/plain'],
1345+
default_output_modes=['application/json'],
1346+
skills=[],
1347+
)
1348+
12981349
def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard:
12991350
modified_card = card.model_copy(deep=True)
13001351
modified_card.name = 'Modified Card'

0 commit comments

Comments
 (0)