Skip to content

Commit eb98c35

Browse files
committed
fix
1 parent a663568 commit eb98c35

3 files changed

Lines changed: 65 additions & 39 deletions

File tree

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class JsonRpcDispatcher:
194194
def __init__( # noqa: PLR0913
195195
self,
196196
agent_card: AgentCard,
197-
http_handler: RequestHandler,
197+
request_handler: RequestHandler,
198198
extended_agent_card: AgentCard | None = None,
199199
context_builder: CallContextBuilder | None = None,
200200
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
@@ -209,12 +209,12 @@ def __init__( # noqa: PLR0913
209209
210210
Args:
211211
agent_card: The AgentCard describing the agent's capabilities.
212-
http_handler: The handler instance responsible for processing A2A
212+
request_handler: The handler instance responsible for processing A2A
213213
requests via http.
214214
extended_agent_card: An optional, distinct AgentCard to be served
215215
at the authenticated extended card endpoint.
216216
context_builder: The CallContextBuilder used to construct the
217-
ServerCallContext passed to the http_handler. If None, no
217+
ServerCallContext passed to the request_handler. If None, no
218218
ServerCallContext is passed.
219219
card_modifier: An optional callback to dynamically modify the public
220220
agent card before it is served.
@@ -231,7 +231,7 @@ def __init__( # noqa: PLR0913
231231
)
232232

233233
self.agent_card = agent_card
234-
self.http_handler = http_handler
234+
self.request_handler = request_handler
235235
self.extended_agent_card = extended_agent_card
236236
self.card_modifier = card_modifier
237237
self.extended_card_modifier = extended_card_modifier
@@ -242,7 +242,7 @@ def __init__( # noqa: PLR0913
242242
if self.enable_v0_3_compat:
243243
self._v03_adapter = JSONRPC03Adapter(
244244
agent_card=agent_card,
245-
http_handler=http_handler,
245+
http_handler=request_handler,
246246
extended_agent_card=extended_agent_card,
247247
context_builder=self._context_builder,
248248
card_modifier=card_modifier,
@@ -433,7 +433,6 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911,
433433
)
434434
async def _require_push_notifications(self) -> None:
435435
"""Helper to enforce push notifications capability."""
436-
pass
437436

438437
@validate_version(constants.PROTOCOL_VERSION_1_0)
439438
@validate(
@@ -445,7 +444,7 @@ async def _process_streaming_request(
445444
request_id: str | int | None,
446445
request_obj: A2ARequest,
447446
context: ServerCallContext,
448-
) -> Response:
447+
) -> AsyncGenerator[dict[str, Any], None]:
449448
"""Processes streaming requests (SendStreamingMessage or SubscribeToTask).
450449
451450
Args:
@@ -454,15 +453,15 @@ async def _process_streaming_request(
454453
context: The ServerCallContext for the request.
455454
456455
Returns:
457-
An `EventSourceResponse` object to stream results to the client.
456+
An `AsyncGenerator` object to stream results to the client.
458457
"""
459458
stream: AsyncGenerator | None = None
460459
if isinstance(request_obj, SendMessageRequest):
461-
stream = self.http_handler.on_message_send_stream(
460+
stream = self.request_handler.on_message_send_stream(
462461
request_obj, context
463462
)
464463
elif isinstance(request_obj, SubscribeToTaskRequest):
465-
stream = self.http_handler.on_subscribe_to_task(
464+
stream = self.request_handler.on_subscribe_to_task(
466465
request_obj, context
467466
)
468467

@@ -485,12 +484,12 @@ async def _wrap_stream(
485484
return _wrap_stream(stream)
486485

487486
@validate_version(constants.PROTOCOL_VERSION_1_0)
488-
async def _process_non_streaming_request(
487+
async def _process_non_streaming_request( # noqa: PLR0911, PLR0912
489488
self,
490489
request_id: str | int | None,
491490
request_obj: A2ARequest,
492491
context: ServerCallContext,
493-
) -> Response:
492+
) -> dict[str, Any] | None:
494493
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
495494
496495
Args:
@@ -499,48 +498,71 @@ async def _process_non_streaming_request(
499498
context: The ServerCallContext for the request.
500499
501500
Returns:
502-
A `JSONResponse` object containing the result or error.
501+
A dict containing the result or error.
503502
"""
504-
handler_result: Any = None
505503
match request_obj:
506504
case SendMessageRequest():
507-
task_or_message = await self.http_handler.on_message_send(request_obj, context)
505+
task_or_message = await self.request_handler.on_message_send(
506+
request_obj, context
507+
)
508508
if isinstance(task_or_message, Task):
509-
response = SendMessageResponse(task=task_or_message)
510-
else:
511-
response = SendMessageResponse(message=task_or_message)
512-
return MessageToDict(response)
509+
return MessageToDict(
510+
SendMessageResponse(task=task_or_message)
511+
)
512+
return MessageToDict(
513+
SendMessageResponse(message=task_or_message)
514+
)
513515
case CancelTaskRequest():
514-
task = await self.http_handler.on_cancel_task(request_obj, context)
516+
task = await self.request_handler.on_cancel_task(
517+
request_obj, context
518+
)
515519
if task:
516-
return MessageToDict(task, preserving_proto_field_name=False)
517-
else:
518-
raise TaskNotFoundError()
520+
return MessageToDict(
521+
task, preserving_proto_field_name=False
522+
)
523+
raise TaskNotFoundError
519524
case GetTaskRequest():
520-
task = await self.http_handler.on_get_task(request_obj, context)
525+
task = await self.request_handler.on_get_task(
526+
request_obj, context
527+
)
521528
if task:
522-
return MessageToDict(task, preserving_proto_field_name=False)
523-
else:
524-
raise TaskNotFoundError()
529+
return MessageToDict(
530+
task, preserving_proto_field_name=False
531+
)
532+
raise TaskNotFoundError
525533
case ListTasksRequest():
526-
response = await self.http_handler.on_list_tasks(request_obj, context)
534+
tasks_response = await self.request_handler.on_list_tasks(
535+
request_obj, context
536+
)
527537
return MessageToDict(
528-
response,
538+
tasks_response,
529539
preserving_proto_field_name=False,
530540
always_print_fields_with_no_presence=True,
531541
)
532542
case TaskPushNotificationConfig():
533543
await self._require_push_notifications()
534-
result_config = await self.http_handler.on_create_task_push_notification_config(request_obj, context)
535-
return MessageToDict(result_config, preserving_proto_field_name=False)
544+
result_config = await self.request_handler.on_create_task_push_notification_config(
545+
request_obj, context
546+
)
547+
return MessageToDict(
548+
result_config, preserving_proto_field_name=False
549+
)
536550
case GetTaskPushNotificationConfigRequest():
537-
config = await self.http_handler.on_get_task_push_notification_config(request_obj, context)
551+
config = await self.request_handler.on_get_task_push_notification_config(
552+
request_obj, context
553+
)
538554
return MessageToDict(config, preserving_proto_field_name=False)
539555
case ListTaskPushNotificationConfigsRequest():
540-
response = await self.http_handler.on_list_task_push_notification_configs(request_obj, context)
541-
return MessageToDict(response, preserving_proto_field_name=False)
556+
configs_response = await self.request_handler.on_list_task_push_notification_configs(
557+
request_obj, context
558+
)
559+
return MessageToDict(
560+
configs_response, preserving_proto_field_name=False
561+
)
542562
case DeleteTaskPushNotificationConfigRequest():
543-
await self.http_handler.on_delete_task_push_notification_config(request_obj, context)
563+
await self.request_handler.on_delete_task_push_notification_config(
564+
request_obj, context
565+
)
544566
return None
545567
case GetExtendedAgentCardRequest():
546568
if not self.agent_card.capabilities.extended_agent_card:
@@ -554,9 +576,13 @@ async def _process_non_streaming_request(
554576
self.extended_card_modifier(base_card, context)
555577
)
556578
elif self.card_modifier:
557-
card_to_serve = await maybe_await(self.card_modifier(base_card))
579+
card_to_serve = await maybe_await(
580+
self.card_modifier(base_card)
581+
)
558582

559-
return MessageToDict(card_to_serve, preserving_proto_field_name=False)
583+
return MessageToDict(
584+
card_to_serve, preserving_proto_field_name=False
585+
)
560586
case _:
561587
logger.error(
562588
'Unhandled validated request type: %s', type(request_obj)

src/a2a/server/routes/jsonrpc_routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def create_jsonrpc_routes( # noqa: PLR0913
7272

7373
dispatcher = JsonRpcDispatcher(
7474
agent_card=agent_card,
75-
http_handler=request_handler,
75+
request_handler=request_handler,
7676
extended_agent_card=extended_agent_card,
7777
context_builder=context_builder,
7878
card_modifier=card_modifier,

tests/server/routes/test_jsonrpc_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def mock_app_params(self) -> dict:
126126
mock_handler = MagicMock(spec=RequestHandler)
127127
mock_agent_card = MagicMock(spec=AgentCard)
128128
mock_agent_card.url = 'http://example.com'
129-
return {'agent_card': mock_agent_card, 'http_handler': mock_handler}
129+
return {'agent_card': mock_agent_card, 'request_handler': mock_handler}
130130

131131
@pytest.fixture(scope='class')
132132
def mark_pkg_starlette_not_installed(self):

0 commit comments

Comments
 (0)