Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ async def send_message_streaming(
try:
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
if not sse.data:
continue
response = SendStreamingMessageResponse.model_validate(
json.loads(sse.data)
)
Expand Down
2 changes: 2 additions & 0 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,25 @@
try:
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
if not sse.data:
continue
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(self, request: httpx.Request) -> dict[str, Any]:

Check notice on line 175 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (187-202)
try:
response = await self.httpx_client.send(request)
response.raise_for_status()
Expand Down
58 changes: 58 additions & 0 deletions tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import httpx
import pytest
import respx

from httpx_sse import EventSource, SSEError, ServerSentEvent

Expand Down Expand Up @@ -466,6 +467,63 @@ async def test_send_message_streaming_success(
== mock_stream_response_2.result.model_dump()
)

# Repro of https://github.com/a2aproject/a2a-python/issues/540
@pytest.mark.asyncio
@respx.mock
async def test_send_message_streaming_comment_success(
self,
mock_agent_card: MagicMock,
):
async with httpx.AsyncClient() as client:
transport = JsonRpcTransport(
httpx_client=client, agent_card=mock_agent_card
)
params = MessageSendParams(
message=create_text_message_object(content='Hello stream')
)
mock_stream_response_1 = SendMessageSuccessResponse(
id='stream_id_123',
jsonrpc='2.0',
result=create_text_message_object(
content='First part', role=Role.agent
),
)
mock_stream_response_2 = SendMessageSuccessResponse(
id='stream_id_123',
jsonrpc='2.0',
result=create_text_message_object(
content='Second part', role=Role.agent
),
)

sse_content = (
'id: stream_id_1\n'
f'data: {mock_stream_response_1.model_dump_json()}\n\n'
': keep-alive\n\n'
'id: stream_id_2\n'
f'data: {mock_stream_response_2.model_dump_json()}\n\n'
': keep-alive\n\n'
)

respx.post(mock_agent_card.url).mock(
return_value=httpx.Response(
200,
headers={'Content-Type': 'text/event-stream'},
content=sse_content,
)
)

results = [
item
async for item in transport.send_message_streaming(
request=params
)
]

assert len(results) == 2
assert results[0] == mock_stream_response_1.result
assert results[1] == mock_stream_response_2.result

@pytest.mark.asyncio
async def test_send_request_http_status_error(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
Expand Down
65 changes: 64 additions & 1 deletion tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@

import httpx
import pytest
import respx

from google.protobuf.json_format import MessageToJson
from httpx_sse import EventSource, ServerSentEvent

from a2a.client import create_text_message_object
from a2a.client.errors import A2AClientHTTPError
from a2a.client.transports.rest import RestTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.grpc import a2a_pb2
from a2a.types import (
AgentCapabilities,
AgentCard,
MessageSendParams,
Role,
)
from a2a.utils import proto_utils


@pytest.fixture
Expand Down Expand Up @@ -50,7 +55,7 @@ class TestRestTransportExtensions:
async def test_send_message_with_default_extensions(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
"""Test that send_message adds extensions to headers."""
"""Test that SSE comments are ignored."""
extensions = [
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
Expand Down Expand Up @@ -88,6 +93,64 @@ async def test_send_message_with_default_extensions(
},
)

# Repro of https://github.com/a2aproject/a2a-python/issues/540
@pytest.mark.asyncio
@respx.mock
async def test_send_message_streaming_comment_success(
self,
mock_agent_card: MagicMock,
):
"""Test successful streaming in RestTransport."""
async with httpx.AsyncClient() as client:
transport = RestTransport(
httpx_client=client, agent_card=mock_agent_card
)
params = MessageSendParams(
message=create_text_message_object(content='Hello stream')
)

mock_stream_response_1 = a2a_pb2.StreamResponse(
msg=proto_utils.ToProto.message(
create_text_message_object(
content='First part', role=Role.agent
)
)
)
mock_stream_response_2 = a2a_pb2.StreamResponse(
msg=proto_utils.ToProto.message(
create_text_message_object(
content='Second part', role=Role.agent
)
)
)

sse_content = (
'id: stream_id_1\n'
f'data: {MessageToJson(mock_stream_response_1, indent=None)}\n\n'
': keep-alive\n\n'
'id: stream_id_2\n'
f'data: {MessageToJson(mock_stream_response_2, indent=None)}\n\n'
': keep-alive\n\n'
)

respx.post(
f'{mock_agent_card.url.rstrip("/")}/v1/message:stream'
).mock(
return_value=httpx.Response(
200,
headers={'Content-Type': 'text/event-stream'},
content=sse_content,
)
)

results = []
async for item in transport.send_message_streaming(request=params):
results.append(item)

assert len(results) == 2
assert results[0].parts[0].root.text == 'First part'
assert results[1].parts[0].root.text == 'Second part'

@pytest.mark.asyncio
@patch('a2a.client.transports.rest.aconnect_sse')
async def test_send_message_streaming_with_new_extensions(
Expand Down
Loading