Skip to content

Commit f41dc7b

Browse files
committed
Subscribe post.
1 parent ea7d3ad commit f41dc7b

9 files changed

Lines changed: 327 additions & 18 deletions

File tree

src/a2a/client/transports/rest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,11 @@ async def subscribe(
258258
) -> AsyncGenerator[StreamResponse]:
259259
"""Reconnects to get task updates."""
260260
async for event in self._send_stream_request(
261-
'GET',
261+
'POST',
262262
f'/tasks/{request.id}:subscribe',
263263
request.tenant,
264264
context=context,
265+
json=MessageToDict(request),
265266
):
266267
yield event
267268

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
163163
self._handle_streaming_request,
164164
self.handler.on_subscribe_to_task,
165165
),
166+
('/v1/tasks/{id}:subscribe', 'POST'): functools.partial(
167+
self._handle_streaming_request,
168+
self.handler.on_subscribe_to_task,
169+
),
166170
('/v1/tasks/{id}', 'GET'): functools.partial(
167171
self._handle_request, self.handler.on_get_task
168172
),

src/a2a/compat/v0_3/rest_transport.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import json
23
import logging
34

@@ -68,6 +69,8 @@ def __init__(
6869
self.url = url.removesuffix('/')
6970
self.httpx_client = httpx_client
7071
self.agent_card = agent_card
72+
self._subscribe_method = 'POST'
73+
self._subscribe_retry_attempted = False
7174

7275
async def send_message(
7376
self,
@@ -273,13 +276,45 @@ async def subscribe(
273276
*,
274277
context: ClientCallContext | None = None,
275278
) -> AsyncGenerator[StreamResponse]:
276-
"""Reconnects to get task updates."""
277-
async for event in self._send_stream_request(
278-
'GET',
279-
f'/v1/tasks/{request.id}:subscribe',
280-
context=context,
281-
):
282-
yield event
279+
"""Reconnects to get task updates.
280+
281+
This method implements backward compatibility logic for the subscribe
282+
endpoint. It first attempts to use POST, which is the official method
283+
for A2A subscribe endpoint. If the server returns 405 Method Not Allowed,
284+
it falls back to GET and remembers this preference for future calls
285+
on this transport instance. If both fail with 405, it will default back
286+
to POST for next calls but will not retry again.
287+
"""
288+
if self._subscribe_method == 'POST':
289+
json_body = MessageToDict(request, preserving_proto_field_name=True)
290+
else:
291+
json_body = None
292+
293+
try:
294+
async for event in self._send_stream_request(
295+
self._subscribe_method,
296+
f'/v1/tasks/{request.id}:subscribe',
297+
context=context,
298+
json=json_body,
299+
):
300+
yield event
301+
except A2AClientError as e:
302+
# Check for 405 Method Not Allowed in the cause (httpx.HTTPStatusError)
303+
cause = e.__cause__
304+
if (
305+
isinstance(cause, httpx.HTTPStatusError)
306+
and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED
307+
):
308+
if self._subscribe_retry_attempted:
309+
self._subscribe_method = 'POST'
310+
raise
311+
else:
312+
self._subscribe_method = 'GET'
313+
self._subscribe_retry_attempted = True
314+
async for event in self.subscribe(request, context=context):
315+
yield event
316+
else:
317+
raise
283318

284319
async def get_extended_agent_card(
285320
self,
@@ -311,7 +346,14 @@ async def close(self) -> None:
311346
def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
312347
"""Handles HTTP status errors and raises the appropriate A2AError."""
313348
try:
314-
error_data = e.response.json()
349+
with contextlib.suppress(httpx.StreamClosed):
350+
e.response.read()
351+
352+
try:
353+
error_data = e.response.json()
354+
except (json.JSONDecodeError, ValueError, httpx.ResponseNotRead):
355+
error_data = {}
356+
315357
error_type = error_data.get('type')
316358
message = error_data.get('message', str(e))
317359

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
237237
self._handle_streaming_request,
238238
self.handler.on_subscribe_to_task,
239239
),
240+
('/tasks/{id}:subscribe', 'POST'): functools.partial(
241+
self._handle_streaming_request,
242+
self.handler.on_subscribe_to_task,
243+
),
240244
('/tasks/{id}', 'GET'): functools.partial(
241245
self._handle_request, self.handler.on_get_task
242246
),

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,15 @@ async def on_subscribe_to_task(
159159
Yields:
160160
JSON serialized objects containing streaming events
161161
"""
162-
task_id = request.path_params['id']
162+
params = SubscribeToTaskRequest()
163+
if request.method == 'POST':
164+
body = await request.body()
165+
if body:
166+
Parse(body, params)
167+
168+
params.id = request.path_params['id']
163169
async for event in self.request_handler.on_subscribe_to_task(
164-
SubscribeToTaskRequest(id=task_id), context
170+
params, context
165171
):
166172
yield MessageToDict(proto_utils.to_stream_response(event))
167173

tests/client/transports/test_rest_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,8 +730,12 @@ async def empty_aiter():
730730
async for _ in method(request=request_obj):
731731
pass
732732

733-
# 4. Verify the URL
733+
# 4. Verify the URL and method
734734
mock_aconnect_sse.assert_called_once()
735-
args, _ = mock_aconnect_sse.call_args
735+
args, kwargs = mock_aconnect_sse.call_args
736+
# method is 2nd positional argument
737+
assert args[1] == 'POST'
738+
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
739+
736740
# url is 3rd positional argument in aconnect_sse(client, method, url, ...)
737741
assert args[2] == f'http://agent.example.com/api{expected_path}'

tests/compat/v0_3/test_rest_handler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,44 @@ async def mock_stream(*args, **kwargs):
186186
]
187187

188188

189+
@pytest.mark.anyio
190+
async def test_on_subscribe_to_task_post(
191+
rest_handler, mock_request, mock_context
192+
):
193+
mock_request.path_params = {'id': 'task-1'}
194+
mock_request.method = 'POST'
195+
request_body = {'name': 'tasks/task-1'}
196+
mock_request.body = AsyncMock(
197+
return_value=json.dumps(request_body).encode('utf-8')
198+
)
199+
200+
async def mock_stream(*args, **kwargs):
201+
yield types_v03.SendStreamingMessageSuccessResponse(
202+
id='req-1',
203+
result=types_v03.Message(
204+
message_id='msg-2',
205+
role='agent',
206+
parts=[types_v03.TextPart(text='Update')],
207+
),
208+
)
209+
210+
rest_handler.handler03.on_subscribe_to_task = MagicMock(
211+
side_effect=mock_stream
212+
)
213+
214+
results = [
215+
chunk
216+
async for chunk in rest_handler.on_subscribe_to_task(
217+
mock_request, mock_context
218+
)
219+
]
220+
221+
assert len(results) == 1
222+
rest_handler.handler03.on_subscribe_to_task.assert_called_once()
223+
called_req = rest_handler.handler03.on_subscribe_to_task.call_args[0][0]
224+
assert called_req.params.id == 'task-1'
225+
226+
189227
@pytest.mark.anyio
190228
async def test_get_push_notification(rest_handler, mock_request, mock_context):
191229
mock_request.path_params = {'id': 'task-1', 'push_id': 'push-1'}

tests/compat/v0_3/test_rest_transport.py

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
23
from unittest.mock import AsyncMock, MagicMock, patch
34

45
import httpx
@@ -232,14 +233,49 @@ async def mock_send_stream_request(*args, **kwargs):
232233
assert events[1] == StreamResponse(message=Message(message_id='msg-123'))
233234

234235

236+
def create_405_error():
237+
mock_response = MagicMock(spec=httpx.Response)
238+
mock_response.status_code = 405
239+
mock_response.json.return_value = {
240+
'type': 'MethodNotAllowed',
241+
'message': 'Method Not Allowed',
242+
}
243+
mock_request = MagicMock(spec=httpx.Request)
244+
mock_request.url = 'http://example.com/v1/tasks/task-123:subscribe'
245+
246+
status_error = httpx.HTTPStatusError(
247+
'405 Method Not Allowed', request=mock_request, response=mock_response
248+
)
249+
raise A2AClientError('HTTP Error 405') from status_error
250+
251+
252+
def create_500_error():
253+
mock_response = MagicMock(spec=httpx.Response)
254+
mock_response.status_code = 500
255+
mock_response.json.return_value = {
256+
'type': 'InternalError',
257+
'message': 'Internal Error',
258+
}
259+
mock_request = MagicMock(spec=httpx.Request)
260+
261+
status_error = httpx.HTTPStatusError(
262+
'500 Internal Error', request=mock_request, response=mock_response
263+
)
264+
raise A2AClientError('HTTP Error 500') from status_error
265+
266+
235267
@pytest.mark.asyncio
236-
async def test_compat_rest_transport_subscribe(transport):
237-
async def mock_send_stream_request(*args, **kwargs):
268+
async def test_compat_rest_transport_subscribe_post_works_no_retry(transport):
269+
"""Scenario: POST works, no retry."""
270+
271+
async def mock_stream(method, path, context=None, json=None):
272+
assert method == 'POST'
273+
assert json == {'id': 'task-123'}
238274
task = Task(id='task-123')
239275
task.status.message.role = Role.ROLE_AGENT
240276
yield StreamResponse(task=task)
241277

242-
transport._send_stream_request = mock_send_stream_request
278+
transport._send_stream_request = mock_stream
243279

244280
req = SubscribeToTaskRequest(id='task-123')
245281
events = [event async for event in transport.subscribe(req)]
@@ -248,6 +284,109 @@ async def mock_send_stream_request(*args, **kwargs):
248284
expected_task = Task(id='task-123')
249285
expected_task.status.message.role = Role.ROLE_AGENT
250286
assert events[0] == StreamResponse(task=expected_task)
287+
assert transport._subscribe_method == 'POST'
288+
assert transport._subscribe_retry_attempted is False
289+
290+
291+
@pytest.mark.asyncio
292+
async def test_compat_rest_transport_subscribe_post_405_retry_get_success(
293+
transport,
294+
):
295+
"""Scenario: POST returns 405, automatic retry GET. Second call uses GET directly."""
296+
call_count = 0
297+
298+
async def mock_stream(method, path, context=None, json=None):
299+
nonlocal call_count
300+
call_count += 1
301+
if method == 'POST':
302+
assert json == {'id': 'task-123'}
303+
create_405_error()
304+
if method == 'GET':
305+
assert json is None
306+
task = Task(id='task-123')
307+
task.status.message.role = Role.ROLE_AGENT
308+
yield StreamResponse(task=task)
309+
310+
transport._send_stream_request = mock_stream
311+
312+
req = SubscribeToTaskRequest(id='task-123')
313+
events = [event async for event in transport.subscribe(req)]
314+
315+
assert len(events) == 1
316+
assert call_count == 2
317+
assert transport._subscribe_method == 'GET'
318+
assert transport._subscribe_retry_attempted is True
319+
320+
# Second call should use GET directly
321+
call_count = 0
322+
events = [event async for event in transport.subscribe(req)]
323+
assert len(events) == 1
324+
assert call_count == 1 # Only GET called
325+
assert transport._subscribe_method == 'GET'
326+
327+
328+
@pytest.mark.asyncio
329+
async def test_compat_rest_transport_subscribe_post_405_get_405_fails(
330+
transport,
331+
):
332+
"""Scenario: POST return 405, retry GET, return 405 - error. Second call is just POST."""
333+
call_count = 0
334+
335+
async def mock_stream(method, path, context=None, json=None):
336+
nonlocal call_count
337+
call_count += 1
338+
if method == 'POST':
339+
assert json == {'id': 'task-123'}
340+
elif method == 'GET':
341+
assert json is None
342+
# To make it an async generator even when it raises
343+
if False:
344+
yield
345+
create_405_error()
346+
347+
transport._send_stream_request = mock_stream
348+
349+
req = SubscribeToTaskRequest(id='task-123')
350+
with pytest.raises(A2AClientError) as exc_info:
351+
[event async for event in transport.subscribe(req)]
352+
353+
assert '405' in str(exc_info.value)
354+
assert call_count == 2 # Tried POST then GET
355+
assert transport._subscribe_method == 'POST'
356+
assert transport._subscribe_retry_attempted is True
357+
358+
# Second call should try POST directly and fail without retry
359+
call_count = 0
360+
with pytest.raises(A2AClientError):
361+
[event async for event in transport.subscribe(req)]
362+
assert call_count == 1
363+
assert transport._subscribe_method == 'POST'
364+
365+
366+
@pytest.mark.asyncio
367+
async def test_compat_rest_transport_subscribe_post_500_no_retry(transport):
368+
"""Scenario: POST return 500, no automatic retry."""
369+
call_count = 0
370+
371+
async def mock_stream(method, path, context=None, json=None):
372+
nonlocal call_count
373+
call_count += 1
374+
assert method == 'POST'
375+
assert json == {'id': 'task-123'}
376+
if False:
377+
yield
378+
create_500_error()
379+
380+
transport._send_stream_request = mock_stream
381+
382+
req = SubscribeToTaskRequest(id='task-123')
383+
with pytest.raises(A2AClientError) as exc_info:
384+
[event async for event in transport.subscribe(req)]
385+
386+
assert '500' in str(exc_info.value)
387+
assert call_count == 1 # No retry on 500
388+
assert transport._subscribe_method == 'POST'
389+
assert transport._subscribe_retry_attempted is False
251390

252391

253392
def test_compat_rest_transport_handle_http_error(transport):

0 commit comments

Comments
 (0)