diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 2c71a6e51..ee406d6bc 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -344,15 +344,6 @@ async def push_notification_callback() -> None: blocking=blocking, event_callback=push_notification_callback, ) - if not result: - raise ServerError(error=InternalError()) # noqa: TRY301 - - if isinstance(result, Task): - self._validate_task_id_match(task_id, result.id) - - await self._send_push_notification_if_needed( - task_id, result_aggregator - ) except Exception: logger.exception('Agent execution failed') @@ -367,6 +358,14 @@ async def push_notification_callback() -> None: else: await self._cleanup_producer(producer_task, task_id) + if not result: + raise ServerError(error=InternalError()) + + if isinstance(result, Task): + self._validate_task_id_match(task_id, result.id) + + await self._send_push_notification_if_needed(task_id, result_aggregator) + return result async def on_message_send_stream( diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index 2cee937f4..567c61484 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -183,15 +183,6 @@ async def on_cancel_task( task = await self.request_handler.on_cancel_task( request.params, context ) - if task: - return prepare_response_object( - request.id, - task, - (Task,), - CancelTaskSuccessResponse, - CancelTaskResponse, - ) - raise ServerError(error=TaskNotFoundError()) # noqa: TRY301 except ServerError as e: return CancelTaskResponse( root=JSONRPCErrorResponse( @@ -199,6 +190,19 @@ async def on_cancel_task( ) ) + if task: + return prepare_response_object( + request.id, + task, + (Task,), + CancelTaskSuccessResponse, + CancelTaskResponse, + ) + + return CancelTaskResponse( + root=JSONRPCErrorResponse(id=request.id, error=TaskNotFoundError()) + ) + async def on_resubscribe_to_task( self, request: TaskResubscriptionRequest, @@ -335,15 +339,6 @@ async def on_get_task( task = await self.request_handler.on_get_task( request.params, context ) - if task: - return prepare_response_object( - request.id, - task, - (Task,), - GetTaskSuccessResponse, - GetTaskResponse, - ) - raise ServerError(error=TaskNotFoundError()) # noqa: TRY301 except ServerError as e: return GetTaskResponse( root=JSONRPCErrorResponse( @@ -351,6 +346,19 @@ async def on_get_task( ) ) + if task: + return prepare_response_object( + request.id, + task, + (Task,), + GetTaskSuccessResponse, + GetTaskResponse, + ) + + return GetTaskResponse( + root=JSONRPCErrorResponse(id=request.id, error=TaskNotFoundError()) + ) + async def list_push_notification_config( self, request: ListTaskPushNotificationConfigRequest,