Skip to content

Commit a941547

Browse files
committed
Add tests related to server side extensions.
1 parent 56cb1f6 commit a941547

4 files changed

Lines changed: 306 additions & 8 deletions

File tree

tests/extensions/test_common.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from a2a.extensions.common import find_extension_by_uri
2+
from a2a.types import AgentCard, AgentExtension, AgentCapabilities
3+
4+
5+
def test_find_extension_by_uri():
6+
ext1 = AgentExtension(uri='foo', name='Foo', description='The Foo extension')
7+
ext2 = AgentExtension(uri='bar', name='Bar', description='The Bar extension')
8+
card = AgentCard(
9+
agent_id='test-agent',
10+
name='Test Agent',
11+
description='Test Agent Description',
12+
version='1.0',
13+
url='http://test.com',
14+
skills=[],
15+
defaultInputModes=['text/plain'],
16+
defaultOutputModes=['text/plain'],
17+
capabilities=AgentCapabilities(extensions=[ext1, ext2]),
18+
)
19+
20+
assert find_extension_by_uri(card, 'foo') == ext1
21+
assert find_extension_by_uri(card, 'bar') == ext2
22+
assert find_extension_by_uri(card, 'baz') is None
23+
24+
25+
def test_find_extension_by_uri_no_extensions():
26+
card = AgentCard(
27+
agent_id='test-agent',
28+
name='Test Agent',
29+
description='Test Agent Description',
30+
version='1.0',
31+
url='http://test.com',
32+
skills=[],
33+
defaultInputModes=['text/plain'],
34+
defaultOutputModes=['text/plain'],
35+
capabilities=AgentCapabilities(extensions=None),
36+
)
37+
38+
assert find_extension_by_uri(card, 'foo') is None

tests/server/agent_execution/test_context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from a2a.server.agent_execution import RequestContext
8+
from a2a.server.context import ServerCallContext
89
from a2a.types import (
910
Message,
1011
MessageSendParams,
@@ -262,3 +263,16 @@ def test_init_with_context_id_and_existing_context_id_match(
262263

263264
assert context.context_id == mock_task.contextId
264265
assert context.current_task == mock_task
266+
267+
def test_extension_handling(self):
268+
"""Test extension handling in RequestContext."""
269+
call_context = ServerCallContext(requested_extensions={'foo', 'bar'})
270+
context = RequestContext(call_context=call_context)
271+
272+
assert context.requested_extensions == {'foo', 'bar'}
273+
274+
context.add_activated_extension('foo')
275+
assert call_context.activated_extensions == {'foo'}
276+
277+
context.add_activated_extension('baz')
278+
assert call_context.activated_extensions == {'foo', 'baz'}

tests/server/apps/jsonrpc/test_jsonrpc_app.py

Lines changed: 164 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,36 @@
1-
from unittest.mock import MagicMock
1+
from unittest.mock import AsyncMock, MagicMock
22

33
import pytest
4-
4+
from starlette.applications import Starlette
5+
from starlette.testclient import TestClient
56

67
# Attempt to import StarletteBaseUser, fallback to MagicMock if not available
78
try:
89
from starlette.authentication import BaseUser as StarletteBaseUser
910
except ImportError:
1011
StarletteBaseUser = MagicMock() # type: ignore
1112

13+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
1214
from a2a.server.apps.jsonrpc.jsonrpc_app import (
13-
JSONRPCApplication, # Still needed for JSONRPCApplication default constructor arg
15+
JSONRPCApplication,
1416
StarletteUserProxy,
1517
)
18+
from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication
19+
from a2a.server.context import ServerCallContext
1620
from a2a.server.request_handlers.request_handler import (
17-
RequestHandler, # For mock spec
21+
RequestHandler,
22+
) # For mock spec
23+
from a2a.types import (
24+
AgentCapabilities,
25+
AgentCard,
26+
Message,
27+
MessageSendParams,
28+
Role,
29+
SendMessageRequest,
30+
SendMessageResponse,
31+
SendMessageSuccessResponse,
32+
TextPart,
1833
)
19-
from a2a.types import AgentCard # For mock spec
20-
2134

2235
# --- StarletteUserProxy Tests ---
2336

@@ -69,6 +82,7 @@ def test_jsonrpc_app_build_method_abstract_raises_typeerror(
6982
mock_agent_card.url = 'http://mockurl.com'
7083
# Ensure 'supportsAuthenticatedExtendedCard' attribute exists
7184
mock_agent_card.supportsAuthenticatedExtendedCard = False
85+
mock_agent_card.capabilities = AgentCapabilities(streaming=True)
7286

7387
# This will fail at definition time if an abstract method is not implemented
7488
with pytest.raises(
@@ -86,5 +100,149 @@ def some_other_method(self):
86100
)
87101

88102

103+
class TestJSONRPCExtensions:
104+
@pytest.fixture
105+
def mock_handler(self):
106+
handler = AsyncMock(spec=RequestHandler)
107+
handler.on_message_send.return_value = SendMessageResponse(
108+
root=SendMessageSuccessResponse(
109+
id='1',
110+
result=Message(
111+
messageId='test',
112+
role=Role.agent,
113+
parts=[TextPart(text='response message')],
114+
),
115+
)
116+
)
117+
return handler
118+
119+
@pytest.fixture
120+
def test_app(self, mock_handler):
121+
mock_agent_card = MagicMock(spec=AgentCard)
122+
mock_agent_card.url = 'http://mockurl.com'
123+
mock_agent_card.supportsAuthenticatedExtendedCard = False
124+
125+
return A2AStarletteApplication(
126+
agent_card=mock_agent_card, http_handler=mock_handler
127+
)
128+
129+
@pytest.fixture
130+
def client(self, test_app):
131+
return TestClient(test_app.build())
132+
133+
def test_request_with_single_extension(self, client, mock_handler):
134+
headers = {HTTP_EXTENSION_HEADER: 'foo'}
135+
response = client.post(
136+
'/',
137+
headers=headers,
138+
json=SendMessageRequest(
139+
id='1',
140+
params=MessageSendParams(
141+
message=Message(
142+
messageId='1',
143+
role=Role.user,
144+
parts=[TextPart(text='hi')],
145+
)
146+
),
147+
).model_dump(),
148+
)
149+
response.raise_for_status()
150+
151+
mock_handler.on_message_send.assert_called_once()
152+
call_context = mock_handler.on_message_send.call_args[0][1]
153+
assert isinstance(call_context, ServerCallContext)
154+
assert call_context.requested_extensions == {'foo'}
155+
156+
def test_request_with_comma_separated_extensions(
157+
self, client, mock_handler
158+
):
159+
headers = {HTTP_EXTENSION_HEADER: 'foo, bar'}
160+
response = client.post(
161+
'/',
162+
headers=headers,
163+
json=SendMessageRequest(
164+
id='1',
165+
params=MessageSendParams(
166+
message=Message(
167+
messageId='1',
168+
role=Role.user,
169+
parts=[TextPart(text='hi')],
170+
)
171+
),
172+
).model_dump(),
173+
)
174+
response.raise_for_status()
175+
176+
mock_handler.on_message_send.assert_called_once()
177+
call_context = mock_handler.on_message_send.call_args[0][1]
178+
assert call_context.requested_extensions == {'foo', 'bar'}
179+
180+
def test_request_with_multiple_extension_headers(
181+
self, client, mock_handler
182+
):
183+
headers = [
184+
(HTTP_EXTENSION_HEADER, 'foo'),
185+
(HTTP_EXTENSION_HEADER, 'bar'),
186+
]
187+
response = client.post(
188+
'/',
189+
headers=headers,
190+
json=SendMessageRequest(
191+
id='1',
192+
params=MessageSendParams(
193+
message=Message(
194+
messageId='1',
195+
role=Role.user,
196+
parts=[TextPart(text='hi')],
197+
)
198+
),
199+
).model_dump(),
200+
)
201+
response.raise_for_status()
202+
203+
mock_handler.on_message_send.assert_called_once()
204+
call_context = mock_handler.on_message_send.call_args[0][1]
205+
assert call_context.requested_extensions == {'foo', 'bar'}
206+
207+
def test_response_with_activated_extensions(self, client, mock_handler):
208+
def side_effect(request, context: ServerCallContext):
209+
context.activated_extensions.add('foo')
210+
context.activated_extensions.add('baz')
211+
return SendMessageResponse(
212+
root=SendMessageSuccessResponse(
213+
id='1',
214+
result=Message(
215+
messageId='test',
216+
role=Role.agent,
217+
parts=[TextPart(text='response message')],
218+
),
219+
)
220+
)
221+
222+
mock_handler.on_message_send.side_effect = side_effect
223+
224+
response = client.post(
225+
'/',
226+
json=SendMessageRequest(
227+
id='1',
228+
params=MessageSendParams(
229+
message=Message(
230+
messageId='1',
231+
role=Role.user,
232+
parts=[TextPart(text='hi')],
233+
)
234+
),
235+
).model_dump(),
236+
)
237+
response.raise_for_status()
238+
239+
assert response.status_code == 200
240+
assert HTTP_EXTENSION_HEADER in response.headers
241+
assert set(response.headers[HTTP_EXTENSION_HEADER].split(', ')) == {
242+
'foo',
243+
'baz',
244+
}
245+
246+
89247
if __name__ == '__main__':
90248
pytest.main([__file__])

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
from unittest.mock import AsyncMock
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import grpc
44
import pytest
55

66
from a2a import types
7+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
78
from a2a.grpc import a2a_pb2
9+
from a2a.server.context import ServerCallContext
810
from a2a.server.request_handlers import GrpcHandler, RequestHandler
911
from a2a.utils.errors import ServerError
1012

11-
1213
# --- Fixtures ---
1314

1415

@@ -21,6 +22,8 @@ def mock_request_handler() -> AsyncMock:
2122
def mock_grpc_context() -> AsyncMock:
2223
context = AsyncMock(spec=grpc.aio.ServicerContext)
2324
context.abort = AsyncMock()
25+
context.invocation_metadata = MagicMock()
26+
context.set_trailing_metadata = MagicMock()
2427
return context
2528

2629

@@ -279,3 +282,88 @@ async def test_abort_context_error_mapping(
279282
call_args, _ = mock_grpc_context.abort.call_args
280283
assert call_args[0] == grpc_status_code
281284
assert error_message_part in call_args[1]
285+
286+
287+
@pytest.mark.asyncio
288+
class TestGrpcExtensions:
289+
@patch(
290+
'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build'
291+
)
292+
async def test_send_message_with_extensions(
293+
self,
294+
mock_build,
295+
grpc_handler: GrpcHandler,
296+
mock_request_handler: AsyncMock,
297+
mock_grpc_context: AsyncMock,
298+
):
299+
mock_build.return_value = ServerCallContext(
300+
requested_extensions={'foo', 'bar'}
301+
)
302+
303+
def side_effect(request, context: ServerCallContext):
304+
context.activated_extensions.add('foo')
305+
context.activated_extensions.add('baz')
306+
return types.Task(
307+
id='task-1',
308+
contextId='ctx-1',
309+
status=types.TaskStatus(state=types.TaskState.completed),
310+
)
311+
312+
mock_request_handler.on_message_send.side_effect = side_effect
313+
314+
await grpc_handler.SendMessage(
315+
a2a_pb2.SendMessageRequest(), mock_grpc_context
316+
)
317+
318+
mock_request_handler.on_message_send.assert_awaited_once()
319+
call_context = mock_request_handler.on_message_send.call_args[0][1]
320+
assert isinstance(call_context, ServerCallContext)
321+
assert call_context.requested_extensions == {'foo', 'bar'}
322+
323+
mock_grpc_context.set_trailing_metadata.assert_called_once_with(
324+
[(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')]
325+
)
326+
327+
@patch(
328+
'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build'
329+
)
330+
async def test_send_streaming_message_with_extensions(
331+
self,
332+
mock_build,
333+
grpc_handler: GrpcHandler,
334+
mock_request_handler: AsyncMock,
335+
mock_grpc_context: AsyncMock,
336+
):
337+
mock_build.return_value = ServerCallContext(
338+
requested_extensions={'foo', 'bar'}
339+
)
340+
341+
async def side_effect(request, context: ServerCallContext):
342+
context.activated_extensions.add('foo')
343+
context.activated_extensions.add('baz')
344+
yield types.Task(
345+
id='task-1',
346+
contextId='ctx-1',
347+
status=types.TaskStatus(state=types.TaskState.working),
348+
)
349+
350+
mock_request_handler.on_message_send_stream.side_effect = side_effect
351+
352+
results = [
353+
result
354+
async for result in grpc_handler.SendStreamingMessage(
355+
a2a_pb2.SendMessageRequest(), mock_grpc_context
356+
)
357+
]
358+
assert results
359+
360+
mock_request_handler.on_message_send_stream.assert_called_once()
361+
call_context = mock_request_handler.on_message_send_stream.call_args[0][
362+
1
363+
]
364+
assert isinstance(call_context, ServerCallContext)
365+
assert call_context.requested_extensions == {'foo', 'bar'}
366+
367+
mock_grpc_context.set_trailing_metadata.assert_called_once_with(
368+
[(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')]
369+
)

0 commit comments

Comments
 (0)