Skip to content

Commit 92bb946

Browse files
committed
refactor: improve JSON-RPC request handling by using method names for routing
1 parent 418a433 commit 92bb946

1 file changed

Lines changed: 44 additions & 26 deletions

File tree

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from abc import ABC, abstractmethod
88
from collections.abc import AsyncGenerator, Awaitable, Callable
9-
from typing import TYPE_CHECKING, Any
9+
from typing import TYPE_CHECKING, Any, cast
1010

1111
from google.protobuf.json_format import MessageToDict, ParseDict
1212
from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response
@@ -400,7 +400,7 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911,
400400
else:
401401
try:
402402
raw_result = await self._process_non_streaming_request(
403-
request_id, specific_request, call_context
403+
specific_request, call_context
404404
)
405405
handler_result = JSONRPC20Response(
406406
result=raw_result, _id=request_id
@@ -450,11 +450,17 @@ async def _process_streaming_request(
450450
An `AsyncGenerator` object to stream results to the client.
451451
"""
452452
stream: AsyncGenerator | None = None
453-
if isinstance(request_obj, SendMessageRequest):
453+
if (
454+
isinstance(request_obj, SendMessageRequest)
455+
and context.state.get('method') == 'SendStreamingMessage'
456+
):
454457
stream = self.request_handler.on_message_send_stream(
455458
request_obj, context
456459
)
457-
elif isinstance(request_obj, SubscribeToTaskRequest):
460+
elif (
461+
isinstance(request_obj, SubscribeToTaskRequest)
462+
and context.state.get('method') == 'SubscribeToTask'
463+
):
458464
stream = self.request_handler.on_subscribe_to_task(
459465
request_obj, context
460466
)
@@ -589,48 +595,60 @@ async def _handle_get_extended_agent_card(
589595
@validate_version(constants.PROTOCOL_VERSION_1_0)
590596
async def _process_non_streaming_request( # noqa: PLR0911
591597
self,
592-
request_id: str | int | None,
593598
request_obj: A2ARequest,
594599
context: ServerCallContext,
595600
) -> dict[str, Any] | None:
596601
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
597602
598603
Args:
599-
request_id: The ID of the request.
600604
request_obj: The proto request message.
601605
context: The ServerCallContext for the request.
602606
603607
Returns:
604608
A dict containing the result or error.
605609
"""
606-
match request_obj:
607-
case SendMessageRequest():
608-
return await self._handle_send_message(request_obj, context)
609-
case CancelTaskRequest():
610-
return await self._handle_cancel_task(request_obj, context)
611-
case GetTaskRequest():
612-
return await self._handle_get_task(request_obj, context)
613-
case ListTasksRequest():
614-
return await self._handle_list_tasks(request_obj, context)
615-
case TaskPushNotificationConfig():
610+
match context.state.get('method', None):
611+
case 'SendMessage':
612+
return await self._handle_send_message(
613+
cast('SendMessageRequest', request_obj), context
614+
)
615+
case 'CancelTask':
616+
return await self._handle_cancel_task(
617+
cast('CancelTaskRequest', request_obj), context
618+
)
619+
case 'GetTask':
620+
return await self._handle_get_task(
621+
cast('GetTaskRequest', request_obj), context
622+
)
623+
case 'ListTasks':
624+
return await self._handle_list_tasks(
625+
cast('ListTasksRequest', request_obj), context
626+
)
627+
case 'CreateTaskPushNotificationConfig':
616628
return await self._handle_create_task_push_notification_config(
617-
request_obj, context
629+
cast('TaskPushNotificationConfig', request_obj), context
618630
)
619-
case GetTaskPushNotificationConfigRequest():
631+
case 'GetTaskPushNotificationConfig':
620632
return await self._handle_get_task_push_notification_config(
621-
request_obj, context
633+
cast('GetTaskPushNotificationConfigRequest', request_obj),
634+
context,
622635
)
623-
case ListTaskPushNotificationConfigsRequest():
636+
case 'ListTaskPushNotificationConfigs':
624637
return await self._handle_list_task_push_notification_configs(
625-
request_obj, context
638+
cast('ListTaskPushNotificationConfigsRequest', request_obj),
639+
context,
626640
)
627-
case DeleteTaskPushNotificationConfigRequest():
628-
return await self._handle_delete_task_push_notification_config(
629-
request_obj, context
641+
case 'DeleteTaskPushNotificationConfig':
642+
await self._handle_delete_task_push_notification_config(
643+
cast(
644+
'DeleteTaskPushNotificationConfigRequest', request_obj
645+
),
646+
context,
630647
)
631-
case GetExtendedAgentCardRequest():
648+
return None
649+
case 'GetExtendedAgentCard':
632650
return await self._handle_get_extended_agent_card(
633-
request_obj, context
651+
cast('GetExtendedAgentCardRequest', request_obj), context
634652
)
635653
case _:
636654
logger.error(

0 commit comments

Comments
 (0)