Skip to content

Commit ff0ad3b

Browse files
committed
Fix broken tests
1 parent 9182fdb commit ff0ad3b

5 files changed

Lines changed: 22 additions & 124 deletions

File tree

src/a2a/client/transports/rest.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -310,24 +310,14 @@ async def resubscribe(
310310
Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
311311
]:
312312
"""Reconnects to get task updates."""
313-
pb = a2a_pb2.TaskSubscriptionRequest(
314-
name=f'tasks/{request.id}',
315-
)
316-
payload = MessageToDict(pb)
317-
payload, modified_kwargs = await self._apply_interceptors(
318-
payload,
319-
self._get_http_args(context),
320-
context,
321-
)
322-
323-
modified_kwargs.setdefault('timeout', None)
313+
http_kwargs = self._get_http_args(context) or {}
314+
http_kwargs.setdefault('timeout', None)
324315

325316
async with aconnect_sse(
326317
self.httpx_client,
327318
'GET',
328319
f'{self.url}/v1/tasks/{request.id}:subscribe',
329-
json=payload,
330-
**modified_kwargs,
320+
**http_kwargs,
331321
) as event_source:
332322
try:
333323
async for sse in event_source.aiter_sse():

src/a2a/server/apps/rest/rest_app.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
from a2a.server.context import ServerCallContext
1616
from a2a.server.request_handlers.request_handler import RequestHandler
1717
from a2a.server.request_handlers.rest_handler import RESTHandler
18-
from a2a.types import (
19-
AgentCard,
20-
AuthenticatedExtendedCardNotConfiguredError,
21-
)
18+
from a2a.types import AgentCard, AuthenticatedExtendedCardNotConfiguredError
2219
from a2a.utils.error_handlers import (
2320
rest_error_handler,
2421
rest_stream_error_handler,
@@ -61,21 +58,7 @@ def __init__(
6158
@rest_error_handler
6259
async def _handle_request(
6360
self,
64-
method: Callable[
65-
[Request, ServerCallContext], Awaitable[dict[str, Any]]
66-
],
67-
request: Request,
68-
) -> Response:
69-
call_context = self._context_builder.build(request)
70-
response = await method(request, call_context)
71-
return JSONResponse(content=response)
72-
73-
@rest_error_handler
74-
async def _handle_list_request(
75-
self,
76-
method: Callable[
77-
[Request, ServerCallContext], Awaitable[list[dict[str, Any]]]
78-
],
61+
method: Callable[[Request, ServerCallContext], Awaitable[Any]],
7962
request: Request,
8063
) -> Response:
8164
call_context = self._context_builder.build(request)
@@ -85,15 +68,13 @@ async def _handle_list_request(
8568
@rest_stream_error_handler
8669
async def _handle_streaming_request(
8770
self,
88-
method: Callable[
89-
[Request, ServerCallContext], AsyncIterable[dict[str, Any]]
90-
],
71+
method: Callable[[Request, ServerCallContext], AsyncIterable[Any]],
9172
request: Request,
9273
) -> EventSourceResponse:
9374
call_context = self._context_builder.build(request)
9475

9576
async def event_generator(
96-
stream: AsyncIterable[dict[str, Any]],
77+
stream: AsyncIterable[Any],
9778
) -> AsyncIterator[dict[str, dict[str, Any]]]:
9879
async for item in stream:
9980
yield {'data': item}
@@ -188,10 +169,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
188169
'/v1/tasks/{id}/pushNotificationConfigs',
189170
'GET',
190171
): functools.partial(
191-
self._handle_list_request, self.handler.list_push_notifications
172+
self._handle_request, self.handler.list_push_notifications
192173
),
193174
('/v1/tasks', 'GET'): functools.partial(
194-
self._handle_list_request, self.handler.list_tasks
175+
self._handle_request, self.handler.list_tasks
195176
),
196177
}
197178
if self.agent_card.supports_authenticated_extended_card:

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import AsyncIterable, AsyncIterator
44
from typing import Any
55

6-
from google.protobuf.json_format import MessageToDict, Parse
6+
from google.protobuf.json_format import MessageToDict, MessageToJson, Parse
77
from starlette.requests import Request
88

99
from a2a.grpc import a2a_pb2
@@ -86,7 +86,7 @@ async def on_message_send_stream(
8686
self,
8787
request: Request,
8888
context: ServerCallContext,
89-
) -> AsyncIterator[dict[str, Any]]:
89+
) -> AsyncIterator[str]:
9090
"""Handles the 'message/stream' REST method.
9191
9292
Yields response objects as they are produced by the underlying handler's stream.
@@ -96,7 +96,7 @@ async def on_message_send_stream(
9696
context: Context provided by the server.
9797
9898
Yields:
99-
`dict` objects containing streaming events
99+
JSON serialized objects containing streaming events
100100
(Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON
101101
"""
102102
body = await request.body()
@@ -110,7 +110,7 @@ async def on_message_send_stream(
110110
a2a_request, context
111111
):
112112
response = proto_utils.ToProto.stream_response(event)
113-
yield MessageToDict(response)
113+
yield MessageToJson(response)
114114

115115
async def on_cancel_task(
116116
self,
@@ -142,7 +142,7 @@ async def on_resubscribe_to_task(
142142
self,
143143
request: Request,
144144
context: ServerCallContext,
145-
) -> AsyncIterable[dict[str, Any]]:
145+
) -> AsyncIterable[str]:
146146
"""Handles the 'tasks/resubscribe' REST method.
147147
148148
Yields response objects as they are produced by the underlying handler's stream.
@@ -152,13 +152,13 @@ async def on_resubscribe_to_task(
152152
context: Context provided by the server.
153153
154154
Yields:
155-
`dict` containing streaming events
155+
JSON serialized objects containing streaming events
156156
"""
157157
task_id = request.path_params['id']
158158
async for event in self.request_handler.on_resubscribe_to_task(
159159
TaskIdParams(id=task_id), context
160160
):
161-
yield (MessageToDict(proto_utils.ToProto.stream_response(event)))
161+
yield MessageToJson(proto_utils.ToProto.stream_response(event))
162162

163163
async def get_push_notification(
164164
self,
@@ -262,7 +262,7 @@ async def list_push_notifications(
262262
self,
263263
request: Request,
264264
context: ServerCallContext,
265-
) -> list[dict[str, Any]]:
265+
) -> dict[str, Any]:
266266
"""Handles the 'tasks/pushNotificationConfig/list' REST method.
267267
268268
This method is currently not implemented.
@@ -283,7 +283,7 @@ async def list_tasks(
283283
self,
284284
request: Request,
285285
context: ServerCallContext,
286-
) -> list[dict[str, Any]]:
286+
) -> dict[str, Any]:
287287
"""Handles the 'tasks/list' REST method.
288288
289289
This method is currently not implemented.

tests/client/test_jsonrpc_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,15 @@ def test_init_with_url(self, mock_httpx_client: AsyncMock):
334334
assert client.url == self.AGENT_URL
335335
assert client.httpx_client == mock_httpx_client
336336

337-
def test_init_with_agent_card_and_url_prioritizes_agent_card(
337+
def test_init_with_agent_card_and_url_prioritizes_url(
338338
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
339339
):
340340
client = JsonRpcTransport(
341341
httpx_client=mock_httpx_client,
342342
agent_card=mock_agent_card,
343343
url='http://otherurl.com',
344344
)
345-
assert client.url == mock_agent_card.url
345+
assert client.url == 'http://otherurl.com'
346346

347347
def test_init_raises_value_error_if_no_card_or_url(
348348
self, mock_httpx_client: AsyncMock

tests/client/test_legacy_client.py

Lines changed: 2 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,10 @@
11
"""Tests for the legacy client compatibility layer."""
22

3-
from unittest.mock import AsyncMock, MagicMock
4-
5-
import httpx
6-
import pytest
7-
8-
from a2a.client import A2AClient, A2AGrpcClient
9-
from a2a.types import (
10-
AgentCapabilities,
11-
AgentCard,
12-
Message,
13-
MessageSendParams,
14-
Part,
15-
Role,
16-
SendMessageRequest,
17-
Task,
18-
TaskQueryParams,
19-
TaskState,
20-
TaskStatus,
21-
TextPart,
22-
)
23-
24-
25-
@pytest.fixture
26-
def mock_httpx_client() -> AsyncMock:
27-
return AsyncMock(spec=httpx.AsyncClient)
28-
29-
30-
@pytest.fixture
31-
def mock_grpc_stub() -> AsyncMock:
32-
stub = AsyncMock()
33-
stub._channel = MagicMock()
34-
return stub
35-
36-
37-
@pytest.fixture
38-
def jsonrpc_agent_card() -> AgentCard:
39-
return AgentCard(
40-
name='Test Agent',
41-
description='A test agent',
42-
url='http://test.agent.com/rpc',
43-
version='1.0.0',
44-
capabilities=AgentCapabilities(streaming=True),
45-
skills=[],
46-
default_input_modes=[],
47-
default_output_modes=[],
48-
preferred_transport='jsonrpc',
49-
)
50-
51-
52-
@pytest.fixture
53-
def grpc_agent_card() -> AgentCard:
54-
return AgentCard(
55-
name='Test Agent',
56-
description='A test agent',
57-
url='http://test.agent.com/rpc',
58-
version='1.0.0',
59-
capabilities=AgentCapabilities(streaming=True),
60-
skills=[],
61-
default_input_modes=[],
62-
default_output_modes=[],
63-
preferred_transport='grpc',
64-
)
65-
66-
67-
@pytest.mark.asyncio
68-
async def test_a2a_client_send_message(
69-
mock_httpx_client: AsyncMock, jsonrpc_agent_card: AgentCard
70-
):
71-
"""Tests for the legacy client compatibility layer."""
72-
73-
743
from unittest.mock import AsyncMock
754

765
import pytest
776

78-
from a2a.types import (
79-
AgentCard,
80-
)
7+
from a2a.types import AgentCard
818

829

8310
@pytest.fixture
@@ -164,7 +91,7 @@ async def test_a2a_grpc_client_get_task(
16491
status=TaskStatus(state=TaskState.working),
16592
)
16693

167-
client.get_task.return_value = mock_response_task
94+
client.get_task = AsyncMock(return_value=mock_response_task)
16895

16996
params = TaskQueryParams(id='task-456')
17097
response = await client.get_task(params)

0 commit comments

Comments
 (0)