|
| 1 | +import json |
| 2 | + |
1 | 3 | from unittest.mock import AsyncMock |
2 | 4 |
|
| 5 | +import httpx |
3 | 6 | import pytest |
| 7 | + |
4 | 8 | from starlette.applications import Starlette |
| 9 | +from starlette.routing import BaseRoute |
5 | 10 | from starlette.testclient import TestClient |
6 | | -from starlette.routing import BaseRoute, Route |
7 | 11 |
|
8 | 12 | from a2a.server.request_handlers.request_handler import RequestHandler |
9 | 13 | from a2a.server.routes.rest_routes import create_rest_routes |
10 | | -from a2a.types.a2a_pb2 import AgentCard, Task, ListTasksResponse |
| 14 | +from a2a.types.a2a_pb2 import ( |
| 15 | + AgentCapabilities, |
| 16 | + AgentCard, |
| 17 | + ListTasksResponse, |
| 18 | + Message, |
| 19 | + Part, |
| 20 | + Role, |
| 21 | + Task, |
| 22 | +) |
| 23 | +from a2a.utils.errors import InternalError |
11 | 24 |
|
12 | 25 |
|
13 | 26 | @pytest.fixture |
@@ -103,3 +116,125 @@ def test_rest_list_tasks(agent_card, mock_handler): |
103 | 116 | response = client.get('/tasks', headers={'A2A-Version': '1.0'}) |
104 | 117 | assert response.status_code == 200 |
105 | 118 | assert mock_handler.on_list_tasks.called |
| 119 | + |
| 120 | + |
| 121 | +@pytest.fixture |
| 122 | +def streaming_agent_card(): |
| 123 | + return AgentCard( |
| 124 | + capabilities=AgentCapabilities(streaming=True), |
| 125 | + ) |
| 126 | + |
| 127 | + |
| 128 | +@pytest.fixture |
| 129 | +def streaming_app(streaming_agent_card, mock_handler): |
| 130 | + routes = create_rest_routes( |
| 131 | + agent_card=streaming_agent_card, request_handler=mock_handler |
| 132 | + ) |
| 133 | + return Starlette(routes=routes) |
| 134 | + |
| 135 | + |
| 136 | +@pytest.fixture |
| 137 | +def streaming_client(streaming_app): |
| 138 | + return httpx.AsyncClient( |
| 139 | + transport=httpx.ASGITransport(app=streaming_app), |
| 140 | + base_url='http://test', |
| 141 | + headers={'A2A-Version': '1.0'}, |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +@pytest.mark.asyncio |
| 146 | +async def test_streaming_mid_stream_error_emits_sse_error_event( |
| 147 | + streaming_client, mock_handler |
| 148 | +): |
| 149 | + """Test that mid-stream errors are sent as SSE error events.""" |
| 150 | + |
| 151 | + async def mock_stream_then_error(*args, **kwargs): |
| 152 | + yield Message( |
| 153 | + message_id='stream_msg_1', |
| 154 | + role=Role.ROLE_AGENT, |
| 155 | + parts=[Part(text='First chunk')], |
| 156 | + ) |
| 157 | + raise InternalError(message='Something went wrong mid-stream') |
| 158 | + |
| 159 | + mock_handler.on_message_send_stream.side_effect = mock_stream_then_error |
| 160 | + |
| 161 | + response = await streaming_client.post( |
| 162 | + '/message:stream', |
| 163 | + headers={'Accept': 'text/event-stream'}, |
| 164 | + json={}, |
| 165 | + ) |
| 166 | + |
| 167 | + response.raise_for_status() |
| 168 | + assert 'text/event-stream' in response.headers.get('content-type', '') |
| 169 | + |
| 170 | + lines = [line.strip() for line in response.text.strip().splitlines()] |
| 171 | + |
| 172 | + # Should have a normal data event followed by an error event |
| 173 | + data_lines = [ |
| 174 | + json.loads(line[6:]) for line in lines if line.startswith('data: ') |
| 175 | + ] |
| 176 | + assert len(data_lines) >= 1 |
| 177 | + assert 'message' in data_lines[0] |
| 178 | + assert data_lines[0]['message']['messageId'] == 'stream_msg_1' |
| 179 | + |
| 180 | + # Should contain an SSE error event |
| 181 | + error_event_lines = [line for line in lines if line == 'event: error'] |
| 182 | + assert len(error_event_lines) == 1 |
| 183 | + |
| 184 | + # Find the error data after the error event |
| 185 | + error_data = None |
| 186 | + for i, line in enumerate(lines): |
| 187 | + if line == 'event: error': |
| 188 | + for j in range(i + 1, len(lines)): |
| 189 | + if lines[j].startswith('data: '): |
| 190 | + error_data = json.loads(lines[j][6:]) |
| 191 | + break |
| 192 | + break |
| 193 | + |
| 194 | + assert error_data is not None |
| 195 | + assert error_data['error']['code'] == 500 |
| 196 | + assert error_data['error']['status'] == 'INTERNAL' |
| 197 | + assert 'Something went wrong mid-stream' in error_data['error']['message'] |
| 198 | + |
| 199 | + |
| 200 | +@pytest.mark.asyncio |
| 201 | +async def test_streaming_mid_stream_unknown_error_emits_sse_error_event( |
| 202 | + streaming_client, mock_handler |
| 203 | +): |
| 204 | + """Test that non-A2AError mid-stream errors also produce SSE error events.""" |
| 205 | + |
| 206 | + async def mock_stream_then_error(*args, **kwargs): |
| 207 | + yield Message( |
| 208 | + message_id='stream_msg_1', |
| 209 | + role=Role.ROLE_AGENT, |
| 210 | + parts=[Part(text='First chunk')], |
| 211 | + ) |
| 212 | + raise RuntimeError('Unexpected failure') |
| 213 | + |
| 214 | + mock_handler.on_message_send_stream.side_effect = mock_stream_then_error |
| 215 | + |
| 216 | + response = await streaming_client.post( |
| 217 | + '/message:stream', |
| 218 | + headers={'Accept': 'text/event-stream'}, |
| 219 | + json={}, |
| 220 | + ) |
| 221 | + |
| 222 | + response.raise_for_status() |
| 223 | + |
| 224 | + lines = [line.strip() for line in response.text.strip().splitlines()] |
| 225 | + |
| 226 | + error_event_lines = [line for line in lines if line == 'event: error'] |
| 227 | + assert len(error_event_lines) == 1 |
| 228 | + |
| 229 | + error_data = None |
| 230 | + for i, line in enumerate(lines): |
| 231 | + if line == 'event: error': |
| 232 | + for j in range(i + 1, len(lines)): |
| 233 | + if lines[j].startswith('data: '): |
| 234 | + error_data = json.loads(lines[j][6:]) |
| 235 | + break |
| 236 | + break |
| 237 | + |
| 238 | + assert error_data is not None |
| 239 | + assert error_data['error']['code'] == 500 |
| 240 | + assert error_data['error']['status'] == 'INTERNAL' |
0 commit comments