Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/8252-celery-on-failure-handler.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type: Changed
description: Added Celery on_failure handler to log worker-level DSR task deaths (OOM, hard timeout, broker disconnect) with error execution logs
pr: 8252
labels: []
55 changes: 55 additions & 0 deletions src/fides/api/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,61 @@ class DatabaseTask(Task): # pylint: disable=W0223
_task_engine = None
_sessionmaker = None

def on_failure(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not going to help if the worker process is killed by the OS, only "softer" failures where Celery is still alive

self, exc: BaseException, task_id: str, args: tuple, kwargs: dict, einfo: Any
) -> None:
"""Log an execution log when a privacy request task fails at the worker level.

Catches failures that bypass the task's own exception handling: hard time
limit exceeded, worker killed, broker disconnect, etc. Skips if the
in-task BaseException catch-all already handled it (status already error).
Only applies to tasks with a privacy_request_id kwarg; other tasks are ignored.
"""
privacy_request_id = kwargs.get("privacy_request_id")
if not privacy_request_id:
return

try:
session = self.get_new_session()
try:
from fides.api.models.privacy_request import PrivacyRequest
from fides.api.schemas.privacy_request import PrivacyRequestStatus
Comment on lines +78 to +79
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline to avoid circular deps 😢


privacy_request = (
session.query(PrivacyRequest)
.filter(PrivacyRequest.id == privacy_request_id)
.first()
)
if not privacy_request:
return

if privacy_request.status == PrivacyRequestStatus.error:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are already handled by the in-task handler

return

logger.error(
"Privacy request '{}' failed at worker level: {}",
privacy_request_id,
str(exc),
)
privacy_request.add_error_execution_log(
session,
connection_key=None,
dataset_name="Worker task failure",
collection_name=None,
message=f"Task failed at worker level: {type(exc).__name__}: {exc}",
Copy link
Copy Markdown
Contributor Author

@eastandwestwind eastandwestwind May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some example scenarios with corresponding err message:

OOM / Hard time limit (Celery kills the worker):

  • Task failed at worker level: TimeLimitExceeded: TimeLimitExceeded(3600,)
  • Task failed at worker level: WorkerLostError: Worker exited prematurely: signal 9 (SIGKILL) Job: 42.

Broker disconnect:

  • Task failed at worker level: ConnectionError: Error while reading from socket: Connection reset by peer

DB connection lost mid-task (if it escapes the catch-all):

  • Task failed at worker level: OperationalError: (psycopg2.OperationalError) server closed the connection unexpectedly

Memory watchdog (if enabled):

  • Task failed at worker level: MemoryLimitExceeded: Memory usage at 94.2% exceeds threshold of 90%

action_type=privacy_request.policy.get_action_type(), # type: ignore[arg-type]
)
privacy_request.error_processing(db=session)
session.commit()
finally:
session.close()
except Exception: # pylint: disable=broad-except
logger.error(
"Failed to log worker-level failure for privacy request '{}': {}",
privacy_request_id,
str(exc),
)

# This retry will attempt to connect 5 times with an exponential backoff (2, 4, 8, 16 seconds between each attempt).
# The original error will be re-raised if the retries are successful. All attempts are shown in the logs.
@retry(
Expand Down
99 changes: 99 additions & 0 deletions tests/fides/ops/tasks/test_database_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=protected-access

from unittest import mock
from unittest.mock import MagicMock, Mock, patch

import pytest
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -73,3 +74,101 @@ def test_max_retries_exceeded(mock_db_task, always_failing_session_maker):
with task.get_new_session():
pass
assert always_failing_session_maker.call_count == NEW_SESSION_RETRIES


class TestDatabaseTaskOnFailure:
"""Tests for the on_failure handler that logs worker-level task deaths."""

def test_on_failure_skips_non_privacy_request_tasks(self):
"""Tasks without privacy_request_id in kwargs are ignored."""
task = DatabaseTask()
task.on_failure(
exc=RuntimeError("boom"),
task_id="test-task-id",
args=(),
kwargs={"some_other_param": "value"},
einfo=None,
)
# No exception raised, no DB interaction

@patch.object(DatabaseTask, "get_new_session")
def test_on_failure_creates_error_log_for_worker_death(self, mock_get_session):
"""When a privacy request task dies at the worker level, an error
execution log is created and the request is marked as errored."""
mock_session = MagicMock()
mock_get_session.return_value = mock_session

mock_privacy_request = MagicMock()
mock_privacy_request.status = MagicMock()
mock_privacy_request.status.__eq__ = (
lambda self, other: False
) # not already errored
mock_privacy_request.policy.get_action_type.return_value = "access"

mock_session.query.return_value.filter.return_value.first.return_value = (
mock_privacy_request
)

task = DatabaseTask()
task.on_failure(
exc=RuntimeError("Worker killed by OOM"),
task_id="test-task-id",
args=(),
kwargs={"privacy_request_id": "test-pr-id"},
einfo=None,
)

mock_privacy_request.add_error_execution_log.assert_called_once()
call_kwargs = mock_privacy_request.add_error_execution_log.call_args
assert "Worker killed by OOM" in call_kwargs[1][
"message"
] or "Worker killed by OOM" in str(call_kwargs)
assert call_kwargs[1]["dataset_name"] == "Worker task failure"

mock_privacy_request.error_processing.assert_called_once_with(db=mock_session)
mock_session.commit.assert_called_once()
mock_session.close.assert_called_once()

@patch.object(DatabaseTask, "get_new_session")
def test_on_failure_skips_already_errored_request(self, mock_get_session):
"""If the in-task exception handler already handled the error, on_failure is a no-op."""
mock_session = MagicMock()
mock_get_session.return_value = mock_session

# Simulate PrivacyRequestStatus.error comparison
from fides.api.schemas.privacy_request import PrivacyRequestStatus

mock_privacy_request = MagicMock()
mock_privacy_request.status = PrivacyRequestStatus.error

mock_session.query.return_value.filter.return_value.first.return_value = (
mock_privacy_request
)

task = DatabaseTask()
task.on_failure(
exc=RuntimeError("boom"),
task_id="test-task-id",
args=(),
kwargs={"privacy_request_id": "test-pr-id"},
einfo=None,
)

mock_privacy_request.add_error_execution_log.assert_not_called()
mock_privacy_request.error_processing.assert_not_called()
mock_session.close.assert_called_once()

@patch.object(DatabaseTask, "get_new_session")
def test_on_failure_handles_db_errors_gracefully(self, mock_get_session):
"""If the DB is unavailable during on_failure, the error is logged but not raised."""
mock_get_session.side_effect = OperationalError("DB down", None, None)

task = DatabaseTask()
# Should not raise
task.on_failure(
exc=RuntimeError("original error"),
task_id="test-task-id",
args=(),
kwargs={"privacy_request_id": "test-pr-id"},
einfo=None,
)
Loading