Skip to content

Commit a890f3f

Browse files
committed
add tests
1 parent 92bb946 commit a890f3f

2 files changed

Lines changed: 375 additions & 16 deletions

File tree

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -450,19 +450,13 @@ async def _process_streaming_request(
450450
An `AsyncGenerator` object to stream results to the client.
451451
"""
452452
stream: AsyncGenerator | None = None
453-
if (
454-
isinstance(request_obj, SendMessageRequest)
455-
and context.state.get('method') == 'SendStreamingMessage'
456-
):
453+
if context.state.get('method') == 'SendStreamingMessage':
457454
stream = self.request_handler.on_message_send_stream(
458-
request_obj, context
455+
cast(SendMessageRequest, request_obj), context
459456
)
460-
elif (
461-
isinstance(request_obj, SubscribeToTaskRequest)
462-
and context.state.get('method') == 'SubscribeToTask'
463-
):
457+
elif context.state.get('method') == 'SubscribeToTask':
464458
stream = self.request_handler.on_subscribe_to_task(
465-
request_obj, context
459+
cast(SubscribeToTaskRequest, request_obj), context
466460
)
467461

468462
if stream is None:
@@ -598,7 +592,7 @@ async def _process_non_streaming_request( # noqa: PLR0911
598592
request_obj: A2ARequest,
599593
context: ServerCallContext,
600594
) -> dict[str, Any] | None:
601-
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
595+
"""Processes non-streaming requests.
602596
603597
Args:
604598
request_obj: The proto request message.
@@ -607,7 +601,7 @@ async def _process_non_streaming_request( # noqa: PLR0911
607601
Returns:
608602
A dict containing the result or error.
609603
"""
610-
match context.state.get('method', None):
604+
match context.state.get('method'):
611605
case 'SendMessage':
612606
return await self._handle_send_message(
613607
cast('SendMessageRequest', request_obj), context
@@ -651,11 +645,10 @@ async def _process_non_streaming_request( # noqa: PLR0911
651645
cast('GetExtendedAgentCardRequest', request_obj), context
652646
)
653647
case _:
654-
logger.error(
655-
'Unhandled validated request type: %s', type(request_obj)
656-
)
648+
method = context.state.get('method')
649+
logger.error('Unhandled method: %s', method)
657650
raise UnsupportedOperationError(
658-
message=f'Request type {type(request_obj).__name__} is unknown.'
651+
message=f'Method {method} is not supported.'
659652
)
660653

661654
def _create_response(

0 commit comments

Comments
 (0)