|
1 | 1 | import logging |
| 2 | +import json |
2 | 3 |
|
3 | 4 | from typing import Any |
4 | 5 | from unittest.mock import MagicMock |
@@ -339,6 +340,62 @@ async def mock_stream_response(): |
339 | 340 | request_handler.on_message_send_stream.assert_called_once() |
340 | 341 |
|
341 | 342 |
|
| 343 | +@pytest.mark.anyio |
| 344 | +async def test_streaming_content_verification( |
| 345 | + streaming_client: AsyncClient, request_handler: MagicMock |
| 346 | +) -> None: |
| 347 | + """Test that streaming endpoint returns correct SSE content.""" |
| 348 | + |
| 349 | + async def mock_stream_response(): |
| 350 | + yield Message( |
| 351 | + message_id='stream_msg_1', |
| 352 | + role=Role.ROLE_AGENT, |
| 353 | + parts=[Part(text='First chunk')], |
| 354 | + ) |
| 355 | + yield Message( |
| 356 | + message_id='stream_msg_2', |
| 357 | + role=Role.ROLE_AGENT, |
| 358 | + parts=[Part(text='Second chunk')], |
| 359 | + ) |
| 360 | + |
| 361 | + request_handler.on_message_send_stream.return_value = mock_stream_response() |
| 362 | + |
| 363 | + request = a2a_pb2.SendMessageRequest( |
| 364 | + message=a2a_pb2.Message( |
| 365 | + message_id='test_stream_msg', |
| 366 | + role=a2a_pb2.ROLE_USER, |
| 367 | + parts=[a2a_pb2.Part(text='Test message')], |
| 368 | + ), |
| 369 | + ) |
| 370 | + |
| 371 | + response = await streaming_client.post( |
| 372 | + '/message:stream', |
| 373 | + json=json_format.MessageToDict(request), |
| 374 | + headers={'Accept': 'text/event-stream'}, |
| 375 | + ) |
| 376 | + |
| 377 | + response.raise_for_status() |
| 378 | + |
| 379 | + # Read the response content |
| 380 | + lines = [line async for line in response.aiter_lines()] |
| 381 | + |
| 382 | + # SSE format is "data: <json>\n\n" |
| 383 | + # httpx.aiter_lines() will give us each line. |
| 384 | + |
| 385 | + data_lines = [line for line in lines if line.startswith('data: ')] |
| 386 | + assert len(data_lines) == 2 |
| 387 | + |
| 388 | + # First chunk |
| 389 | + first_data = json.loads(data_lines[0][6:]) |
| 390 | + assert first_data['message']['messageId'] == 'stream_msg_1' |
| 391 | + assert first_data['message']['parts'][0]['text'] == 'First chunk' |
| 392 | + |
| 393 | + # Second chunk |
| 394 | + second_data = json.loads(data_lines[1][6:]) |
| 395 | + assert second_data['message']['messageId'] == 'stream_msg_2' |
| 396 | + assert second_data['message']['parts'][0]['text'] == 'Second chunk' |
| 397 | + |
| 398 | + |
342 | 399 | @pytest.mark.anyio |
343 | 400 | async def test_streaming_endpoint_with_invalid_content_type( |
344 | 401 | streaming_client: AsyncClient, request_handler: MagicMock |
|
0 commit comments