From 542e6687f756fee67fae65a533ffab0109324621 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 25 Mar 2026 12:05:54 +0000 Subject: [PATCH 1/7] wip --- src/a2a/server/request_handlers/__init__.py | 2 - src/a2a/server/routes/rest_dispatcher.py | 292 ++++++++++++++++++++ src/a2a/server/routes/rest_routes.py | 172 +++--------- src/a2a/utils/helpers.py | 47 ++-- 4 files changed, 364 insertions(+), 149 deletions(-) create mode 100644 src/a2a/server/routes/rest_dispatcher.py diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 43ebc8e25..9882dc2af 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -11,7 +11,6 @@ build_error_response, prepare_response_object, ) -from a2a.server.request_handlers.rest_handler import RESTHandler logger = logging.getLogger(__name__) @@ -41,7 +40,6 @@ def __init__(self, *args, **kwargs): 'DefaultRequestHandler', 'GrpcHandler', 'JSONRPCHandler', - 'RESTHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py new file mode 100644 index 000000000..b4b192543 --- /dev/null +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -0,0 +1,292 @@ +import json +import logging +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from google.protobuf.json_format import MessageToDict, Parse + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder +from a2a.types import a2a_pb2 +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + GetTaskPushNotificationConfigRequest, + SubscribeToTaskRequest, +) +from a2a.utils import constants, proto_utils +from a2a.utils.error_handlers import ( + rest_error_handler, + rest_stream_error_handler, +) +from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, + InvalidRequestError, + TaskNotFoundError, +) +from a2a.utils.helpers import maybe_await, validate, validate_version +from a2a.utils.telemetry import SpanKind, trace_class + + +if TYPE_CHECKING: + from sse_starlette.sse import EventSourceResponse + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + _package_starlette_installed = True +else: + try: + from sse_starlette.sse import EventSourceResponse + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + _package_starlette_installed = True + except ImportError: + EventSourceResponse = Any + Request = Any + JSONResponse = Any + Response = Any + + _package_starlette_installed = False + +logger = logging.getLogger(__name__) + +@trace_class(kind=SpanKind.SERVER) +class RestDispatcher: + """Dispatches incoming REST requests to the appropriate handler methods. + + Handles context building, routing to RequestHandler directly, and response formatting (JSON/SSE). + """ + + def __init__( # noqa: PLR0913 + self, + agent_card: AgentCard, + request_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + ) -> None: + """Initializes the RestDispatcher. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to delegate requests to. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + """ + if not _package_starlette_installed: + raise ImportError( + 'Packages `starlette` and `sse-starlette` are required to use the' + ' `RestDispatcher`. They can be added as a part of `a2a-sdk` ' + 'optional dependencies, `a2a-sdk[http-server]`.' + ) + + self.agent_card = agent_card + self.extended_agent_card = extended_agent_card + self.card_modifier = card_modifier + self.extended_card_modifier = extended_card_modifier + self._context_builder = context_builder or DefaultCallContextBuilder() + self.request_handler = request_handler + + def _build_call_context(self, request: Request) -> ServerCallContext: + call_context = self._context_builder.build(request) + if 'tenant' in request.path_params: + call_context.tenant = request.path_params['tenant'] + return call_context + + @rest_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def on_message_send(self, request: Request) -> Response: + """Handles the 'message/send' REST method.""" + context = self._build_call_context(request) + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + task_or_message = await self.request_handler.on_message_send(params, context) + if isinstance(task_or_message, a2a_pb2.Task): + response = a2a_pb2.SendMessageResponse(task=task_or_message) + else: + response = a2a_pb2.SendMessageResponse(message=task_or_message) + return JSONResponse(content=MessageToDict(response)) + + @rest_stream_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_message_send_stream(self, request: Request) -> EventSourceResponse: + """Handles the 'message/stream' REST method.""" + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + + context = self._build_call_context(request) + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + + stream = aiter(self.request_handler.on_message_send_stream(params, context)) + try: + first_event = await anext(stream) + except StopAsyncIteration: + return EventSourceResponse(iter([])) + + async def event_generator() -> AsyncIterator[str]: + yield json.dumps(MessageToDict(proto_utils.to_stream_response(first_event))) + async for event in stream: + yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) + + return EventSourceResponse(event_generator()) + + @rest_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def on_cancel_task(self, request: Request) -> Response: + """Handles the 'tasks/cancel' REST method.""" + context = self._build_call_context(request) + task_id = request.path_params['id'] + task = await self.request_handler.on_cancel_task(CancelTaskRequest(id=task_id), context) + if task: + return JSONResponse(content=MessageToDict(task)) + raise TaskNotFoundError + + @rest_stream_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_subscribe_to_task(self, request: Request) -> EventSourceResponse: + """Handles the 'SubscribeToTask' REST method.""" + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + + context = self._build_call_context(request) + task_id = request.path_params['id'] + + stream = aiter(self.request_handler.on_subscribe_to_task(SubscribeToTaskRequest(id=task_id), context)) + try: + first_event = await anext(stream) + except StopAsyncIteration: + return EventSourceResponse(iter([])) + + async def event_generator() -> AsyncIterator[str]: + yield json.dumps(MessageToDict(proto_utils.to_stream_response(first_event))) + async for event in stream: + yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) + + return EventSourceResponse(event_generator()) + + @rest_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def on_get_task(self, request: Request) -> Response: + """Handles the 'tasks/{id}' REST method.""" + context = self._build_call_context(request) + params = a2a_pb2.GetTaskRequest() + proto_utils.parse_params(request.query_params, params) + params.id = request.path_params['id'] + task = await self.request_handler.on_get_task(params, context) + if task: + return JSONResponse(content=MessageToDict(task)) + raise TaskNotFoundError + + @rest_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def get_push_notification(self, request: Request) -> Response: + """Handles the 'tasks/pushNotificationConfig/get' REST method.""" + context = self._build_call_context(request) + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = GetTaskPushNotificationConfigRequest(task_id=task_id, id=push_id) + config = await self.request_handler.on_get_task_push_notification_config(params, context) + return JSONResponse(content=MessageToDict(config)) + + @rest_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def delete_push_notification(self, request: Request) -> Response: + """Handles the 'tasks/pushNotificationConfig/delete' REST method.""" + context = self._build_call_context(request) + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.DeleteTaskPushNotificationConfigRequest(task_id=task_id, id=push_id) + await self.request_handler.on_delete_task_push_notification_config(params, context) + return JSONResponse(content={}) + + @rest_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda self: self.agent_card.capabilities.push_notifications, + 'Push notifications are not supported by the agent', + ) + async def set_push_notification(self, request: Request) -> Response: + """Handles the 'tasks/pushNotificationConfig/set' REST method.""" + context = self._build_call_context(request) + body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) + params.task_id = request.path_params['id'] + config = await self.request_handler.on_create_task_push_notification_config(params, context) + return JSONResponse(content=MessageToDict(config)) + + @rest_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def list_push_notifications(self, request: Request) -> Response: + """Handles the 'tasks/pushNotificationConfig/list' REST method.""" + context = self._build_call_context(request) + params = a2a_pb2.ListTaskPushNotificationConfigsRequest() + proto_utils.parse_params(request.query_params, params) + params.task_id = request.path_params['id'] + result = await self.request_handler.on_list_task_push_notification_configs(params, context) + return JSONResponse(content=MessageToDict(result)) + + @rest_error_handler + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def list_tasks(self, request: Request) -> Response: + """Handles the 'tasks/list' REST method.""" + context = self._build_call_context(request) + params = a2a_pb2.ListTasksRequest() + proto_utils.parse_params(request.query_params, params) + result = await self.request_handler.on_list_tasks(params, context) + return JSONResponse(content=MessageToDict(result, always_print_fields_with_no_presence=True)) + + @rest_error_handler + async def handle_authenticated_agent_card(self, request: Request) -> Response: + """Handles the 'extendedAgentCard' REST method.""" + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='Authenticated card not supported' + ) + card_to_serve = self.extended_agent_card or self.agent_card + + if self.extended_card_modifier: + context = self._build_call_context(request) + card_to_serve = await maybe_await( + self.extended_card_modifier(card_to_serve, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) + + return JSONResponse( + content=MessageToDict(card_to_serve, preserving_proto_field_name=True) + ) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 1923f038a..bab33ed3e 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -5,14 +5,19 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from typing import TYPE_CHECKING, Any -from google.protobuf.json_format import MessageToDict +from google.protobuf.json_format import MessageToDict, Parse from a2a.compat.v0_3.rest_adapter import REST03Adapter from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.rest_handler import RESTHandler from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder -from a2a.types.a2a_pb2 import AgentCard +from a2a.server.routes.rest_dispatcher import RestDispatcher +from a2a.types import a2a_pb2 +from a2a.types.a2a_pb2 import ( + AgentCard, + GetTaskPushNotificationConfigRequest, +) +from a2a.utils import constants, proto_utils from a2a.utils.error_handlers import ( rest_error_handler, rest_stream_error_handler, @@ -20,8 +25,10 @@ from a2a.utils.errors import ( ExtendedAgentCardNotConfiguredError, InvalidRequestError, + TaskNotFoundError, + UnsupportedOperationError, ) -from a2a.utils.helpers import maybe_await +from a2a.utils.helpers import maybe_await, validate_version if TYPE_CHECKING: @@ -94,7 +101,16 @@ def create_rest_routes( # noqa: PLR0913 'optional dependencies, `a2a-sdk[http-server]`.' ) - v03_routes = {} + dispatcher = RestDispatcher( + agent_card=agent_card, + request_handler=request_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + ) + + routes: list[BaseRoute] = [] if enable_v0_3_compat: v03_adapter = REST03Adapter( agent_card=agent_card, @@ -105,139 +121,34 @@ def create_rest_routes( # noqa: PLR0913 extended_card_modifier=extended_card_modifier, ) v03_routes = v03_adapter.routes() - - routes: list[BaseRoute] = [] - for (path, method), endpoint in v03_routes.items(): - routes.append( - Route( - path=f'{path_prefix}{path}', - endpoint=endpoint, - methods=[method], - ) - ) - - handler = RESTHandler( - agent_card=agent_card, request_handler=request_handler - ) - _context_builder = context_builder or DefaultCallContextBuilder() - - def _build_call_context(request: 'Request') -> ServerCallContext: - call_context = _context_builder.build(request) - if 'tenant' in request.path_params: - call_context.tenant = request.path_params['tenant'] - return call_context - - @rest_error_handler - async def _handle_request( - method: Callable[['Request', ServerCallContext], Awaitable[Any]], - request: 'Request', - ) -> 'Response': - - call_context = _build_call_context(request) - response = await method(request, call_context) - return JSONResponse(content=response) - - @rest_stream_error_handler - async def _handle_streaming_request( - method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], - request: Request, - ) -> EventSourceResponse: - # Pre-consume and cache the request body to prevent deadlock in streaming context - # This is required because Starlette's request.body() can only be consumed once, - # and attempting to consume it after EventSourceResponse starts causes deadlock - try: - await request.body() - except (ValueError, RuntimeError, OSError) as e: - raise InvalidRequestError( - message=f'Failed to pre-consume request body: {e}' - ) from e - - call_context = _build_call_context(request) - - # Eagerly fetch the first item from the stream so that errors raised - # before any event is yielded (e.g. validation, parsing, or handler - # failures) propagate here and are caught by - # @rest_stream_error_handler, which returns a JSONResponse with - # the correct HTTP status code instead of starting an SSE stream. - # Without this, the error would be raised after SSE headers are - # already sent, and the client would see a broken stream instead - # of a proper error response. - stream = aiter(method(request, call_context)) - try: - first_item = await anext(stream) - except StopAsyncIteration: - return EventSourceResponse(iter([])) - - async def event_generator() -> AsyncIterator[str]: - yield json.dumps(first_item) - async for item in stream: - yield json.dumps(item) - - return EventSourceResponse(event_generator()) - - async def _handle_authenticated_agent_card( - request: 'Request', call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - if not agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' + for (path, method), endpoint in v03_routes.items(): + routes.append( + Route( + path=f'{path_prefix}{path}', + endpoint=endpoint, + methods=[method], + ) ) - card_to_serve = extended_agent_card or agent_card - if extended_card_modifier: - # Re-generate context if none passed to replicate RESTAdapter exact logic - context = call_context or _build_call_context(request) - card_to_serve = await maybe_await( - extended_card_modifier(card_to_serve, context) - ) - elif card_modifier: - card_to_serve = await maybe_await(card_modifier(card_to_serve)) - - return MessageToDict(card_to_serve, preserving_proto_field_name=True) - - # Dictionary of routes, mapping to bound helper methods - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { - ('/message:send', 'POST'): functools.partial( - _handle_request, handler.on_message_send - ), - ('/message:stream', 'POST'): functools.partial( - _handle_streaming_request, - handler.on_message_send_stream, - ), - ('/tasks/{id}:cancel', 'POST'): functools.partial( - _handle_request, handler.on_cancel_task - ), - ('/tasks/{id}:subscribe', 'GET'): functools.partial( - _handle_streaming_request, - handler.on_subscribe_to_task, - ), - ('/tasks/{id}:subscribe', 'POST'): functools.partial( - _handle_streaming_request, - handler.on_subscribe_to_task, - ), - ('/tasks/{id}', 'GET'): functools.partial( - _handle_request, handler.on_get_task - ), + base_routes = { + ('/message:send', 'POST'): dispatcher.on_message_send, + ('/message:stream', 'POST'): dispatcher.on_message_send_stream, + ('/tasks/{id}:cancel', 'POST'): dispatcher.on_cancel_task, + ('/tasks/{id}:subscribe', 'GET'): dispatcher.on_subscribe_to_task, + ('/tasks/{id}:subscribe', 'POST'): dispatcher.on_subscribe_to_task, + ('/tasks/{id}', 'GET'): dispatcher.on_get_task, ( '/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET', - ): functools.partial(_handle_request, handler.get_push_notification), + ): dispatcher.get_push_notification, ( '/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE', - ): functools.partial(_handle_request, handler.delete_push_notification), - ('/tasks/{id}/pushNotificationConfigs', 'POST'): functools.partial( - _handle_request, handler.set_push_notification - ), - ('/tasks/{id}/pushNotificationConfigs', 'GET'): functools.partial( - _handle_request, handler.list_push_notifications - ), - ('/tasks', 'GET'): functools.partial( - _handle_request, handler.list_tasks - ), - ('/extendedAgentCard', 'GET'): functools.partial( - _handle_request, _handle_authenticated_agent_card - ), + ): dispatcher.delete_push_notification, + ('/tasks/{id}/pushNotificationConfigs', 'POST'): dispatcher.set_push_notification, + ('/tasks/{id}/pushNotificationConfigs', 'GET'): dispatcher.list_push_notifications, + ('/tasks', 'GET'): dispatcher.list_tasks, + ('/extendedAgentCard', 'GET'): dispatcher.handle_authenticated_agent_card, } base_route_objects = [] @@ -253,3 +164,4 @@ async def _handle_authenticated_agent_card( routes.append(Mount(path='/{tenant}', routes=base_route_objects)) return routes + diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index e5b37e5f4..266fe0428 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -338,23 +338,36 @@ def _get_actual_version( context = arg break - if context is None: - # If no context is found, we can't validate the version. - # In a real scenario, this shouldn't happen for properly routed requests. - # We default to the expected version to allow test call to proceed. - return expected_version - - headers = context.state.get('headers', {}) - # Header names are usually case-insensitive in most frameworks, but dict lookup is case-sensitive. - # We check both standard and lowercase versions. - actual_version = headers.get( - constants.VERSION_HEADER - ) or headers.get(constants.VERSION_HEADER.lower()) - - if not actual_version: - return constants.PROTOCOL_VERSION_0_3 - - return str(actual_version) + if context is not None: + headers = context.state.get('headers', {}) + actual_version = headers.get( + constants.VERSION_HEADER + ) or headers.get(constants.VERSION_HEADER.lower()) + + if not actual_version: + return constants.PROTOCOL_VERSION_0_3 + + return str(actual_version) + + # Fallback to Request + request = kwargs.get('request') + if request is None: + for arg in args: + if hasattr(arg, 'headers') and hasattr(arg, 'path_params'): + request = arg + break + + if request is not None: + headers = dict(request.headers) + actual_version = headers.get( + constants.VERSION_HEADER + ) or headers.get(constants.VERSION_HEADER.lower()) + + if not actual_version: + return constants.PROTOCOL_VERSION_0_3 + return str(actual_version) + + return expected_version def _is_version_compatible(actual: str) -> bool: if actual == expected_version: From d60e25934808836134b044e5d12f25023ce08fcf Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 25 Mar 2026 16:01:52 +0000 Subject: [PATCH 2/7] fix --- src/a2a/server/routes/rest_dispatcher.py | 318 +++++++++++++------- src/a2a/server/routes/rest_routes.py | 38 +-- src/a2a/utils/helpers.py | 10 +- tests/server/routes/test_rest_dispatcher.py | 288 ++++++++++++++++++ 4 files changed, 514 insertions(+), 140 deletions(-) create mode 100644 tests/server/routes/test_rest_dispatcher.py diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index b4b192543..768315086 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -1,7 +1,8 @@ import json import logging + from collections.abc import AsyncIterator, Awaitable, Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar from google.protobuf.json_format import MessageToDict, Parse @@ -52,6 +53,9 @@ logger = logging.getLogger(__name__) +TResponse = TypeVar('TResponse') + + @trace_class(kind=SpanKind.SERVER) class RestDispatcher: """Dispatches incoming REST requests to the appropriate handler methods. @@ -108,29 +112,24 @@ def _build_call_context(self, request: Request) -> ServerCallContext: call_context.tenant = request.path_params['tenant'] return call_context - @rest_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_message_send(self, request: Request) -> Response: - """Handles the 'message/send' REST method.""" + async def _handle_non_streaming( + self, + request: Request, + handler_func: Callable[[ServerCallContext], Awaitable[TResponse]], + ) -> TResponse: + """Centralized error handling and context management for unary calls.""" context = self._build_call_context(request) - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - task_or_message = await self.request_handler.on_message_send(params, context) - if isinstance(task_or_message, a2a_pb2.Task): - response = a2a_pb2.SendMessageResponse(task=task_or_message) - else: - response = a2a_pb2.SendMessageResponse(message=task_or_message) - return JSONResponse(content=MessageToDict(response)) + return await handler_func(context) - @rest_stream_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream(self, request: Request) -> EventSourceResponse: - """Handles the 'message/stream' REST method.""" + async def _handle_streaming( + self, + request: Request, + handler_func: Callable[[ServerCallContext], AsyncIterator[Any]], + ) -> EventSourceResponse: + """Centralized error handling and context management for streaming calls.""" + # Pre-consume and cache the request body to prevent deadlock in streaming context + # This is required because Starlette's request.body() can only be consumed once, + # and attempting to consume it after EventSourceResponse starts causes deadlock try: await request.body() except (ValueError, RuntimeError, OSError) as e: @@ -139,139 +138,234 @@ async def on_message_send_stream(self, request: Request) -> EventSourceResponse: ) from e context = self._build_call_context(request) - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - stream = aiter(self.request_handler.on_message_send_stream(params, context)) + # Eagerly fetch the first item from the stream so that errors raised + # before any event is yielded (e.g. validation, parsing, or handler + # failures) propagate here and are caught by + # @rest_stream_error_handler, which returns a JSONResponse with + # the correct HTTP status code instead of starting an SSE stream. + # Without this, the error would be raised after SSE headers are + # already sent, and the client would see a broken stream instead + stream = aiter(handler_func(context)) try: - first_event = await anext(stream) + first_item = await anext(stream) except StopAsyncIteration: return EventSourceResponse(iter([])) async def event_generator() -> AsyncIterator[str]: - yield json.dumps(MessageToDict(proto_utils.to_stream_response(first_event))) - async for event in stream: - yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) + yield json.dumps(first_item) + async for item in stream: + yield json.dumps(item) return EventSourceResponse(event_generator()) @rest_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) + async def on_message_send(self, request: Request) -> Response: + """Handles the 'message/send' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.SendMessageResponse: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + task_or_message = await self.request_handler.on_message_send( + params, context + ) + if isinstance(task_or_message, a2a_pb2.Task): + return a2a_pb2.SendMessageResponse(task=task_or_message) + return a2a_pb2.SendMessageResponse(message=task_or_message) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) + + @rest_stream_error_handler + async def on_message_send_stream( + self, request: Request + ) -> EventSourceResponse: + """Handles the 'message/stream' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda _: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def _handler( + context: ServerCallContext, + ) -> AsyncIterator[dict[str, Any]]: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + async for event in self.request_handler.on_message_send_stream( + params, context + ): + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) + + return await self._handle_streaming(request, _handler) + + @rest_error_handler async def on_cancel_task(self, request: Request) -> Response: """Handles the 'tasks/cancel' REST method.""" - context = self._build_call_context(request) - task_id = request.path_params['id'] - task = await self.request_handler.on_cancel_task(CancelTaskRequest(id=task_id), context) - if task: - return JSONResponse(content=MessageToDict(task)) - raise TaskNotFoundError + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler(context: ServerCallContext) -> a2a_pb2.Task: + task_id = request.path_params['id'] + task = await self.request_handler.on_cancel_task( + CancelTaskRequest(id=task_id), context + ) + if task: + return task + raise TaskNotFoundError + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_stream_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task(self, request: Request) -> EventSourceResponse: + async def on_subscribe_to_task( + self, request: Request + ) -> EventSourceResponse: """Handles the 'SubscribeToTask' REST method.""" - try: - await request.body() - except (ValueError, RuntimeError, OSError) as e: - raise InvalidRequestError( - message=f'Failed to pre-consume request body: {e}' - ) from e - - context = self._build_call_context(request) task_id = request.path_params['id'] - - stream = aiter(self.request_handler.on_subscribe_to_task(SubscribeToTaskRequest(id=task_id), context)) - try: - first_event = await anext(stream) - except StopAsyncIteration: - return EventSourceResponse(iter([])) - async def event_generator() -> AsyncIterator[str]: - yield json.dumps(MessageToDict(proto_utils.to_stream_response(first_event))) - async for event in stream: - yield json.dumps(MessageToDict(proto_utils.to_stream_response(event))) + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda _: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def _handler( + context: ServerCallContext, + ) -> AsyncIterator[dict[str, Any]]: + async for event in self.request_handler.on_subscribe_to_task( + SubscribeToTaskRequest(id=task_id), context + ): + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) - return EventSourceResponse(event_generator()) + return await self._handle_streaming(request, _handler) @rest_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) async def on_get_task(self, request: Request) -> Response: """Handles the 'tasks/{id}' REST method.""" - context = self._build_call_context(request) - params = a2a_pb2.GetTaskRequest() - proto_utils.parse_params(request.query_params, params) - params.id = request.path_params['id'] - task = await self.request_handler.on_get_task(params, context) - if task: - return JSONResponse(content=MessageToDict(task)) - raise TaskNotFoundError + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler(context: ServerCallContext) -> a2a_pb2.Task: + params = a2a_pb2.GetTaskRequest() + proto_utils.parse_params(request.query_params, params) + params.id = request.path_params['id'] + task = await self.request_handler.on_get_task(params, context) + if task: + return task + raise TaskNotFoundError + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) async def get_push_notification(self, request: Request) -> Response: """Handles the 'tasks/pushNotificationConfig/get' REST method.""" - context = self._build_call_context(request) - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = GetTaskPushNotificationConfigRequest(task_id=task_id, id=push_id) - config = await self.request_handler.on_get_task_push_notification_config(params, context) - return JSONResponse(content=MessageToDict(config)) + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.TaskPushNotificationConfig: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = GetTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + return ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) + ) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) async def delete_push_notification(self, request: Request) -> Response: """Handles the 'tasks/pushNotificationConfig/delete' REST method.""" - context = self._build_call_context(request) - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = a2a_pb2.DeleteTaskPushNotificationConfigRequest(task_id=task_id, id=push_id) - await self.request_handler.on_delete_task_push_notification_config(params, context) + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler(context: ServerCallContext) -> None: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + await self.request_handler.on_delete_task_push_notification_config( + params, context + ) + + await self._handle_non_streaming(request, _handler) return JSONResponse(content={}) @rest_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def set_push_notification(self, request: Request) -> Response: """Handles the 'tasks/pushNotificationConfig/set' REST method.""" - context = self._build_call_context(request) - body = await request.body() - params = a2a_pb2.TaskPushNotificationConfig() - Parse(body, params) - params.task_id = request.path_params['id'] - config = await self.request_handler.on_create_task_push_notification_config(params, context) - return JSONResponse(content=MessageToDict(config)) + + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda _: self.agent_card.capabilities.push_notifications, + 'Push notifications are not supported by the agent', + ) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.TaskPushNotificationConfig: + body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) + params.task_id = request.path_params['id'] + return await self.request_handler.on_create_task_push_notification_config( + params, context + ) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) async def list_push_notifications(self, request: Request) -> Response: """Handles the 'tasks/pushNotificationConfig/list' REST method.""" - context = self._build_call_context(request) - params = a2a_pb2.ListTaskPushNotificationConfigsRequest() - proto_utils.parse_params(request.query_params, params) - params.task_id = request.path_params['id'] - result = await self.request_handler.on_list_task_push_notification_configs(params, context) - return JSONResponse(content=MessageToDict(result)) + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.ListTaskPushNotificationConfigsResponse: + params = a2a_pb2.ListTaskPushNotificationConfigsRequest() + proto_utils.parse_params(request.query_params, params) + params.task_id = request.path_params['id'] + return await self.request_handler.on_list_task_push_notification_configs( + params, context + ) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_error_handler - @validate_version(constants.PROTOCOL_VERSION_1_0) async def list_tasks(self, request: Request) -> Response: """Handles the 'tasks/list' REST method.""" - context = self._build_call_context(request) - params = a2a_pb2.ListTasksRequest() - proto_utils.parse_params(request.query_params, params) - result = await self.request_handler.on_list_tasks(params, context) - return JSONResponse(content=MessageToDict(result, always_print_fields_with_no_presence=True)) + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.ListTasksResponse: + params = a2a_pb2.ListTasksRequest() + proto_utils.parse_params(request.query_params, params) + return await self.request_handler.on_list_tasks(params, context) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse( + content=MessageToDict( + response, always_print_fields_with_no_presence=True + ) + ) @rest_error_handler - async def handle_authenticated_agent_card(self, request: Request) -> Response: + async def handle_authenticated_agent_card( + self, request: Request + ) -> Response: """Handles the 'extendedAgentCard' REST method.""" if not self.agent_card.capabilities.extended_agent_card: raise ExtendedAgentCardNotConfiguredError( @@ -288,5 +382,7 @@ async def handle_authenticated_agent_card(self, request: Request) -> Response: card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return JSONResponse( - content=MessageToDict(card_to_serve, preserving_proto_field_name=True) + content=MessageToDict( + card_to_serve, preserving_proto_field_name=True + ) ) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index bab33ed3e..85dd01ff4 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -1,34 +1,16 @@ -import functools -import json import logging -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any -from google.protobuf.json_format import MessageToDict, Parse - from a2a.compat.v0_3.rest_adapter import REST03Adapter from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder +from a2a.server.routes import CallContextBuilder from a2a.server.routes.rest_dispatcher import RestDispatcher -from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( AgentCard, - GetTaskPushNotificationConfigRequest, -) -from a2a.utils import constants, proto_utils -from a2a.utils.error_handlers import ( - rest_error_handler, - rest_stream_error_handler, -) -from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, - InvalidRequestError, - TaskNotFoundError, - UnsupportedOperationError, ) -from a2a.utils.helpers import maybe_await, validate_version if TYPE_CHECKING: @@ -145,10 +127,19 @@ def create_rest_routes( # noqa: PLR0913 '/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE', ): dispatcher.delete_push_notification, - ('/tasks/{id}/pushNotificationConfigs', 'POST'): dispatcher.set_push_notification, - ('/tasks/{id}/pushNotificationConfigs', 'GET'): dispatcher.list_push_notifications, + ( + '/tasks/{id}/pushNotificationConfigs', + 'POST', + ): dispatcher.set_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs', + 'GET', + ): dispatcher.list_push_notifications, ('/tasks', 'GET'): dispatcher.list_tasks, - ('/extendedAgentCard', 'GET'): dispatcher.handle_authenticated_agent_card, + ( + '/extendedAgentCard', + 'GET', + ): dispatcher.handle_authenticated_agent_card, } base_route_objects = [] @@ -164,4 +155,3 @@ def create_rest_routes( # noqa: PLR0913 routes.append(Mount(path='/{tenant}', routes=base_route_objects)) return routes - diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 266fe0428..8d1dd315d 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -362,7 +362,7 @@ def _get_actual_version( actual_version = headers.get( constants.VERSION_HEADER ) or headers.get(constants.VERSION_HEADER.lower()) - + if not actual_version: return constants.PROTOCOL_VERSION_0_3 return str(actual_version) @@ -385,7 +385,7 @@ def _is_version_compatible(actual: str) -> bool: @functools.wraps(func) def async_gen_wrapper( - self: Any, *args: Any, **kwargs: Any + *args: Any, **kwargs: Any ) -> AsyncIterator[Any]: actual_version = _get_actual_version(args, kwargs) if not _is_version_compatible(actual_version): @@ -398,12 +398,12 @@ def async_gen_wrapper( message=f"A2A version '{actual_version}' is not supported by this handler. " f"Expected version '{expected_version}'." ) - return func(self, *args, **kwargs) + return func(*args, **kwargs) return cast('F', async_gen_wrapper) @functools.wraps(func) - async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: actual_version = _get_actual_version(args, kwargs) if not _is_version_compatible(actual_version): logger.warning( @@ -415,7 +415,7 @@ async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: message=f"A2A version '{actual_version}' is not supported by this handler. " f"Expected version '{expected_version}'." ) - return await func(self, *args, **kwargs) + return await func(*args, **kwargs) return cast('F', async_wrapper) diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py new file mode 100644 index 000000000..f91105ac7 --- /dev/null +++ b/tests/server/routes/test_rest_dispatcher.py @@ -0,0 +1,288 @@ +import json +from collections.abc import AsyncIterator +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from starlette.requests import Request +from starlette.responses import JSONResponse + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes import rest_dispatcher +from a2a.server.routes.rest_dispatcher import ( + DefaultCallContextBuilder, + RestDispatcher, +) +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + Message, + SendMessageResponse, + Task, + TaskPushNotificationConfig, + ListTasksResponse, + ListTaskPushNotificationConfigsResponse, +) +from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, + TaskNotFoundError, + UnsupportedOperationError, +) + + +@pytest.fixture +def mock_handler(): + handler = AsyncMock(spec=RequestHandler) + # Default success cases + handler.on_message_send.return_value = Message(message_id='test_msg') + handler.on_cancel_task.return_value = Task(id='test_task') + handler.on_get_task.return_value = Task(id='test_task') + handler.on_list_tasks.return_value = ListTasksResponse() + handler.on_get_task_push_notification_config.return_value = ( + TaskPushNotificationConfig(url='http://test') + ) + handler.on_create_task_push_notification_config.return_value = ( + TaskPushNotificationConfig(url='http://test') + ) + handler.on_list_task_push_notification_configs.return_value = ( + ListTaskPushNotificationConfigsResponse() + ) + + # Streaming mocks + async def mock_stream(*args, **kwargs) -> AsyncIterator[Task]: + yield Task(id='chunk1') + yield Task(id='chunk2') + + handler.on_message_send_stream.side_effect = mock_stream + handler.on_subscribe_to_task.side_effect = mock_stream + return handler + + +@pytest.fixture +def agent_card(): + card = MagicMock(spec=AgentCard) + card.capabilities = AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ) + return card + + +@pytest.fixture +def rest_dispatcher_instance(agent_card, mock_handler): + return RestDispatcher(agent_card=agent_card, request_handler=mock_handler) + + +from starlette.datastructures import Headers + + +def make_mock_request( + method: str = 'GET', + path_params: dict | None = None, + query_params: dict | None = None, + headers: dict | None = None, + body: bytes = b'{}', +) -> Request: + mock_req = MagicMock(spec=Request) + mock_req.method = method + mock_req.path_params = path_params or {} + mock_req.query_params = query_params or {} + + # Default valid headers for A2A + default_headers = {'a2a-version': '1.0'} + if headers: + default_headers.update(headers) + + mock_req.headers = Headers(default_headers) + mock_req.body = AsyncMock(return_value=body) + + # Needs to be able to build ServerCallContext, so provide .user and .auth etc. if needed + mock_req.user = MagicMock(is_authenticated=False) + mock_req.auth = None + return mock_req + + +class TestRestDispatcherInitialization: + @pytest.fixture(scope='class') + def mark_pkg_starlette_not_installed(self): + pkg_starlette_installed_flag = ( + rest_dispatcher._package_starlette_installed + ) + rest_dispatcher._package_starlette_installed = False + yield + rest_dispatcher._package_starlette_installed = ( + pkg_starlette_installed_flag + ) + + def test_missing_starlette_raises_importerror( + self, mark_pkg_starlette_not_installed, agent_card, mock_handler + ): + with pytest.raises( + ImportError, + match='Packages `starlette` and `sse-starlette` are required', + ): + RestDispatcher(agent_card=agent_card, request_handler=mock_handler) + + +@pytest.mark.asyncio +class TestRestDispatcherContextManagement: + async def test_build_call_context(self, rest_dispatcher_instance): + req = make_mock_request(path_params={'tenant': 'my-tenant'}) + context = rest_dispatcher_instance._build_call_context(req) + + assert isinstance(context, ServerCallContext) + assert context.tenant == 'my-tenant' + assert context.state['headers']['a2a-version'] == '1.0' + + +@pytest.mark.asyncio +class TestRestDispatcherEndpoints: + async def test_on_message_send_returns_message( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request(method='POST') + response = await rest_dispatcher_instance.on_message_send(req) + + assert isinstance(response, JSONResponse) + assert response.status_code == 200 + data = json.loads(response.body) + assert 'message' in data + + async def test_on_message_send_returns_task( + self, rest_dispatcher_instance, mock_handler + ): + mock_handler.on_message_send.return_value = Task(id='new_task') + req = make_mock_request(method='POST') + + response = await rest_dispatcher_instance.on_message_send(req) + assert response.status_code == 200 + data = json.loads(response.body) + assert 'task' in data + assert data['task']['id'] == 'new_task' + + async def test_on_cancel_task_success( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request(method='POST', path_params={'id': 'test_task'}) + response = await rest_dispatcher_instance.on_cancel_task(req) + + assert response.status_code == 200 + data = json.loads(response.body) + assert data['id'] == 'test_task' + + async def test_on_cancel_task_not_found( + self, rest_dispatcher_instance, mock_handler + ): + mock_handler.on_cancel_task.return_value = None + req = make_mock_request(method='POST', path_params={'id': 'test_task'}) + + response = await rest_dispatcher_instance.on_cancel_task(req) + assert response.status_code == 404 # TaskNotFoundError maps to 404 + + async def test_on_get_task_success( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request(method='GET', path_params={'id': 'test_task'}) + response = await rest_dispatcher_instance.on_get_task(req) + + assert response.status_code == 200 + data = json.loads(response.body) + assert data['id'] == 'test_task' + + async def test_on_get_task_not_found( + self, rest_dispatcher_instance, mock_handler + ): + mock_handler.on_get_task.return_value = None + req = make_mock_request( + method='GET', path_params={'id': 'missing_task'} + ) + + response = await rest_dispatcher_instance.on_get_task(req) + assert response.status_code == 404 + + async def test_list_tasks(self, rest_dispatcher_instance, mock_handler): + req = make_mock_request(method='GET') + response = await rest_dispatcher_instance.list_tasks(req) + assert response.status_code == 200 + + async def test_get_push_notification( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request( + method='GET', path_params={'id': 'task1', 'push_id': 'push1'} + ) + response = await rest_dispatcher_instance.get_push_notification(req) + assert response.status_code == 200 + data = json.loads(response.body) + assert data['url'] == 'http://test' + + async def test_delete_push_notification( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request( + method='DELETE', path_params={'id': 'task1', 'push_id': 'push1'} + ) + response = await rest_dispatcher_instance.delete_push_notification(req) + assert response.status_code == 200 + + async def test_set_push_notification_disabled_raises( + self, agent_card, mock_handler + ): + agent_card.capabilities.push_notifications = False + dispatcher = RestDispatcher( + agent_card=agent_card, request_handler=mock_handler + ) + req = make_mock_request(method='POST', path_params={'id': 'task1'}) + + response = await dispatcher.set_push_notification(req) + assert response.status_code == 400 # UnsupportedOperation maps to 400 + + async def test_handle_authenticated_agent_card( + self, rest_dispatcher_instance + ): + req = make_mock_request() + response = ( + await rest_dispatcher_instance.handle_authenticated_agent_card(req) + ) + assert response.status_code == 200 + + async def test_handle_authenticated_agent_card_unsupported( + self, agent_card, mock_handler + ): + agent_card.capabilities.extended_agent_card = False + dispatcher = RestDispatcher( + agent_card=agent_card, request_handler=mock_handler + ) + req = make_mock_request() + + response = await dispatcher.handle_authenticated_agent_card(req) + assert response.status_code == 400 + + +@pytest.mark.asyncio +class TestRestDispatcherStreaming: + async def test_on_message_send_stream_unsupported( + self, agent_card, mock_handler + ): + agent_card.capabilities.streaming = False + dispatcher = RestDispatcher( + agent_card=agent_card, request_handler=mock_handler + ) + req = make_mock_request(method='POST') + + response = await dispatcher.on_message_send_stream(req) + assert response.status_code == 400 + + async def test_on_subscribe_to_task_unsupported( + self, agent_card, mock_handler + ): + agent_card.capabilities.streaming = False + dispatcher = RestDispatcher( + agent_card=agent_card, request_handler=mock_handler + ) + req = make_mock_request(method='GET', path_params={'id': 't1'}) + + response = await dispatcher.on_subscribe_to_task(req) + assert response.status_code == 400 From 65d8e84683f69264e50feac38204ad3fac5acc61 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 25 Mar 2026 16:09:04 +0000 Subject: [PATCH 3/7] fix --- .../server/request_handlers/rest_handler.py | 334 ------------------ src/a2a/utils/helpers.py | 47 +-- 2 files changed, 17 insertions(+), 364 deletions(-) delete mode 100644 src/a2a/server/request_handlers/rest_handler.py diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py deleted file mode 100644 index af889d9df..000000000 --- a/src/a2a/server/request_handlers/rest_handler.py +++ /dev/null @@ -1,334 +0,0 @@ -import logging - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import ( - MessageToDict, - Parse, -) - - -if TYPE_CHECKING: - from starlette.requests import Request -else: - try: - from starlette.requests import Request - except ImportError: - Request = Any - - -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types import a2a_pb2 -from a2a.types.a2a_pb2 import ( - AgentCard, - CancelTaskRequest, - GetTaskPushNotificationConfigRequest, - SubscribeToTaskRequest, -) -from a2a.utils import constants, proto_utils -from a2a.utils.errors import TaskNotFoundError -from a2a.utils.helpers import ( - validate, - validate_version, -) -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.SERVER) -class RESTHandler: - """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses. - - This uses the protobuf definitions of the gRPC service as the source of truth. By - doing this, it ensures that this implementation and the gRPC transcoding - (via Envoy) are equivalent. This handler should be used if using the gRPC handler - with Envoy is not feasible for a given deployment solution. Use this handler - and a related application if you desire to ONLY server the RESTful API. - """ - - def __init__( - self, - agent_card: AgentCard, - request_handler: RequestHandler, - ): - """Initializes the RESTHandler. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - request_handler: The underlying `RequestHandler` instance to delegate requests to. - """ - self.agent_card = agent_card - self.request_handler = request_handler - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_message_send( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'message/send' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the result (Task or Message) - """ - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - task_or_message = await self.request_handler.on_message_send( - params, context - ) - if isinstance(task_or_message, a2a_pb2.Task): - response = a2a_pb2.SendMessageResponse(task=task_or_message) - else: - response = a2a_pb2.SendMessageResponse(message=task_or_message) - return MessageToDict(response) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream( - self, - request: Request, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - """Handles the 'message/stream' REST method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Yields: - JSON serialized objects containing streaming events - (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON - """ - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - async for event in self.request_handler.on_message_send_stream( - params, context - ): - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_cancel_task( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/cancel' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the updated Task - """ - task_id = request.path_params['id'] - task = await self.request_handler.on_cancel_task( - CancelTaskRequest(id=task_id), context - ) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task( - self, - request: Request, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - """Handles the 'SubscribeToTask' REST method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Yields: - JSON serialized objects containing streaming events - """ - task_id = request.path_params['id'] - async for event in self.request_handler.on_subscribe_to_task( - SubscribeToTaskRequest(id=task_id), context - ): - yield MessageToDict(proto_utils.to_stream_response(event)) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def get_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/get' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the config - """ - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = GetTaskPushNotificationConfigRequest( - task_id=task_id, - id=push_id, - ) - config = ( - await self.request_handler.on_get_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) - async def set_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/set' REST method. - - Requires the agent to support push notifications. - - Args: - request: The incoming `TaskPushNotificationConfig` object. - context: Context provided by the server. - - Returns: - A `dict` containing the config object. - - Raises: - UnsupportedOperationError: If push notifications are not supported by the agent - (due to the `@validate` decorator), A2AError if processing error is - found. - """ - body = await request.body() - params = a2a_pb2.TaskPushNotificationConfig() - Parse(body, params) - # Set the parent to the task resource name format - params.task_id = request.path_params['id'] - config = ( - await self.request_handler.on_create_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_get_task( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/{id}' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `Task` object containing the Task. - """ - params = a2a_pb2.GetTaskRequest() - proto_utils.parse_params(request.query_params, params) - params.id = request.path_params['id'] - task = await self.request_handler.on_get_task(params, context) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def delete_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/delete' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - An empty `dict` representing the empty response. - """ - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( - task_id=task_id, id=push_id - ) - await self.request_handler.on_delete_task_push_notification_config( - params, context - ) - return {} - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def list_tasks( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/list' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A list of `dict` representing the `Task` objects. - """ - params = a2a_pb2.ListTasksRequest() - proto_utils.parse_params(request.query_params, params) - - result = await self.request_handler.on_list_tasks(params, context) - return MessageToDict(result, always_print_fields_with_no_presence=True) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def list_push_notifications( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/list' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A list of `dict` representing the `TaskPushNotificationConfig` objects. - """ - params = a2a_pb2.ListTaskPushNotificationConfigsRequest() - proto_utils.parse_params(request.query_params, params) - params.task_id = request.path_params['id'] - - result = ( - await self.request_handler.on_list_task_push_notification_configs( - params, context - ) - ) - return MessageToDict(result) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 8d1dd315d..b1f23b405 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -338,36 +338,23 @@ def _get_actual_version( context = arg break - if context is not None: - headers = context.state.get('headers', {}) - actual_version = headers.get( - constants.VERSION_HEADER - ) or headers.get(constants.VERSION_HEADER.lower()) - - if not actual_version: - return constants.PROTOCOL_VERSION_0_3 - - return str(actual_version) - - # Fallback to Request - request = kwargs.get('request') - if request is None: - for arg in args: - if hasattr(arg, 'headers') and hasattr(arg, 'path_params'): - request = arg - break - - if request is not None: - headers = dict(request.headers) - actual_version = headers.get( - constants.VERSION_HEADER - ) or headers.get(constants.VERSION_HEADER.lower()) - - if not actual_version: - return constants.PROTOCOL_VERSION_0_3 - return str(actual_version) - - return expected_version + if context is None: + # If no context is found, we can't validate the version. + # In a real scenario, this shouldn't happen for properly routed requests. + # We default to the expected version to allow test call to proceed. + return expected_version + + headers = context.state.get('headers', {}) + # Header names are usually case-insensitive in most frameworks, but dict lookup is case-sensitive. + # We check both standard and lowercase versions. + actual_version = headers.get( + constants.VERSION_HEADER + ) or headers.get(constants.VERSION_HEADER.lower()) + + if not actual_version: + return constants.PROTOCOL_VERSION_0_3 + + return str(actual_version) def _is_version_compatible(actual: str) -> bool: if actual == expected_version: From f30eff8fd75f4e1cbb228b85f8c4ed4c7731eb0b Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 25 Mar 2026 16:18:03 +0000 Subject: [PATCH 4/7] fix --- src/a2a/server/request_handlers/__init__.py | 1 - tests/server/routes/test_rest_dispatcher.py | 43 +++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 3c46fbc3e..ef4f0a74f 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -38,7 +38,6 @@ def __init__(self, *args, **kwargs): __all__ = [ 'DefaultRequestHandler', 'GrpcHandler', - 'JSONRPCHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py index f91105ac7..6627fb4bd 100644 --- a/tests/server/routes/test_rest_dispatcher.py +++ b/tests/server/routes/test_rest_dispatcher.py @@ -139,6 +139,16 @@ async def test_build_call_context(self, rest_dispatcher_instance): @pytest.mark.asyncio class TestRestDispatcherEndpoints: + async def test_on_message_send_throws_error_for_unsupported_version( + self, rest_dispatcher_instance, mock_handler + ): + # 0.3 is currently not supported for direct message sending on RestDispatcher + req = make_mock_request(method='POST', headers={'a2a-version': '0.3.0'}) + response = await rest_dispatcher_instance.on_message_send(req) + + # VersionNotSupportedError maps to 400 Bad Request + assert response.status_code == 400 + async def test_on_message_send_returns_message( self, rest_dispatcher_instance, mock_handler ): @@ -286,3 +296,36 @@ async def test_on_subscribe_to_task_unsupported( response = await dispatcher.on_subscribe_to_task(req) assert response.status_code == 400 + + async def test_on_message_send_stream_success( + self, rest_dispatcher_instance + ): + req = make_mock_request(method='POST') + response = await rest_dispatcher_instance.on_message_send_stream(req) + + assert response.status_code == 200 + + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + assert len(chunks) == 2 + # sse-starlette yields strings or bytes formatted as Server-Sent Events + assert 'chunk1' in str(chunks[0]) + assert 'chunk2' in str(chunks[1]) + + async def test_on_subscribe_to_task_success( + self, rest_dispatcher_instance + ): + req = make_mock_request(method='GET', path_params={'id': 'test_task'}) + response = await rest_dispatcher_instance.on_subscribe_to_task(req) + + assert response.status_code == 200 + + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + assert len(chunks) == 2 + assert 'chunk1' in str(chunks[0]) + assert 'chunk2' in str(chunks[1]) From 7a583230a25c4940c9eadb4fca0186859b4f0c8f Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 25 Mar 2026 16:26:12 +0000 Subject: [PATCH 5/7] linter --- tests/server/routes/test_rest_dispatcher.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py index 6627fb4bd..bee9424f0 100644 --- a/tests/server/routes/test_rest_dispatcher.py +++ b/tests/server/routes/test_rest_dispatcher.py @@ -145,7 +145,7 @@ async def test_on_message_send_throws_error_for_unsupported_version( # 0.3 is currently not supported for direct message sending on RestDispatcher req = make_mock_request(method='POST', headers={'a2a-version': '0.3.0'}) response = await rest_dispatcher_instance.on_message_send(req) - + # VersionNotSupportedError maps to 400 Bad Request assert response.status_code == 400 @@ -302,26 +302,24 @@ async def test_on_message_send_stream_success( ): req = make_mock_request(method='POST') response = await rest_dispatcher_instance.on_message_send_stream(req) - + assert response.status_code == 200 - + chunks = [] async for chunk in response.body_iterator: chunks.append(chunk) - + assert len(chunks) == 2 # sse-starlette yields strings or bytes formatted as Server-Sent Events assert 'chunk1' in str(chunks[0]) assert 'chunk2' in str(chunks[1]) - async def test_on_subscribe_to_task_success( - self, rest_dispatcher_instance - ): + async def test_on_subscribe_to_task_success(self, rest_dispatcher_instance): req = make_mock_request(method='GET', path_params={'id': 'test_task'}) response = await rest_dispatcher_instance.on_subscribe_to_task(req) - + assert response.status_code == 200 - + chunks = [] async for chunk in response.body_iterator: chunks.append(chunk) From 2cdb64a804fe061cf3ba69580126d04fb97247ae Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 1 Apr 2026 11:39:50 +0000 Subject: [PATCH 6/7] refactor: update default context builder behavior for request handlers in rest_routes --- src/a2a/server/routes/rest_routes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 85dd01ff4..5d0cfcfc8 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -65,8 +65,8 @@ def create_rest_routes( # noqa: PLR0913 extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the request_handler. If None, no - ServerCallContext is passed. + ServerCallContext passed to the request_handler. If None the + DefaultCallContextBuilder is used. card_modifier: An optional callback to dynamically modify the public agent card before it is served. extended_card_modifier: An optional callback to dynamically modify From 434c5442d4c0b7f899aa6411a460675e281a9abd Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 1 Apr 2026 11:51:17 +0000 Subject: [PATCH 7/7] test: add scope attribute to mock request in rest dispatcher tests --- tests/server/routes/test_rest_dispatcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py index bee9424f0..b4233d0cd 100644 --- a/tests/server/routes/test_rest_dispatcher.py +++ b/tests/server/routes/test_rest_dispatcher.py @@ -101,6 +101,7 @@ def make_mock_request( # Needs to be able to build ServerCallContext, so provide .user and .auth etc. if needed mock_req.user = MagicMock(is_authenticated=False) mock_req.auth = None + mock_req.scope = {} return mock_req