Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 44 additions & 33 deletions src/a2a/server/routes/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from google.protobuf.json_format import MessageToDict, ParseDict
from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response
Expand Down Expand Up @@ -400,7 +400,7 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911,
else:
try:
raw_result = await self._process_non_streaming_request(
request_id, specific_request, call_context
specific_request, call_context
)
handler_result = JSONRPC20Response(
result=raw_result, _id=request_id
Expand Down Expand Up @@ -450,13 +450,13 @@ async def _process_streaming_request(
An `AsyncGenerator` object to stream results to the client.
"""
stream: AsyncGenerator | None = None
if isinstance(request_obj, SendMessageRequest):
if context.state.get('method') == 'SendStreamingMessage':
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated
stream = self.request_handler.on_message_send_stream(
request_obj, context
cast('SendMessageRequest', request_obj), context
)
elif isinstance(request_obj, SubscribeToTaskRequest):
elif context.state.get('method') == 'SubscribeToTask':
stream = self.request_handler.on_subscribe_to_task(
request_obj, context
cast('SubscribeToTaskRequest', request_obj), context
)
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated

if stream is None:
Expand Down Expand Up @@ -589,55 +589,66 @@ async def _handle_get_extended_agent_card(
@validate_version(constants.PROTOCOL_VERSION_1_0)
async def _process_non_streaming_request( # noqa: PLR0911
self,
request_id: str | int | None,
request_obj: A2ARequest,
context: ServerCallContext,
) -> dict[str, Any] | None:
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
"""Processes non-streaming requests.

Args:
request_id: The ID of the request.
request_obj: The proto request message.
context: The ServerCallContext for the request.

Returns:
A dict containing the result or error.
"""
match request_obj:
case SendMessageRequest():
return await self._handle_send_message(request_obj, context)
case CancelTaskRequest():
return await self._handle_cancel_task(request_obj, context)
case GetTaskRequest():
return await self._handle_get_task(request_obj, context)
case ListTasksRequest():
return await self._handle_list_tasks(request_obj, context)
case TaskPushNotificationConfig():
match context.state.get('method'):
case 'SendMessage':
return await self._handle_send_message(
cast('SendMessageRequest', request_obj), context
)
case 'CancelTask':
return await self._handle_cancel_task(
cast('CancelTaskRequest', request_obj), context
)
case 'GetTask':
return await self._handle_get_task(
cast('GetTaskRequest', request_obj), context
)
case 'ListTasks':
return await self._handle_list_tasks(
cast('ListTasksRequest', request_obj), context
)
case 'CreateTaskPushNotificationConfig':
return await self._handle_create_task_push_notification_config(
request_obj, context
cast('TaskPushNotificationConfig', request_obj), context
)
case GetTaskPushNotificationConfigRequest():
case 'GetTaskPushNotificationConfig':
return await self._handle_get_task_push_notification_config(
request_obj, context
cast('GetTaskPushNotificationConfigRequest', request_obj),
context,
)
case ListTaskPushNotificationConfigsRequest():
case 'ListTaskPushNotificationConfigs':
return await self._handle_list_task_push_notification_configs(
request_obj, context
cast('ListTaskPushNotificationConfigsRequest', request_obj),
context,
)
case DeleteTaskPushNotificationConfigRequest():
return await self._handle_delete_task_push_notification_config(
request_obj, context
case 'DeleteTaskPushNotificationConfig':
await self._handle_delete_task_push_notification_config(
cast(
'DeleteTaskPushNotificationConfigRequest', request_obj
),
context,
)
case GetExtendedAgentCardRequest():
return None
case 'GetExtendedAgentCard':
return await self._handle_get_extended_agent_card(
request_obj, context
cast('GetExtendedAgentCardRequest', request_obj), context
)
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated
case _:
logger.error(
'Unhandled validated request type: %s', type(request_obj)
)
method = context.state.get('method')
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated
logger.error('Unhandled method: %s', method)
raise UnsupportedOperationError(
message=f'Request type {type(request_obj).__name__} is unknown.'
message=f'Method {method} is not supported.'
)

def _create_response(
Expand Down
Loading
Loading