Skip to content

Commit c842e3e

Browse files
committed
Subscribe post fix v2.
1 parent f41dc7b commit c842e3e

6 files changed

Lines changed: 99 additions & 48 deletions

File tree

src/a2a/client/transports/rest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ async def subscribe(
262262
f'/tasks/{request.id}:subscribe',
263263
request.tenant,
264264
context=context,
265-
json=MessageToDict(request),
266265
):
267266
yield event
268267

src/a2a/compat/v0_3/rest_transport.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,14 @@ def __init__(
6464
httpx_client: httpx.AsyncClient,
6565
agent_card: AgentCard | None,
6666
url: str,
67+
subscribe_method_override: str | None = None,
6768
):
6869
"""Initializes the CompatRestTransport."""
6970
self.url = url.removesuffix('/')
7071
self.httpx_client = httpx_client
7172
self.agent_card = agent_card
72-
self._subscribe_method = 'POST'
73-
self._subscribe_retry_attempted = False
73+
self._subscribe_method_override = subscribe_method_override
74+
self._subscribe_auto_method_override = subscribe_method_override is None
7475

7576
async def send_message(
7677
self,
@@ -285,17 +286,12 @@ async def subscribe(
285286
on this transport instance. If both fail with 405, it will default back
286287
to POST for next calls but will not retry again.
287288
"""
288-
if self._subscribe_method == 'POST':
289-
json_body = MessageToDict(request, preserving_proto_field_name=True)
290-
else:
291-
json_body = None
292-
289+
subscribe_method = self._subscribe_method_override or 'POST'
293290
try:
294291
async for event in self._send_stream_request(
295-
self._subscribe_method,
292+
subscribe_method,
296293
f'/v1/tasks/{request.id}:subscribe',
297294
context=context,
298-
json=json_body,
299295
):
300296
yield event
301297
except A2AClientError as e:
@@ -305,12 +301,13 @@ async def subscribe(
305301
isinstance(cause, httpx.HTTPStatusError)
306302
and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED
307303
):
308-
if self._subscribe_retry_attempted:
309-
self._subscribe_method = 'POST'
304+
if self._subscribe_method_override:
305+
if self._subscribe_auto_method_override:
306+
self._subscribe_auto_method_override = False
307+
self._subscribe_method_override = 'POST'
310308
raise
311309
else:
312-
self._subscribe_method = 'GET'
313-
self._subscribe_retry_attempted = True
310+
self._subscribe_method_override = 'GET'
314311
async for event in self.subscribe(request, context=context):
315312
yield event
316313
else:

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,9 @@ async def on_subscribe_to_task(
159159
Yields:
160160
JSON serialized objects containing streaming events
161161
"""
162-
params = SubscribeToTaskRequest()
163-
if request.method == 'POST':
164-
body = await request.body()
165-
if body:
166-
Parse(body, params)
167-
168-
params.id = request.path_params['id']
162+
task_id = request.path_params['id']
169163
async for event in self.request_handler.on_subscribe_to_task(
170-
params, context
164+
SubscribeToTaskRequest(id=task_id), context
171165
):
172166
yield MessageToDict(proto_utils.to_stream_response(event))
173167

tests/client/transports/test_rest_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,10 @@ async def empty_aiter():
735735
args, kwargs = mock_aconnect_sse.call_args
736736
# method is 2nd positional argument
737737
assert args[1] == 'POST'
738-
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
738+
if method_name == 'subscribe':
739+
assert kwargs.get('json') is None
740+
else:
741+
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
739742

740743
# url is 3rd positional argument in aconnect_sse(client, method, url, ...)
741744
assert args[2] == f'http://agent.example.com/api{expected_path}'

tests/compat/v0_3/test_rest_transport.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ async def test_compat_rest_transport_subscribe_post_works_no_retry(transport):
270270

271271
async def mock_stream(method, path, context=None, json=None):
272272
assert method == 'POST'
273-
assert json == {'id': 'task-123'}
273+
assert json is None
274274
task = Task(id='task-123')
275275
task.status.message.role = Role.ROLE_AGENT
276276
yield StreamResponse(task=task)
@@ -284,8 +284,7 @@ async def mock_stream(method, path, context=None, json=None):
284284
expected_task = Task(id='task-123')
285285
expected_task.status.message.role = Role.ROLE_AGENT
286286
assert events[0] == StreamResponse(task=expected_task)
287-
assert transport._subscribe_method == 'POST'
288-
assert transport._subscribe_retry_attempted is False
287+
assert transport._subscribe_method_override is None
289288

290289

291290
@pytest.mark.asyncio
@@ -299,7 +298,7 @@ async def mock_stream(method, path, context=None, json=None):
299298
nonlocal call_count
300299
call_count += 1
301300
if method == 'POST':
302-
assert json == {'id': 'task-123'}
301+
assert json is None
303302
create_405_error()
304303
if method == 'GET':
305304
assert json is None
@@ -314,29 +313,28 @@ async def mock_stream(method, path, context=None, json=None):
314313

315314
assert len(events) == 1
316315
assert call_count == 2
317-
assert transport._subscribe_method == 'GET'
318-
assert transport._subscribe_retry_attempted is True
316+
assert transport._subscribe_method_override == 'GET'
319317

320318
# Second call should use GET directly
321319
call_count = 0
322320
events = [event async for event in transport.subscribe(req)]
323321
assert len(events) == 1
324322
assert call_count == 1 # Only GET called
325-
assert transport._subscribe_method == 'GET'
323+
assert transport._subscribe_method_override == 'GET'
326324

327325

328326
@pytest.mark.asyncio
329327
async def test_compat_rest_transport_subscribe_post_405_get_405_fails(
330328
transport,
331329
):
332330
"""Scenario: POST return 405, retry GET, return 405 - error. Second call is just POST."""
333-
call_count = 0
331+
332+
method_count = {}
334333

335334
async def mock_stream(method, path, context=None, json=None):
336-
nonlocal call_count
337-
call_count += 1
335+
method_count[method] = method_count.get(method, 0) + 1
338336
if method == 'POST':
339-
assert json == {'id': 'task-123'}
337+
assert json is None
340338
elif method == 'GET':
341339
assert json is None
342340
# To make it an async generator even when it raises
@@ -351,16 +349,16 @@ async def mock_stream(method, path, context=None, json=None):
351349
[event async for event in transport.subscribe(req)]
352350

353351
assert '405' in str(exc_info.value)
354-
assert call_count == 2 # Tried POST then GET
355-
assert transport._subscribe_method == 'POST'
356-
assert transport._subscribe_retry_attempted is True
352+
assert transport._subscribe_method_override == 'POST'
353+
assert method_count == {'POST': 1, 'GET': 1}
354+
assert transport._subscribe_auto_method_override is False
357355

358356
# Second call should try POST directly and fail without retry
359-
call_count = 0
360357
with pytest.raises(A2AClientError):
361358
[event async for event in transport.subscribe(req)]
362-
assert call_count == 1
363-
assert transport._subscribe_method == 'POST'
359+
assert transport._subscribe_auto_method_override is False
360+
assert transport._subscribe_method_override == 'POST'
361+
assert method_count == {'POST': 2, 'GET': 1}
364362

365363

366364
@pytest.mark.asyncio
@@ -372,7 +370,7 @@ async def mock_stream(method, path, context=None, json=None):
372370
nonlocal call_count
373371
call_count += 1
374372
assert method == 'POST'
375-
assert json == {'id': 'task-123'}
373+
assert json is None
376374
if False:
377375
yield
378376
create_500_error()
@@ -385,8 +383,71 @@ async def mock_stream(method, path, context=None, json=None):
385383

386384
assert '500' in str(exc_info.value)
387385
assert call_count == 1 # No retry on 500
388-
assert transport._subscribe_method == 'POST'
389-
assert transport._subscribe_retry_attempted is False
386+
assert transport._subscribe_method_override is None
387+
388+
389+
@pytest.mark.asyncio
390+
async def test_compat_rest_transport_subscribe_method_override_avoids_retry_get(
391+
mock_httpx_client, agent_card
392+
):
393+
"""Scenario: Init with GET override, server returns 405, no automatic retry."""
394+
transport = CompatRestTransport(
395+
httpx_client=mock_httpx_client,
396+
agent_card=agent_card,
397+
url='http://example.com',
398+
subscribe_method_override='GET',
399+
)
400+
call_count = 0
401+
402+
async def mock_stream(method, path, context=None, json=None):
403+
nonlocal call_count
404+
call_count += 1
405+
assert method == 'GET'
406+
assert json is None
407+
if False:
408+
yield
409+
create_405_error()
410+
411+
transport._send_stream_request = mock_stream
412+
413+
req = SubscribeToTaskRequest(id='task-123')
414+
with pytest.raises(A2AClientError) as exc_info:
415+
[event async for event in transport.subscribe(req)]
416+
417+
assert '405' in str(exc_info.value)
418+
assert call_count == 1
419+
420+
421+
@pytest.mark.asyncio
422+
async def test_compat_rest_transport_subscribe_method_override_avoids_retry_post(
423+
mock_httpx_client, agent_card
424+
):
425+
"""Scenario: Init with POST override, server returns 405, no automatic retry."""
426+
transport = CompatRestTransport(
427+
httpx_client=mock_httpx_client,
428+
agent_card=agent_card,
429+
url='http://example.com',
430+
subscribe_method_override='POST',
431+
)
432+
call_count = 0
433+
434+
async def mock_stream(method, path, context=None, json=None):
435+
nonlocal call_count
436+
call_count += 1
437+
assert method == 'POST'
438+
assert json is None
439+
if False:
440+
yield
441+
create_405_error()
442+
443+
transport._send_stream_request = mock_stream
444+
445+
req = SubscribeToTaskRequest(id='task-123')
446+
with pytest.raises(A2AClientError) as exc_info:
447+
[event async for event in transport.subscribe(req)]
448+
449+
assert '405' in str(exc_info.value)
450+
assert call_count == 1
390451

391452

392453
def test_compat_rest_transport_handle_http_error(transport):

tests/server/apps/rest/test_rest_fastapi_app.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ async def mock_stream_response():
438438
async def test_subscribe_to_task_post(
439439
streaming_client: AsyncClient, request_handler: MagicMock
440440
) -> None:
441-
"""Test that POST /tasks/{id}:subscribe works and parses body."""
441+
"""Test that POST /tasks/{id}:subscribe works."""
442442

443443
async def mock_stream_response():
444444
yield Task(
@@ -449,11 +449,8 @@ async def mock_stream_response():
449449

450450
request_handler.on_subscribe_to_task.return_value = mock_stream_response()
451451

452-
request = a2a_pb2.SubscribeToTaskRequest(id='task-1')
453-
454452
response = await streaming_client.post(
455453
'/tasks/task-1:subscribe',
456-
json=json_format.MessageToDict(request),
457454
headers={'Accept': 'text/event-stream'},
458455
)
459456

@@ -595,7 +592,7 @@ def extended_card_modifier(self) -> MagicMock:
595592
('/message:send', 'POST', 'on_message_send', {'message': {}}),
596593
('/tasks/1:cancel', 'POST', 'on_cancel_task', None),
597594
('/tasks/1:subscribe', 'GET', 'on_subscribe_to_task', None),
598-
('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', {'id': '1'}),
595+
('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', None),
599596
('/tasks/1', 'GET', 'on_get_task', None),
600597
('/tasks', 'GET', 'on_list_tasks', None),
601598
(

0 commit comments

Comments
 (0)