Skip to content

Commit 71cd2a5

Browse files
committed
fix switch
1 parent 1ce4368 commit 71cd2a5

1 file changed

Lines changed: 14 additions & 28 deletions

File tree

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from abc import ABC, abstractmethod
88
from collections.abc import AsyncGenerator, Awaitable, Callable
9-
from typing import TYPE_CHECKING, Any, cast
9+
from typing import TYPE_CHECKING, Any
1010

1111
from google.protobuf.json_format import MessageToDict, ParseDict
1212
from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response
@@ -31,7 +31,6 @@
3131
from a2a.server.request_handlers.response_helpers import (
3232
build_error_response,
3333
)
34-
from a2a.types import A2ARequest
3534
from a2a.types.a2a_pb2 import (
3635
AgentCard,
3736
CancelTaskRequest,
@@ -436,7 +435,7 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911,
436435
async def _process_streaming_request(
437436
self,
438437
request_id: str | int | None,
439-
request_obj: A2ARequest,
438+
request_obj: Any,
440439
context: ServerCallContext,
441440
) -> AsyncGenerator[dict[str, Any], None]:
442441
"""Processes streaming requests (SendStreamingMessage or SubscribeToTask).
@@ -452,11 +451,11 @@ async def _process_streaming_request(
452451
stream: AsyncGenerator | None = None
453452
if context.state.get('method') == 'SendStreamingMessage':
454453
stream = self.request_handler.on_message_send_stream(
455-
cast('SendMessageRequest', request_obj), context
454+
request_obj, context
456455
)
457456
elif context.state.get('method') == 'SubscribeToTask':
458457
stream = self.request_handler.on_subscribe_to_task(
459-
cast('SubscribeToTaskRequest', request_obj), context
458+
request_obj, context
460459
)
461460

462461
if stream is None:
@@ -589,7 +588,7 @@ async def _handle_get_extended_agent_card(
589588
@validate_version(constants.PROTOCOL_VERSION_1_0)
590589
async def _process_non_streaming_request( # noqa: PLR0911
591590
self,
592-
request_obj: A2ARequest,
591+
request_obj: Any,
593592
context: ServerCallContext,
594593
) -> dict[str, Any] | None:
595594
"""Processes non-streaming requests.
@@ -603,46 +602,33 @@ async def _process_non_streaming_request( # noqa: PLR0911
603602
"""
604603
match context.state.get('method'):
605604
case 'SendMessage':
606-
return await self._handle_send_message(
607-
cast('SendMessageRequest', request_obj), context
608-
)
605+
return await self._handle_send_message(request_obj, context)
609606
case 'CancelTask':
610-
return await self._handle_cancel_task(
611-
cast('CancelTaskRequest', request_obj), context
612-
)
607+
return await self._handle_cancel_task(request_obj, context)
613608
case 'GetTask':
614-
return await self._handle_get_task(
615-
cast('GetTaskRequest', request_obj), context
616-
)
609+
return await self._handle_get_task(request_obj, context)
617610
case 'ListTasks':
618-
return await self._handle_list_tasks(
619-
cast('ListTasksRequest', request_obj), context
620-
)
611+
return await self._handle_list_tasks(request_obj, context)
621612
case 'CreateTaskPushNotificationConfig':
622613
return await self._handle_create_task_push_notification_config(
623-
cast('TaskPushNotificationConfig', request_obj), context
614+
request_obj, context
624615
)
625616
case 'GetTaskPushNotificationConfig':
626617
return await self._handle_get_task_push_notification_config(
627-
cast('GetTaskPushNotificationConfigRequest', request_obj),
628-
context,
618+
request_obj, context
629619
)
630620
case 'ListTaskPushNotificationConfigs':
631621
return await self._handle_list_task_push_notification_configs(
632-
cast('ListTaskPushNotificationConfigsRequest', request_obj),
633-
context,
622+
request_obj, context
634623
)
635624
case 'DeleteTaskPushNotificationConfig':
636625
await self._handle_delete_task_push_notification_config(
637-
cast(
638-
'DeleteTaskPushNotificationConfigRequest', request_obj
639-
),
640-
context,
626+
request_obj, context
641627
)
642628
return None
643629
case 'GetExtendedAgentCard':
644630
return await self._handle_get_extended_agent_card(
645-
cast('GetExtendedAgentCardRequest', request_obj), context
631+
request_obj, context
646632
)
647633
case _:
648634
method = context.state.get('method')

0 commit comments

Comments
 (0)