|
6 | 6 |
|
7 | 7 | from abc import ABC, abstractmethod |
8 | 8 | from collections.abc import AsyncGenerator, Awaitable, Callable |
9 | | -from typing import TYPE_CHECKING, Any |
| 9 | +from typing import TYPE_CHECKING, Any, cast |
10 | 10 |
|
11 | 11 | from google.protobuf.json_format import MessageToDict, ParseDict |
12 | 12 | from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response |
@@ -400,7 +400,7 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, |
400 | 400 | else: |
401 | 401 | try: |
402 | 402 | raw_result = await self._process_non_streaming_request( |
403 | | - request_id, specific_request, call_context |
| 403 | + specific_request, call_context |
404 | 404 | ) |
405 | 405 | handler_result = JSONRPC20Response( |
406 | 406 | result=raw_result, _id=request_id |
@@ -450,11 +450,17 @@ async def _process_streaming_request( |
450 | 450 | An `AsyncGenerator` object to stream results to the client. |
451 | 451 | """ |
452 | 452 | stream: AsyncGenerator | None = None |
453 | | - if isinstance(request_obj, SendMessageRequest): |
| 453 | + if ( |
| 454 | + isinstance(request_obj, SendMessageRequest) |
| 455 | + and context.state.get('method') == 'SendStreamingMessage' |
| 456 | + ): |
454 | 457 | stream = self.request_handler.on_message_send_stream( |
455 | 458 | request_obj, context |
456 | 459 | ) |
457 | | - elif isinstance(request_obj, SubscribeToTaskRequest): |
| 460 | + elif ( |
| 461 | + isinstance(request_obj, SubscribeToTaskRequest) |
| 462 | + and context.state.get('method') == 'SubscribeToTask' |
| 463 | + ): |
458 | 464 | stream = self.request_handler.on_subscribe_to_task( |
459 | 465 | request_obj, context |
460 | 466 | ) |
@@ -589,48 +595,60 @@ async def _handle_get_extended_agent_card( |
589 | 595 | @validate_version(constants.PROTOCOL_VERSION_1_0) |
590 | 596 | async def _process_non_streaming_request( # noqa: PLR0911 |
591 | 597 | self, |
592 | | - request_id: str | int | None, |
593 | 598 | request_obj: A2ARequest, |
594 | 599 | context: ServerCallContext, |
595 | 600 | ) -> dict[str, Any] | None: |
596 | 601 | """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*). |
597 | 602 |
|
598 | 603 | Args: |
599 | | - request_id: The ID of the request. |
600 | 604 | request_obj: The proto request message. |
601 | 605 | context: The ServerCallContext for the request. |
602 | 606 |
|
603 | 607 | Returns: |
604 | 608 | A dict containing the result or error. |
605 | 609 | """ |
606 | | - match request_obj: |
607 | | - case SendMessageRequest(): |
608 | | - return await self._handle_send_message(request_obj, context) |
609 | | - case CancelTaskRequest(): |
610 | | - return await self._handle_cancel_task(request_obj, context) |
611 | | - case GetTaskRequest(): |
612 | | - return await self._handle_get_task(request_obj, context) |
613 | | - case ListTasksRequest(): |
614 | | - return await self._handle_list_tasks(request_obj, context) |
615 | | - case TaskPushNotificationConfig(): |
| 610 | + match context.state.get('method', None): |
| 611 | + case 'SendMessage': |
| 612 | + return await self._handle_send_message( |
| 613 | + cast('SendMessageRequest', request_obj), context |
| 614 | + ) |
| 615 | + case 'CancelTask': |
| 616 | + return await self._handle_cancel_task( |
| 617 | + cast('CancelTaskRequest', request_obj), context |
| 618 | + ) |
| 619 | + case 'GetTask': |
| 620 | + return await self._handle_get_task( |
| 621 | + cast('GetTaskRequest', request_obj), context |
| 622 | + ) |
| 623 | + case 'ListTasks': |
| 624 | + return await self._handle_list_tasks( |
| 625 | + cast('ListTasksRequest', request_obj), context |
| 626 | + ) |
| 627 | + case 'CreateTaskPushNotificationConfig': |
616 | 628 | return await self._handle_create_task_push_notification_config( |
617 | | - request_obj, context |
| 629 | + cast('TaskPushNotificationConfig', request_obj), context |
618 | 630 | ) |
619 | | - case GetTaskPushNotificationConfigRequest(): |
| 631 | + case 'GetTaskPushNotificationConfig': |
620 | 632 | return await self._handle_get_task_push_notification_config( |
621 | | - request_obj, context |
| 633 | + cast('GetTaskPushNotificationConfigRequest', request_obj), |
| 634 | + context, |
622 | 635 | ) |
623 | | - case ListTaskPushNotificationConfigsRequest(): |
| 636 | + case 'ListTaskPushNotificationConfigs': |
624 | 637 | return await self._handle_list_task_push_notification_configs( |
625 | | - request_obj, context |
| 638 | + cast('ListTaskPushNotificationConfigsRequest', request_obj), |
| 639 | + context, |
626 | 640 | ) |
627 | | - case DeleteTaskPushNotificationConfigRequest(): |
628 | | - return await self._handle_delete_task_push_notification_config( |
629 | | - request_obj, context |
| 641 | + case 'DeleteTaskPushNotificationConfig': |
| 642 | + await self._handle_delete_task_push_notification_config( |
| 643 | + cast( |
| 644 | + 'DeleteTaskPushNotificationConfigRequest', request_obj |
| 645 | + ), |
| 646 | + context, |
630 | 647 | ) |
631 | | - case GetExtendedAgentCardRequest(): |
| 648 | + return None |
| 649 | + case 'GetExtendedAgentCard': |
632 | 650 | return await self._handle_get_extended_agent_card( |
633 | | - request_obj, context |
| 651 | + cast('GetExtendedAgentCardRequest', request_obj), context |
634 | 652 | ) |
635 | 653 | case _: |
636 | 654 | logger.error( |
|
0 commit comments