-
Notifications
You must be signed in to change notification settings - Fork 429
Expand file tree
/
Copy pathtest_rest_client.py
More file actions
329 lines (285 loc) · 11 KB
/
test_rest_client.py
File metadata and controls
329 lines (285 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
import respx
from google.protobuf.json_format import MessageToJson
from httpx_sse import EventSource, ServerSentEvent
from a2a.client import create_text_message_object
from a2a.client.errors import A2AClientHTTPError
from a2a.client.transports.rest import RestTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.grpc import a2a_pb2
from a2a.types import (
AgentCapabilities,
AgentCard,
MessageSendParams,
Role,
)
from a2a.utils import proto_utils
@pytest.fixture
def mock_httpx_client() -> AsyncMock:
return AsyncMock(spec=httpx.AsyncClient)
@pytest.fixture
def mock_agent_card() -> MagicMock:
mock = MagicMock(spec=AgentCard, url='http://agent.example.com/api')
mock.supports_authenticated_extended_card = False
return mock
async def async_iterable_from_list(
items: list[ServerSentEvent],
) -> AsyncGenerator[ServerSentEvent, None]:
"""Helper to create an async iterable from a list."""
for item in items:
yield item
def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
headers = mock_kwargs.get('headers', {})
assert HTTP_EXTENSION_HEADER in headers
header_value = headers[HTTP_EXTENSION_HEADER]
actual_extensions = {e.strip() for e in header_value.split(',')}
assert actual_extensions == expected_extensions
class TestRestTransportExtensions:
@pytest.mark.asyncio
async def test_send_message_with_default_extensions(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
"""Test that send_message adds extensions to headers."""
extensions = [
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
]
client = RestTransport(
httpx_client=mock_httpx_client,
extensions=extensions,
agent_card=mock_agent_card,
)
params = MessageSendParams(
message=create_text_message_object(content='Hello')
)
# Mock the build_request method to capture its inputs
mock_build_request = MagicMock(
return_value=AsyncMock(spec=httpx.Request)
)
mock_httpx_client.build_request = mock_build_request
# Mock the send method
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_httpx_client.send.return_value = mock_response
await client.send_message(request=params)
mock_build_request.assert_called_once()
_, kwargs = mock_build_request.call_args
_assert_extensions_header(
kwargs,
{
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
},
)
# Repro of https://github.com/a2aproject/a2a-python/issues/540
@pytest.mark.asyncio
@respx.mock
async def test_send_message_streaming_comment_success(
self,
mock_agent_card: MagicMock,
):
"""Test successful streaming in RestTransport."""
async with httpx.AsyncClient() as client:
transport = RestTransport(
httpx_client=client, agent_card=mock_agent_card
)
params = MessageSendParams(
message=create_text_message_object(content='Hello stream')
)
mock_stream_response_1 = a2a_pb2.StreamResponse(
msg=proto_utils.ToProto.message(
create_text_message_object(
content='First part', role=Role.agent
)
)
)
mock_stream_response_2 = a2a_pb2.StreamResponse(
msg=proto_utils.ToProto.message(
create_text_message_object(
content='Second part', role=Role.agent
)
)
)
sse_content = (
'id: stream_id_1\n'
f'data: {MessageToJson(mock_stream_response_1, indent=None)}\n\n'
': keep-alive\n\n'
'id: stream_id_2\n'
f'data: {MessageToJson(mock_stream_response_2, indent=None)}\n\n'
': keep-alive\n\n'
)
print(sse_content)
respx.post(
f'{mock_agent_card.url.rstrip("/")}/v1/message:stream'
).mock(
return_value=httpx.Response(
200,
headers={'Content-Type': 'text/event-stream'},
content=sse_content,
)
)
results = [
item
async for item in transport.send_message_streaming(
request=params
)
]
assert len(results) == 2
assert results[0].parts[0].root.text == 'First part'
assert results[1].parts[0].root.text == 'Second part'
@pytest.mark.asyncio
@patch('a2a.client.transports.rest.aconnect_sse')
async def test_send_message_streaming_with_new_extensions(
self,
mock_aconnect_sse: AsyncMock,
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
"""Test X-A2A-Extensions header in send_message_streaming."""
new_extensions = ['https://example.com/test-ext/v2']
extensions = ['https://example.com/test-ext/v1']
client = RestTransport(
httpx_client=mock_httpx_client,
agent_card=mock_agent_card,
extensions=extensions,
)
params = MessageSendParams(
message=create_text_message_object(content='Hello stream')
)
mock_event_source = AsyncMock(spec=EventSource)
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
mock_aconnect_sse.return_value.__aenter__.return_value = (
mock_event_source
)
async for _ in client.send_message_streaming(
request=params, extensions=new_extensions
):
pass
mock_aconnect_sse.assert_called_once()
_, kwargs = mock_aconnect_sse.call_args
_assert_extensions_header(
kwargs,
{
'https://example.com/test-ext/v2',
},
)
@pytest.mark.asyncio
@patch('a2a.client.transports.rest.aconnect_sse')
async def test_send_message_streaming_server_error_propagates(
self,
mock_aconnect_sse: AsyncMock,
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
"""Test that send_message_streaming propagates server errors (e.g., 403, 500) directly."""
client = RestTransport(
httpx_client=mock_httpx_client,
agent_card=mock_agent_card,
)
params = MessageSendParams(
message=create_text_message_object(content='Error stream')
)
mock_event_source = AsyncMock(spec=EventSource)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 403
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
'Forbidden',
request=httpx.Request('POST', 'http://test.url'),
response=mock_response,
)
mock_event_source.response = mock_response
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
mock_aconnect_sse.return_value.__aenter__.return_value = (
mock_event_source
)
with pytest.raises(A2AClientHTTPError) as exc_info:
async for _ in client.send_message_streaming(request=params):
pass
assert exc_info.value.status_code == 403
mock_aconnect_sse.assert_called_once()
@pytest.mark.asyncio
async def test_get_card_no_card_provided_with_extensions(
self, mock_httpx_client: AsyncMock
):
"""Test get_card with extensions set in Client when no card is initially provided.
Tests that the extensions are added to the HTTP GET request."""
extensions = [
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
]
client = RestTransport(
httpx_client=mock_httpx_client,
url='http://agent.example.com/api',
extensions=extensions,
)
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = {
'name': 'Test Agent',
'description': 'Test Agent Description',
'url': 'http://agent.example.com/api',
'version': '1.0.0',
'default_input_modes': ['text'],
'default_output_modes': ['text'],
'capabilities': AgentCapabilities().model_dump(),
'skills': [],
}
mock_httpx_client.get.return_value = mock_response
await client.get_card()
mock_httpx_client.get.assert_called_once()
_, mock_kwargs = mock_httpx_client.get.call_args
_assert_extensions_header(
mock_kwargs,
{
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
},
)
@pytest.mark.asyncio
async def test_get_card_with_extended_card_support_with_extensions(
self, mock_httpx_client: AsyncMock
):
"""Test get_card with extensions passed to get_card call when extended card support is enabled.
Tests that the extensions are added to the GET request."""
extensions = [
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
]
agent_card = AgentCard(
name='Test Agent',
description='Test Agent Description',
url='http://agent.example.com/api',
version='1.0.0',
default_input_modes=['text'],
default_output_modes=['text'],
capabilities=AgentCapabilities(),
skills=[],
supports_authenticated_extended_card=True,
)
client = RestTransport(
httpx_client=mock_httpx_client,
agent_card=agent_card,
)
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = agent_card.model_dump(mode='json')
mock_httpx_client.send.return_value = mock_response
with patch.object(
client, '_send_get_request', new_callable=AsyncMock
) as mock_send_get_request:
mock_send_get_request.return_value = agent_card.model_dump(
mode='json'
)
await client.get_card(extensions=extensions)
mock_send_get_request.assert_called_once()
_, _, mock_kwargs = mock_send_get_request.call_args[0]
_assert_extensions_header(
mock_kwargs,
{
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
},
)