Skip to content

Commit b141ecf

Browse files
committed
Refactor to use common header value splitter
1 parent 75849c0 commit b141ecf

6 files changed

Lines changed: 125 additions & 36 deletions

File tree

src/a2a/extensions/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@
44
HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'
55

66

7+
def get_requested_extensions(values: list[str]) -> set[str]:
8+
"""Get the set of requested extensions from an input list.
9+
10+
This handles the list containing potentially comma-separated values, as
11+
occurs when using a list in an HTTP header.
12+
"""
13+
return {
14+
stripped
15+
for v in values
16+
for ext in v.split(',')
17+
if (stripped := ext.strip())
18+
}
19+
20+
721
def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
822
"""Find an AgentExtension in an AgentCard given a uri."""
923
for ext in card.capabilities.extensions or []:

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
from a2a.auth.user import UnauthenticatedUser
2121
from a2a.auth.user import User as A2AUser
22-
from a2a.extensions.common import HTTP_EXTENSION_HEADER
22+
from a2a.extensions.common import (
23+
HTTP_EXTENSION_HEADER,
24+
get_requested_extensions,
25+
)
2326
from a2a.server.context import ServerCallContext
2427
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
2528
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -103,12 +106,9 @@ def build(self, request: Request) -> ServerCallContext:
103106
return ServerCallContext(
104107
user=user,
105108
state=state,
106-
requested_extensions={
107-
stripped
108-
for h in request.headers.getlist(HTTP_EXTENSION_HEADER)
109-
for ext in h.split(',')
110-
if (stripped := ext.strip())
111-
},
109+
requested_extensions=get_requested_extensions(
110+
request.headers.getlist(HTTP_EXTENSION_HEADER)
111+
),
112112
)
113113

114114

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222

2323
from a2a import types
2424
from a2a.auth.user import UnauthenticatedUser
25-
from a2a.extensions.common import HTTP_EXTENSION_HEADER
25+
from a2a.extensions.common import (
26+
HTTP_EXTENSION_HEADER,
27+
get_requested_extensions,
28+
)
2629
from a2a.grpc import a2a_pb2
2730
from a2a.server.context import ServerCallContext
2831
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -76,7 +79,7 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
7679
return ServerCallContext(
7780
user=user,
7881
state=state,
79-
requested_extensions=set(
82+
requested_extensions=get_requested_extensions(
8083
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
8184
),
8285
)

tests/extensions/test_common.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
1-
from a2a.extensions.common import find_extension_by_uri
2-
from a2a.types import AgentCard, AgentExtension, AgentCapabilities
1+
from a2a.extensions.common import (
2+
find_extension_by_uri,
3+
get_requested_extensions,
4+
)
5+
from a2a.types import AgentCapabilities, AgentCard, AgentExtension
6+
7+
8+
def test_get_requested_extensions():
9+
assert get_requested_extensions([]) == set()
10+
assert get_requested_extensions(['foo']) == {'foo'}
11+
assert get_requested_extensions(['foo', 'bar']) == {'foo', 'bar'}
12+
assert get_requested_extensions(['foo, bar']) == {'foo', 'bar'}
13+
assert get_requested_extensions(['foo,bar']) == {'foo', 'bar'}
14+
assert get_requested_extensions(['foo', 'bar,baz']) == {'foo', 'bar', 'baz'}
15+
assert get_requested_extensions(['foo,, bar', 'baz']) == {
16+
'foo',
17+
'bar',
18+
'baz',
19+
}
20+
assert get_requested_extensions([' foo , bar ', 'baz']) == {
21+
'foo',
22+
'bar',
23+
'baz',
24+
}
325

426

527
def test_find_extension_by_uri():
6-
ext1 = AgentExtension(
7-
uri='foo', name='Foo', description='The Foo extension'
8-
)
9-
ext2 = AgentExtension(
10-
uri='bar', name='Bar', description='The Bar extension'
11-
)
28+
ext1 = AgentExtension(uri='foo', description='The Foo extension')
29+
ext2 = AgentExtension(uri='bar', description='The Bar extension')
1230
card = AgentCard(
13-
agent_id='test-agent',
1431
name='Test Agent',
1532
description='Test Agent Description',
1633
version='1.0',
@@ -28,7 +45,6 @@ def test_find_extension_by_uri():
2845

2946
def test_find_extension_by_uri_no_extensions():
3047
card = AgentCard(
31-
agent_id='test-agent',
3248
name='Test Agent',
3349
description='Test Agent Description',
3450
version='1.0',

tests/server/apps/jsonrpc/test_jsonrpc_app.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,33 @@ def test_request_with_comma_separated_extensions(
177177
call_context = mock_handler.on_message_send.call_args[0][1]
178178
assert call_context.requested_extensions == {'foo', 'bar'}
179179

180+
def test_request_with_comma_separated_extensions_no_space(
181+
self, client, mock_handler
182+
):
183+
headers = [
184+
(HTTP_EXTENSION_HEADER, 'foo, bar'),
185+
(HTTP_EXTENSION_HEADER, 'baz'),
186+
]
187+
response = client.post(
188+
'/',
189+
headers=headers,
190+
json=SendMessageRequest(
191+
id='1',
192+
params=MessageSendParams(
193+
message=Message(
194+
messageId='1',
195+
role=Role.user,
196+
parts=[TextPart(text='hi')],
197+
)
198+
),
199+
).model_dump(),
200+
)
201+
response.raise_for_status()
202+
203+
mock_handler.on_message_send.assert_called_once()
204+
call_context = mock_handler.on_message_send.call_args[0][1]
205+
assert call_context.requested_extensions == {'foo', 'bar', 'baz'}
206+
180207
def test_request_with_multiple_extension_headers(
181208
self, client, mock_handler
182209
):

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import grpc
4+
import grpc.aio
45
import pytest
56

67
from a2a import types
@@ -22,7 +23,6 @@ def mock_request_handler() -> AsyncMock:
2223
def mock_grpc_context() -> AsyncMock:
2324
context = AsyncMock(spec=grpc.aio.ServicerContext)
2425
context.abort = AsyncMock()
25-
context.invocation_metadata = MagicMock()
2626
context.set_trailing_metadata = MagicMock()
2727
return context
2828

@@ -286,18 +286,15 @@ async def test_abort_context_error_mapping(
286286

287287
@pytest.mark.asyncio
288288
class TestGrpcExtensions:
289-
@patch(
290-
'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build'
291-
)
292289
async def test_send_message_with_extensions(
293290
self,
294-
mock_build,
295291
grpc_handler: GrpcHandler,
296292
mock_request_handler: AsyncMock,
297293
mock_grpc_context: AsyncMock,
298294
):
299-
mock_build.return_value = ServerCallContext(
300-
requested_extensions={'foo', 'bar'}
295+
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
296+
(HTTP_EXTENSION_HEADER, 'foo'),
297+
(HTTP_EXTENSION_HEADER, 'bar'),
301298
)
302299

303300
def side_effect(request, context: ServerCallContext):
@@ -321,21 +318,48 @@ def side_effect(request, context: ServerCallContext):
321318
assert call_context.requested_extensions == {'foo', 'bar'}
322319

323320
mock_grpc_context.set_trailing_metadata.assert_called_once()
324-
called_metadata = mock_grpc_context.set_trailing_metadata.call_args.args[0]
325-
assert set(called_metadata) == {(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')}
321+
called_metadata = (
322+
mock_grpc_context.set_trailing_metadata.call_args.args[0]
323+
)
324+
assert set(called_metadata) == {
325+
(HTTP_EXTENSION_HEADER, 'foo'),
326+
(HTTP_EXTENSION_HEADER, 'baz'),
327+
}
328+
329+
async def test_send_message_with_comma_separated_extensions(
330+
self,
331+
grpc_handler: GrpcHandler,
332+
mock_request_handler: AsyncMock,
333+
mock_grpc_context: AsyncMock,
334+
):
335+
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
336+
(HTTP_EXTENSION_HEADER, 'foo ,, bar,'),
337+
(HTTP_EXTENSION_HEADER, 'baz , bar'),
338+
)
339+
mock_request_handler.on_message_send.return_value = types.Message(
340+
messageId='1',
341+
role=types.Role.agent,
342+
parts=[types.TextPart(text='test')],
343+
)
344+
345+
await grpc_handler.SendMessage(
346+
a2a_pb2.SendMessageRequest(), mock_grpc_context
347+
)
348+
349+
mock_request_handler.on_message_send.assert_awaited_once()
350+
call_context = mock_request_handler.on_message_send.call_args[0][1]
351+
assert isinstance(call_context, ServerCallContext)
352+
assert call_context.requested_extensions == {'foo', 'bar', 'baz'}
326353

327-
@patch(
328-
'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build'
329-
)
330354
async def test_send_streaming_message_with_extensions(
331355
self,
332-
mock_build,
333356
grpc_handler: GrpcHandler,
334357
mock_request_handler: AsyncMock,
335358
mock_grpc_context: AsyncMock,
336359
):
337-
mock_build.return_value = ServerCallContext(
338-
requested_extensions={'foo', 'bar'}
360+
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
361+
(HTTP_EXTENSION_HEADER, 'foo'),
362+
(HTTP_EXTENSION_HEADER, 'bar'),
339363
)
340364

341365
async def side_effect(request, context: ServerCallContext):
@@ -365,5 +389,10 @@ async def side_effect(request, context: ServerCallContext):
365389
assert call_context.requested_extensions == {'foo', 'bar'}
366390

367391
mock_grpc_context.set_trailing_metadata.assert_called_once()
368-
called_metadata = mock_grpc_context.set_trailing_metadata.call_args.args[0]
369-
assert set(called_metadata) == {(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')}
392+
called_metadata = (
393+
mock_grpc_context.set_trailing_metadata.call_args.args[0]
394+
)
395+
assert set(called_metadata) == {
396+
(HTTP_EXTENSION_HEADER, 'foo'),
397+
(HTTP_EXTENSION_HEADER, 'baz'),
398+
}

0 commit comments

Comments
 (0)