Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,6 @@ async def push_notification_callback() -> None:
blocking=blocking,
event_callback=push_notification_callback,
)
if not result:
raise ServerError(error=InternalError()) # noqa: TRY301

if isinstance(result, Task):
self._validate_task_id_match(task_id, result.id)

await self._send_push_notification_if_needed(
task_id, result_aggregator
)

except Exception:
logger.exception('Agent execution failed')
Expand All @@ -367,6 +358,14 @@ async def push_notification_callback() -> None:
else:
await self._cleanup_producer(producer_task, task_id)

if not result:
raise ServerError(error=InternalError())

if isinstance(result, Task):
self._validate_task_id_match(task_id, result.id)

await self._send_push_notification_if_needed(task_id, result_aggregator)

return result

async def on_message_send_stream(
Expand Down
44 changes: 26 additions & 18 deletions src/a2a/server/request_handlers/jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,22 +183,26 @@ async def on_cancel_task(
task = await self.request_handler.on_cancel_task(
request.params, context
)
if task:
return prepare_response_object(
request.id,
task,
(Task,),
CancelTaskSuccessResponse,
CancelTaskResponse,
)
raise ServerError(error=TaskNotFoundError()) # noqa: TRY301
except ServerError as e:
return CancelTaskResponse(
root=JSONRPCErrorResponse(
id=request.id, error=e.error if e.error else InternalError()
)
)

if task:
return prepare_response_object(
request.id,
task,
(Task,),
CancelTaskSuccessResponse,
CancelTaskResponse,
)

return CancelTaskResponse(
root=JSONRPCErrorResponse(id=request.id, error=TaskNotFoundError())
)

async def on_resubscribe_to_task(
self,
request: TaskResubscriptionRequest,
Expand Down Expand Up @@ -335,22 +339,26 @@ async def on_get_task(
task = await self.request_handler.on_get_task(
request.params, context
)
if task:
return prepare_response_object(
request.id,
task,
(Task,),
GetTaskSuccessResponse,
GetTaskResponse,
)
raise ServerError(error=TaskNotFoundError()) # noqa: TRY301
except ServerError as e:
return GetTaskResponse(
root=JSONRPCErrorResponse(
id=request.id, error=e.error if e.error else InternalError()
)
)

if task:
return prepare_response_object(
request.id,
task,
(Task,),
GetTaskSuccessResponse,
GetTaskResponse,
)

return GetTaskResponse(
root=JSONRPCErrorResponse(id=request.id, error=TaskNotFoundError())
)

async def list_push_notification_config(
self,
request: ListTaskPushNotificationConfigRequest,
Expand Down
Loading