Skip to content

Commit 3e36c3c

Browse files
committed
Subscribe post.
1 parent 0e583f5 commit 3e36c3c

7 files changed

Lines changed: 40 additions & 22 deletions

File tree

src/a2a/client/transports/rest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ async def subscribe(
262262
f'/tasks/{request.id}:subscribe',
263263
request.tenant,
264264
context=context,
265+
json=MessageToDict(request),
265266
):
266267
yield event
267268

src/a2a/compat/v0_3/rest_transport.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,17 @@ async def subscribe(
287287
to POST for next calls but will not retry again.
288288
"""
289289
subscribe_method = self._subscribe_method_override or 'POST'
290+
if subscribe_method == 'POST':
291+
json_body = MessageToDict(request, preserving_proto_field_name=True)
292+
else:
293+
json_body = None
294+
290295
try:
291296
async for event in self._send_stream_request(
292297
subscribe_method,
293298
f'/v1/tasks/{request.id}:subscribe',
294299
context=context,
300+
json=json_body,
295301
):
296302
yield event
297303
except A2AClientError as e:

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,15 @@ async def on_subscribe_to_task(
159159
Yields:
160160
JSON serialized objects containing streaming events
161161
"""
162-
task_id = request.path_params['id']
162+
params = SubscribeToTaskRequest()
163+
if request.method == 'POST':
164+
body = await request.body()
165+
if body:
166+
Parse(body, params)
167+
168+
params.id = request.path_params['id']
163169
async for event in self.request_handler.on_subscribe_to_task(
164-
SubscribeToTaskRequest(id=task_id), context
170+
params, context
165171
):
166172
yield MessageToDict(proto_utils.to_stream_response(event))
167173

tests/client/transports/__init__.py

Whitespace-only changes.

tests/client/transports/test_rest_client.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from httpx_sse import EventSource, ServerSentEvent
1010

1111
from a2a.client import create_text_message_object
12+
from a2a.client.client import ClientCallContext
1213
from a2a.client.errors import A2AClientError
1314
from a2a.client.transports.rest import RestTransport
1415
from a2a.extensions.common import HTTP_EXTENSION_HEADER
@@ -162,7 +163,6 @@ async def test_send_message_with_timeout_context(
162163
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
163164
):
164165
"""Test that send_message passes context timeout to build_request."""
165-
from a2a.client.client import ClientCallContext
166166

167167
client = RestTransport(
168168
httpx_client=mock_httpx_client,
@@ -258,7 +258,6 @@ async def test_send_message_with_default_extensions(
258258
mock_response.status_code = 200
259259
mock_httpx_client.send.return_value = mock_response
260260

261-
from a2a.client.client import ClientCallContext
262261

263262
context = ClientCallContext(
264263
service_parameters={
@@ -302,7 +301,6 @@ async def test_send_message_streaming_with_new_extensions(
302301
mock_event_source
303302
)
304303

305-
from a2a.client.client import ClientCallContext
306304

307305
context = ClientCallContext(
308306
service_parameters={
@@ -404,7 +402,6 @@ async def test_get_card_with_extended_card_support_with_extensions(
404402

405403
request = GetExtendedAgentCardRequest()
406404

407-
from a2a.client.client import ClientCallContext
408405

409406
context = ClientCallContext(
410407
service_parameters={HTTP_EXTENSION_HEADER: extensions_str}
@@ -419,7 +416,6 @@ async def test_get_card_with_extended_card_support_with_extensions(
419416
await client.get_extended_agent_card(request, context=context)
420417

421418
mock_execute_request.assert_called_once()
422-
# _execute_request(method, target, tenant, context)
423419
call_args = mock_execute_request.call_args
424420
assert (
425421
call_args[1].get('context') == context or call_args[0][3] == context
@@ -694,7 +690,8 @@ async def test_rest_get_task_prepend_empty_tenant(
694690
)
695691
@pytest.mark.asyncio
696692
@patch('a2a.client.transports.http_helpers.aconnect_sse')
697-
async def test_rest_streaming_methods_prepend_tenant(
693+
async def test_rest_streaming_methods_prepend_tenant( # noqa: PLR0913
694+
698695
self,
699696
mock_aconnect_sse,
700697
method_name,
@@ -735,10 +732,7 @@ async def empty_aiter():
735732
args, kwargs = mock_aconnect_sse.call_args
736733
# method is 2nd positional argument
737734
assert args[1] == 'POST'
738-
if method_name == 'subscribe':
739-
assert kwargs.get('json') is None
740-
else:
741-
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
735+
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
742736

743737
# url is 3rd positional argument in aconnect_sse(client, method, url, ...)
744738
assert args[2] == f'http://agent.example.com/api{expected_path}'

tests/compat/v0_3/test_rest_transport.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ async def test_compat_rest_transport_subscribe_post_works_no_retry(transport):
270270

271271
async def mock_stream(method, path, context=None, json=None):
272272
assert method == 'POST'
273-
assert json is None
273+
if method == 'POST':
274+
assert json == {'id': 'task-123'}
275+
else:
276+
assert json is None
274277
task = Task(id='task-123')
275278
task.status.message.role = Role.ROLE_AGENT
276279
yield StreamResponse(task=task)
@@ -298,7 +301,10 @@ async def mock_stream(method, path, context=None, json=None):
298301
nonlocal call_count
299302
call_count += 1
300303
if method == 'POST':
301-
assert json is None
304+
if method == 'POST':
305+
assert json == {'id': 'task-123'}
306+
else:
307+
assert json is None
302308
create_405_error()
303309
if method == 'GET':
304310
assert json is None
@@ -334,8 +340,8 @@ async def test_compat_rest_transport_subscribe_post_405_get_405_fails(
334340
async def mock_stream(method, path, context=None, json=None):
335341
method_count[method] = method_count.get(method, 0) + 1
336342
if method == 'POST':
337-
assert json is None
338-
elif method == 'GET':
343+
assert json == {'id': 'task-123'}
344+
else:
339345
assert json is None
340346
# To make it an async generator even when it raises
341347
if False:
@@ -370,7 +376,10 @@ async def mock_stream(method, path, context=None, json=None):
370376
nonlocal call_count
371377
call_count += 1
372378
assert method == 'POST'
373-
assert json is None
379+
if method == 'POST':
380+
assert json == {'id': 'task-123'}
381+
else:
382+
assert json is None
374383
if False:
375384
yield
376385
create_500_error()
@@ -435,11 +444,10 @@ async def mock_stream(method, path, context=None, json=None):
435444
nonlocal call_count
436445
call_count += 1
437446
assert method == 'POST'
438-
assert json is None
447+
assert json == {'id': 'task-123'}
439448
if False:
440449
yield
441450
create_405_error()
442-
443451
transport._send_stream_request = mock_stream
444452

445453
req = SubscribeToTaskRequest(id='task-123')

tests/server/apps/rest/test_rest_fastapi_app.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import logging
21
import json
2+
import logging
33

44
from typing import Any
55
from unittest.mock import MagicMock
@@ -438,7 +438,7 @@ async def mock_stream_response():
438438
async def test_subscribe_to_task_post(
439439
streaming_client: AsyncClient, request_handler: MagicMock
440440
) -> None:
441-
"""Test that POST /tasks/{id}:subscribe works."""
441+
"""Test that POST /tasks/{id}:subscribe works and parses body."""
442442

443443
async def mock_stream_response():
444444
yield Task(
@@ -449,8 +449,11 @@ async def mock_stream_response():
449449

450450
request_handler.on_subscribe_to_task.return_value = mock_stream_response()
451451

452+
request = a2a_pb2.SubscribeToTaskRequest(id='task-1')
453+
452454
response = await streaming_client.post(
453455
'/tasks/task-1:subscribe',
456+
json=json_format.MessageToDict(request),
454457
headers={'Accept': 'text/event-stream'},
455458
)
456459

@@ -592,7 +595,7 @@ def extended_card_modifier(self) -> MagicMock:
592595
('/message:send', 'POST', 'on_message_send', {'message': {}}),
593596
('/tasks/1:cancel', 'POST', 'on_cancel_task', None),
594597
('/tasks/1:subscribe', 'GET', 'on_subscribe_to_task', None),
595-
('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', None),
598+
('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', {'id': '1'}),
596599
('/tasks/1', 'GET', 'on_get_task', None),
597600
('/tasks', 'GET', 'on_list_tasks', None),
598601
(

0 commit comments

Comments
 (0)