|
1 | | -import json |
2 | | - |
3 | 1 | from unittest.mock import AsyncMock |
4 | 2 |
|
5 | | -import httpx |
6 | 3 | import pytest |
7 | | - |
8 | 4 | from starlette.applications import Starlette |
9 | | -from starlette.routing import BaseRoute |
10 | 5 | from starlette.testclient import TestClient |
| 6 | +from starlette.routing import BaseRoute, Route |
11 | 7 |
|
12 | 8 | from a2a.server.request_handlers.request_handler import RequestHandler |
13 | 9 | from a2a.server.routes.rest_routes import create_rest_routes |
14 | | -from a2a.types.a2a_pb2 import ( |
15 | | - AgentCard, |
16 | | - ListTasksResponse, |
17 | | - Message, |
18 | | - Part, |
19 | | - Role, |
20 | | - Task, |
21 | | -) |
22 | | -from a2a.utils.errors import InternalError |
| 10 | +from a2a.types.a2a_pb2 import AgentCard, Task, ListTasksResponse |
23 | 11 |
|
24 | 12 |
|
25 | 13 | @pytest.fixture |
@@ -104,118 +92,3 @@ def test_rest_list_tasks(agent_card, mock_handler): |
104 | 92 | response = client.get('/tasks', headers={'A2A-Version': '1.0'}) |
105 | 93 | assert response.status_code == 200 |
106 | 94 | assert mock_handler.on_list_tasks.called |
107 | | - |
108 | | - |
109 | | -@pytest.fixture |
110 | | -def streaming_app(mock_handler): |
111 | | - routes = create_rest_routes( |
112 | | - request_handler=mock_handler, |
113 | | - ) |
114 | | - return Starlette(routes=routes) |
115 | | - |
116 | | - |
117 | | -@pytest.fixture |
118 | | -def streaming_client(streaming_app): |
119 | | - return httpx.AsyncClient( |
120 | | - transport=httpx.ASGITransport(app=streaming_app), |
121 | | - base_url='http://test', |
122 | | - headers={'A2A-Version': '1.0'}, |
123 | | - ) |
124 | | - |
125 | | - |
126 | | -@pytest.mark.asyncio |
127 | | -async def test_streaming_mid_stream_error_emits_sse_error_event( |
128 | | - streaming_client, mock_handler |
129 | | -): |
130 | | - """Test that mid-stream errors are sent as SSE error events.""" |
131 | | - |
132 | | - async def mock_stream_then_error(*args, **kwargs): |
133 | | - yield Message( |
134 | | - message_id='stream_msg_1', |
135 | | - role=Role.ROLE_AGENT, |
136 | | - parts=[Part(text='First chunk')], |
137 | | - ) |
138 | | - raise InternalError(message='Something went wrong mid-stream') |
139 | | - |
140 | | - mock_handler.on_message_send_stream.side_effect = mock_stream_then_error |
141 | | - |
142 | | - response = await streaming_client.post( |
143 | | - '/message:stream', |
144 | | - headers={'Accept': 'text/event-stream'}, |
145 | | - json={}, |
146 | | - ) |
147 | | - |
148 | | - response.raise_for_status() |
149 | | - assert 'text/event-stream' in response.headers.get('content-type', '') |
150 | | - |
151 | | - lines = [line.strip() for line in response.text.strip().splitlines()] |
152 | | - |
153 | | - # Should have a normal data event followed by an error event |
154 | | - data_lines = [ |
155 | | - json.loads(line[6:]) for line in lines if line.startswith('data: ') |
156 | | - ] |
157 | | - assert len(data_lines) >= 1 |
158 | | - assert 'message' in data_lines[0] |
159 | | - assert data_lines[0]['message']['messageId'] == 'stream_msg_1' |
160 | | - |
161 | | - # Should contain an SSE error event |
162 | | - error_event_lines = [line for line in lines if line == 'event: error'] |
163 | | - assert len(error_event_lines) == 1 |
164 | | - |
165 | | - # Find the error data after the error event |
166 | | - error_data = None |
167 | | - for i, line in enumerate(lines): |
168 | | - if line == 'event: error': |
169 | | - for j in range(i + 1, len(lines)): |
170 | | - if lines[j].startswith('data: '): |
171 | | - error_data = json.loads(lines[j][6:]) |
172 | | - break |
173 | | - break |
174 | | - |
175 | | - assert error_data is not None |
176 | | - assert error_data['error']['code'] == 500 |
177 | | - assert error_data['error']['status'] == 'INTERNAL' |
178 | | - assert 'Something went wrong mid-stream' in error_data['error']['message'] |
179 | | - |
180 | | - |
181 | | -@pytest.mark.asyncio |
182 | | -async def test_streaming_mid_stream_unknown_error_emits_sse_error_event( |
183 | | - streaming_client, mock_handler |
184 | | -): |
185 | | - """Test that non-A2AError mid-stream errors also produce SSE error events.""" |
186 | | - |
187 | | - async def mock_stream_then_error(*args, **kwargs): |
188 | | - yield Message( |
189 | | - message_id='stream_msg_1', |
190 | | - role=Role.ROLE_AGENT, |
191 | | - parts=[Part(text='First chunk')], |
192 | | - ) |
193 | | - raise RuntimeError('Unexpected failure') |
194 | | - |
195 | | - mock_handler.on_message_send_stream.side_effect = mock_stream_then_error |
196 | | - |
197 | | - response = await streaming_client.post( |
198 | | - '/message:stream', |
199 | | - headers={'Accept': 'text/event-stream'}, |
200 | | - json={}, |
201 | | - ) |
202 | | - |
203 | | - response.raise_for_status() |
204 | | - |
205 | | - lines = [line.strip() for line in response.text.strip().splitlines()] |
206 | | - |
207 | | - error_event_lines = [line for line in lines if line == 'event: error'] |
208 | | - assert len(error_event_lines) == 1 |
209 | | - |
210 | | - error_data = None |
211 | | - for i, line in enumerate(lines): |
212 | | - if line == 'event: error': |
213 | | - for j in range(i + 1, len(lines)): |
214 | | - if lines[j].startswith('data: '): |
215 | | - error_data = json.loads(lines[j][6:]) |
216 | | - break |
217 | | - break |
218 | | - |
219 | | - assert error_data is not None |
220 | | - assert error_data['error']['code'] == 500 |
221 | | - assert error_data['error']['status'] == 'INTERNAL' |
0 commit comments