Skip to content

Commit 22ef4db

Browse files
committed
Address issues in legacy client, review comments
1 parent ff0ad3b commit 22ef4db

3 files changed

Lines changed: 156 additions & 8 deletions

File tree

src/a2a/client/legacy.py

Lines changed: 137 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CancelTaskRequest,
1616
CancelTaskResponse,
1717
CancelTaskSuccessResponse,
18+
GetTaskPushNotificationConfigParams,
1819
GetTaskPushNotificationConfigRequest,
1920
GetTaskPushNotificationConfigResponse,
2021
GetTaskPushNotificationConfigSuccessResponse,
@@ -31,6 +32,7 @@
3132
SetTaskPushNotificationConfigRequest,
3233
SetTaskPushNotificationConfigResponse,
3334
SetTaskPushNotificationConfigSuccessResponse,
35+
TaskIdParams,
3436
TaskResubscriptionRequest,
3537
)
3638

@@ -62,6 +64,21 @@ async def send_message(
6264
http_kwargs: dict[str, Any] | None = None,
6365
context: ClientCallContext | None = None,
6466
) -> SendMessageResponse:
67+
"""Sends a non-streaming message request to the agent.
68+
69+
Args:
70+
request: The `SendMessageRequest` object containing the message and configuration.
71+
http_kwargs: Optional dictionary of keyword arguments to pass to the
72+
underlying httpx.post request.
73+
context: The client call context.
74+
75+
Returns:
76+
A `SendMessageResponse` object containing the agent's response (Task or Message) or an error.
77+
78+
Raises:
79+
A2AClientHTTPError: If an HTTP error occurs during the request.
80+
A2AClientJSONError: If the response body cannot be decoded as JSON or validated.
81+
"""
6582
if not context and http_kwargs:
6683
context = ClientCallContext(state={'http_kwargs': http_kwargs})
6784

@@ -75,7 +92,7 @@ async def send_message(
7592
)
7693
)
7794
except A2AClientJSONRPCError as e:
78-
return SendMessageResponse(root=JSONRPCErrorResponse(error=e.error))
95+
return SendMessageResponse(JSONRPCErrorResponse(error=e.error))
7996

8097
async def send_message_streaming(
8198
self,
@@ -84,6 +101,24 @@ async def send_message_streaming(
84101
http_kwargs: dict[str, Any] | None = None,
85102
context: ClientCallContext | None = None,
86103
) -> AsyncGenerator[SendStreamingMessageResponse, None]:
104+
"""Sends a streaming message request to the agent and yields responses as they arrive.
105+
106+
This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent.
107+
108+
Args:
109+
request: The `SendStreamingMessageRequest` object containing the message and configuration.
110+
http_kwargs: Optional dictionary of keyword arguments to pass to the
111+
underlying httpx.post request. A default `timeout=None` is set but can be overridden.
112+
context: The client call context.
113+
114+
Yields:
115+
`SendStreamingMessageResponse` objects as they are received in the SSE stream.
116+
These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent.
117+
118+
Raises:
119+
A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request.
120+
A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated.
121+
"""
87122
if not context and http_kwargs:
88123
context = ClientCallContext(state={'http_kwargs': http_kwargs})
89124

@@ -103,6 +138,21 @@ async def get_task(
103138
http_kwargs: dict[str, Any] | None = None,
104139
context: ClientCallContext | None = None,
105140
) -> GetTaskResponse:
141+
"""Retrieves the current state and history of a specific task.
142+
143+
Args:
144+
request: The `GetTaskRequest` object specifying the task ID and history length.
145+
http_kwargs: Optional dictionary of keyword arguments to pass to the
146+
underlying httpx.post request.
147+
context: The client call context.
148+
149+
Returns:
150+
A `GetTaskResponse` object containing the Task or an error.
151+
152+
Raises:
153+
A2AClientHTTPError: If an HTTP error occurs during the request.
154+
A2AClientJSONError: If the response body cannot be decoded as JSON or validated.
155+
"""
106156
if not context and http_kwargs:
107157
context = ClientCallContext(state={'http_kwargs': http_kwargs})
108158
try:
@@ -124,6 +174,21 @@ async def cancel_task(
124174
http_kwargs: dict[str, Any] | None = None,
125175
context: ClientCallContext | None = None,
126176
) -> CancelTaskResponse:
177+
"""Requests the agent to cancel a specific task.
178+
179+
Args:
180+
request: The `CancelTaskRequest` object specifying the task ID.
181+
http_kwargs: Optional dictionary of keyword arguments to pass to the
182+
underlying httpx.post request.
183+
context: The client call context.
184+
185+
Returns:
186+
A `CancelTaskResponse` object containing the updated Task with canceled status or an error.
187+
188+
Raises:
189+
A2AClientHTTPError: If an HTTP error occurs during the request.
190+
A2AClientJSONError: If the response body cannot be decoded as JSON or validated.
191+
"""
127192
if not context and http_kwargs:
128193
context = ClientCallContext(state={'http_kwargs': http_kwargs})
129194
try:
@@ -136,7 +201,7 @@ async def cancel_task(
136201
)
137202
)
138203
except A2AClientJSONRPCError as e:
139-
return CancelTaskResponse(root=JSONRPCErrorResponse(error=e.error))
204+
return CancelTaskResponse(JSONRPCErrorResponse(error=e.error))
140205

141206
async def set_task_callback(
142207
self,
@@ -145,6 +210,21 @@ async def set_task_callback(
145210
http_kwargs: dict[str, Any] | None = None,
146211
context: ClientCallContext | None = None,
147212
) -> SetTaskPushNotificationConfigResponse:
213+
"""Sets or updates the push notification configuration for a specific task.
214+
215+
Args:
216+
request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration.
217+
http_kwargs: Optional dictionary of keyword arguments to pass to the
218+
underlying httpx.post request.
219+
context: The client call context.
220+
221+
Returns:
222+
A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error.
223+
224+
Raises:
225+
A2AClientHTTPError: If an HTTP error occurs during the request.
226+
A2AClientJSONError: If the response body cannot be decoded as JSON or validated.
227+
"""
148228
if not context and http_kwargs:
149229
context = ClientCallContext(state={'http_kwargs': http_kwargs})
150230
try:
@@ -158,7 +238,7 @@ async def set_task_callback(
158238
)
159239
except A2AClientJSONRPCError as e:
160240
return SetTaskPushNotificationConfigResponse(
161-
root=JSONRPCErrorResponse(error=e.error)
241+
JSONRPCErrorResponse(error=e.error)
162242
)
163243

164244
async def get_task_callback(
@@ -168,11 +248,31 @@ async def get_task_callback(
168248
http_kwargs: dict[str, Any] | None = None,
169249
context: ClientCallContext | None = None,
170250
) -> GetTaskPushNotificationConfigResponse:
251+
"""Retrieves the push notification configuration for a specific task.
252+
253+
Args:
254+
request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID.
255+
http_kwargs: Optional dictionary of keyword arguments to pass to the
256+
underlying httpx.post request.
257+
context: The client call context.
258+
259+
Returns:
260+
A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error.
261+
262+
Raises:
263+
A2AClientHTTPError: If an HTTP error occurs during the request.
264+
A2AClientJSONError: If the response body cannot be decoded as JSON or validated.
265+
"""
171266
if not context and http_kwargs:
172267
context = ClientCallContext(state={'http_kwargs': http_kwargs})
268+
params = request.params
269+
if isinstance(params, TaskIdParams):
270+
params = GetTaskPushNotificationConfigParams(
271+
id=request.params.task_id
272+
)
173273
try:
174274
result = await self._transport.get_task_callback(
175-
request.params, context=context
275+
params, context=context
176276
)
177277
return GetTaskPushNotificationConfigResponse(
178278
root=GetTaskPushNotificationConfigSuccessResponse(
@@ -181,7 +281,7 @@ async def get_task_callback(
181281
)
182282
except A2AClientJSONRPCError as e:
183283
return GetTaskPushNotificationConfigResponse(
184-
root=JSONRPCErrorResponse(error=e.error)
284+
JSONRPCErrorResponse(error=e.error)
185285
)
186286

187287
async def resubscribe(
@@ -191,6 +291,24 @@ async def resubscribe(
191291
http_kwargs: dict[str, Any] | None = None,
192292
context: ClientCallContext | None = None,
193293
) -> AsyncGenerator[SendStreamingMessageResponse, None]:
294+
"""Reconnects to get task updates.
295+
296+
This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent.
297+
298+
Args:
299+
request: The `TaskResubscriptionRequest` object containing the task information to reconnect to.
300+
http_kwargs: Optional dictionary of keyword arguments to pass to the
301+
underlying httpx.post request. A default `timeout=None` is set but can be overridden.
302+
context: The client call context.
303+
304+
Yields:
305+
`SendStreamingMessageResponse` objects as they are received in the SSE stream.
306+
These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent.
307+
308+
Raises:
309+
A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request.
310+
A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated.
311+
"""
194312
if not context and http_kwargs:
195313
context = ClientCallContext(state={'http_kwargs': http_kwargs})
196314

@@ -209,6 +327,20 @@ async def get_card(
209327
http_kwargs: dict[str, Any] | None = None,
210328
context: ClientCallContext | None = None,
211329
) -> AgentCard:
330+
"""Retrieves the authenticated card (if necessary) or the public one.
331+
332+
Args:
333+
http_kwargs: Optional dictionary of keyword arguments to pass to the
334+
underlying httpx.post request.
335+
context: The client call context.
336+
337+
Returns:
338+
A `AgentCard` object containing the card or an error.
339+
340+
Raises:
341+
A2AClientHTTPError: If an HTTP error occurs during the request.
342+
A2AClientJSONError: If the response body cannot be decoded as JSON or validated.
343+
"""
212344
if not context and http_kwargs:
213345
context = ClientCallContext(state={'http_kwargs': http_kwargs})
214346
return await self._transport.get_card(context=context)

src/a2a/client/transports/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async def get_card(
9696
*,
9797
context: ClientCallContext | None = None,
9898
) -> AgentCard:
99-
"""Retrieves the agent's card."""
99+
"""Retrieves the AgentCard."""
100100

101101
@abstractmethod
102102
async def close(self) -> None:

tests/client/test_legacy_client.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
11
"""Tests for the legacy client compatibility layer."""
22

3-
from unittest.mock import AsyncMock
3+
from unittest.mock import AsyncMock, MagicMock
44

5+
import httpx
56
import pytest
67

7-
from a2a.types import AgentCard
8+
from a2a.client import A2AClient, A2AGrpcClient
9+
from a2a.types import (
10+
AgentCard,
11+
AgentCapabilities,
12+
Message,
13+
Role,
14+
TextPart,
15+
Part,
16+
Task,
17+
TaskStatus,
18+
TaskState,
19+
TaskQueryParams,
20+
SendMessageRequest,
21+
MessageSendParams,
22+
GetTaskRequest,
23+
)
824

925

1026
@pytest.fixture

0 commit comments

Comments
 (0)