Skip to content

Commit 40ddcb9

Browse files
committed
Subscribe post fix v2.
1 parent 3e36c3c commit 40ddcb9

5 files changed

Lines changed: 9 additions & 15 deletions

File tree

src/a2a/client/transports/rest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ async def subscribe(
262262
f'/tasks/{request.id}:subscribe',
263263
request.tenant,
264264
context=context,
265-
json=MessageToDict(request),
266265
):
267266
yield event
268267

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,9 @@ async def on_subscribe_to_task(
159159
Yields:
160160
JSON serialized objects containing streaming events
161161
"""
162-
params = SubscribeToTaskRequest()
163-
if request.method == 'POST':
164-
body = await request.body()
165-
if body:
166-
Parse(body, params)
167-
168-
params.id = request.path_params['id']
162+
task_id = request.path_params['id']
169163
async for event in self.request_handler.on_subscribe_to_task(
170-
params, context
164+
SubscribeToTaskRequest(id=task_id), context
171165
):
172166
yield MessageToDict(proto_utils.to_stream_response(event))
173167

tests/client/transports/test_rest_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,10 @@ async def empty_aiter():
732732
args, kwargs = mock_aconnect_sse.call_args
733733
# method is 2nd positional argument
734734
assert args[1] == 'POST'
735-
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
735+
if method_name == 'subscribe':
736+
assert kwargs.get('json') is None
737+
else:
738+
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
736739

737740
# url is 3rd positional argument in aconnect_sse(client, method, url, ...)
738741
assert args[2] == f'http://agent.example.com/api{expected_path}'

tests/compat/v0_3/test_rest_transport.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ async def test_compat_rest_transport_subscribe_post_405_get_405_fails(
337337

338338
method_count = {}
339339

340+
method_count = {}
340341
async def mock_stream(method, path, context=None, json=None):
341342
method_count[method] = method_count.get(method, 0) + 1
342343
if method == 'POST':

tests/server/apps/rest/test_rest_fastapi_app.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ async def mock_stream_response():
438438
async def test_subscribe_to_task_post(
439439
streaming_client: AsyncClient, request_handler: MagicMock
440440
) -> None:
441-
"""Test that POST /tasks/{id}:subscribe works and parses body."""
441+
"""Test that POST /tasks/{id}:subscribe works."""
442442

443443
async def mock_stream_response():
444444
yield Task(
@@ -449,11 +449,8 @@ async def mock_stream_response():
449449

450450
request_handler.on_subscribe_to_task.return_value = mock_stream_response()
451451

452-
request = a2a_pb2.SubscribeToTaskRequest(id='task-1')
453-
454452
response = await streaming_client.post(
455453
'/tasks/task-1:subscribe',
456-
json=json_format.MessageToDict(request),
457454
headers={'Accept': 'text/event-stream'},
458455
)
459456

@@ -595,7 +592,7 @@ def extended_card_modifier(self) -> MagicMock:
595592
('/message:send', 'POST', 'on_message_send', {'message': {}}),
596593
('/tasks/1:cancel', 'POST', 'on_cancel_task', None),
597594
('/tasks/1:subscribe', 'GET', 'on_subscribe_to_task', None),
598-
('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', {'id': '1'}),
595+
('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', None),
599596
('/tasks/1', 'GET', 'on_get_task', None),
600597
('/tasks', 'GET', 'on_list_tasks', None),
601598
(

0 commit comments

Comments
 (0)