Skip to content

Commit a587390

Browse files
committed
Update
1 parent f5e0f72 commit a587390

1 file changed

Lines changed: 2 additions & 129 deletions

File tree

Lines changed: 2 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,13 @@
1-
import json
2-
31
from unittest.mock import AsyncMock
42

5-
import httpx
63
import pytest
7-
84
from starlette.applications import Starlette
9-
from starlette.routing import BaseRoute
105
from starlette.testclient import TestClient
6+
from starlette.routing import BaseRoute, Route
117

128
from a2a.server.request_handlers.request_handler import RequestHandler
139
from a2a.server.routes.rest_routes import create_rest_routes
14-
from a2a.types.a2a_pb2 import (
15-
AgentCard,
16-
ListTasksResponse,
17-
Message,
18-
Part,
19-
Role,
20-
Task,
21-
)
22-
from a2a.utils.errors import InternalError
10+
from a2a.types.a2a_pb2 import AgentCard, Task, ListTasksResponse
2311

2412

2513
@pytest.fixture
@@ -104,118 +92,3 @@ def test_rest_list_tasks(agent_card, mock_handler):
10492
response = client.get('/tasks', headers={'A2A-Version': '1.0'})
10593
assert response.status_code == 200
10694
assert mock_handler.on_list_tasks.called
107-
108-
109-
@pytest.fixture
110-
def streaming_app(mock_handler):
111-
routes = create_rest_routes(
112-
request_handler=mock_handler,
113-
)
114-
return Starlette(routes=routes)
115-
116-
117-
@pytest.fixture
118-
def streaming_client(streaming_app):
119-
return httpx.AsyncClient(
120-
transport=httpx.ASGITransport(app=streaming_app),
121-
base_url='http://test',
122-
headers={'A2A-Version': '1.0'},
123-
)
124-
125-
126-
@pytest.mark.asyncio
127-
async def test_streaming_mid_stream_error_emits_sse_error_event(
128-
streaming_client, mock_handler
129-
):
130-
"""Test that mid-stream errors are sent as SSE error events."""
131-
132-
async def mock_stream_then_error(*args, **kwargs):
133-
yield Message(
134-
message_id='stream_msg_1',
135-
role=Role.ROLE_AGENT,
136-
parts=[Part(text='First chunk')],
137-
)
138-
raise InternalError(message='Something went wrong mid-stream')
139-
140-
mock_handler.on_message_send_stream.side_effect = mock_stream_then_error
141-
142-
response = await streaming_client.post(
143-
'/message:stream',
144-
headers={'Accept': 'text/event-stream'},
145-
json={},
146-
)
147-
148-
response.raise_for_status()
149-
assert 'text/event-stream' in response.headers.get('content-type', '')
150-
151-
lines = [line.strip() for line in response.text.strip().splitlines()]
152-
153-
# Should have a normal data event followed by an error event
154-
data_lines = [
155-
json.loads(line[6:]) for line in lines if line.startswith('data: ')
156-
]
157-
assert len(data_lines) >= 1
158-
assert 'message' in data_lines[0]
159-
assert data_lines[0]['message']['messageId'] == 'stream_msg_1'
160-
161-
# Should contain an SSE error event
162-
error_event_lines = [line for line in lines if line == 'event: error']
163-
assert len(error_event_lines) == 1
164-
165-
# Find the error data after the error event
166-
error_data = None
167-
for i, line in enumerate(lines):
168-
if line == 'event: error':
169-
for j in range(i + 1, len(lines)):
170-
if lines[j].startswith('data: '):
171-
error_data = json.loads(lines[j][6:])
172-
break
173-
break
174-
175-
assert error_data is not None
176-
assert error_data['error']['code'] == 500
177-
assert error_data['error']['status'] == 'INTERNAL'
178-
assert 'Something went wrong mid-stream' in error_data['error']['message']
179-
180-
181-
@pytest.mark.asyncio
182-
async def test_streaming_mid_stream_unknown_error_emits_sse_error_event(
183-
streaming_client, mock_handler
184-
):
185-
"""Test that non-A2AError mid-stream errors also produce SSE error events."""
186-
187-
async def mock_stream_then_error(*args, **kwargs):
188-
yield Message(
189-
message_id='stream_msg_1',
190-
role=Role.ROLE_AGENT,
191-
parts=[Part(text='First chunk')],
192-
)
193-
raise RuntimeError('Unexpected failure')
194-
195-
mock_handler.on_message_send_stream.side_effect = mock_stream_then_error
196-
197-
response = await streaming_client.post(
198-
'/message:stream',
199-
headers={'Accept': 'text/event-stream'},
200-
json={},
201-
)
202-
203-
response.raise_for_status()
204-
205-
lines = [line.strip() for line in response.text.strip().splitlines()]
206-
207-
error_event_lines = [line for line in lines if line == 'event: error']
208-
assert len(error_event_lines) == 1
209-
210-
error_data = None
211-
for i, line in enumerate(lines):
212-
if line == 'event: error':
213-
for j in range(i + 1, len(lines)):
214-
if lines[j].startswith('data: '):
215-
error_data = json.loads(lines[j][6:])
216-
break
217-
break
218-
219-
assert error_data is not None
220-
assert error_data['error']['code'] == 500
221-
assert error_data['error']['status'] == 'INTERNAL'

0 commit comments

Comments
 (0)