diff --git a/changelog/8252-celery-on-failure-handler.yaml b/changelog/8252-celery-on-failure-handler.yaml new file mode 100644 index 00000000000..8ee280977b3 --- /dev/null +++ b/changelog/8252-celery-on-failure-handler.yaml @@ -0,0 +1,4 @@ +type: Changed +description: Added Celery on_failure handler to log worker-level DSR task deaths (OOM, hard timeout, broker disconnect) with error execution logs +pr: 8252 +labels: [] diff --git a/src/fides/api/tasks/__init__.py b/src/fides/api/tasks/__init__.py index 9951b615874..253396a8169 100644 --- a/src/fides/api/tasks/__init__.py +++ b/src/fides/api/tasks/__init__.py @@ -58,6 +58,61 @@ class DatabaseTask(Task): # pylint: disable=W0223 _task_engine = None _sessionmaker = None + def on_failure( + self, exc: BaseException, task_id: str, args: tuple, kwargs: dict, einfo: Any + ) -> None: + """Log an execution log when a privacy request task fails at the worker level. + + Catches failures that bypass the task's own exception handling: hard time + limit exceeded, worker killed, broker disconnect, etc. Skips if the + in-task BaseException catch-all already handled it (status already error). + Only applies to tasks with a privacy_request_id kwarg; other tasks are ignored. + """ + privacy_request_id = kwargs.get("privacy_request_id") + if not privacy_request_id: + return + + try: + session = self.get_new_session() + try: + from fides.api.models.privacy_request import PrivacyRequest + from fides.api.schemas.privacy_request import PrivacyRequestStatus + + privacy_request = ( + session.query(PrivacyRequest) + .filter(PrivacyRequest.id == privacy_request_id) + .first() + ) + if not privacy_request: + return + + if privacy_request.status == PrivacyRequestStatus.error: + return + + logger.error( + "Privacy request '{}' failed at worker level: {}", + privacy_request_id, + str(exc), + ) + privacy_request.add_error_execution_log( + session, + connection_key=None, + dataset_name="Worker task failure", + collection_name=None, + message=f"Task failed at worker level: {type(exc).__name__}: {exc}", + action_type=privacy_request.policy.get_action_type(), # type: ignore[arg-type] + ) + privacy_request.error_processing(db=session) + session.commit() + finally: + session.close() + except Exception: # pylint: disable=broad-except + logger.error( + "Failed to log worker-level failure for privacy request '{}': {}", + privacy_request_id, + str(exc), + ) + # This retry will attempt to connect 5 times with an exponential backoff (2, 4, 8, 16 seconds between each attempt). # The original error will be re-raised if the retries are successful. All attempts are shown in the logs. @retry( diff --git a/tests/fides/ops/tasks/test_database_task.py b/tests/fides/ops/tasks/test_database_task.py index ce799a0b64d..de20974450b 100644 --- a/tests/fides/ops/tasks/test_database_task.py +++ b/tests/fides/ops/tasks/test_database_task.py @@ -1,6 +1,7 @@ # pylint: disable=protected-access from unittest import mock +from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy.engine import Engine @@ -73,3 +74,101 @@ def test_max_retries_exceeded(mock_db_task, always_failing_session_maker): with task.get_new_session(): pass assert always_failing_session_maker.call_count == NEW_SESSION_RETRIES + + +class TestDatabaseTaskOnFailure: + """Tests for the on_failure handler that logs worker-level task deaths.""" + + def test_on_failure_skips_non_privacy_request_tasks(self): + """Tasks without privacy_request_id in kwargs are ignored.""" + task = DatabaseTask() + task.on_failure( + exc=RuntimeError("boom"), + task_id="test-task-id", + args=(), + kwargs={"some_other_param": "value"}, + einfo=None, + ) + # No exception raised, no DB interaction + + @patch.object(DatabaseTask, "get_new_session") + def test_on_failure_creates_error_log_for_worker_death(self, mock_get_session): + """When a privacy request task dies at the worker level, an error + execution log is created and the request is marked as errored.""" + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + mock_privacy_request = MagicMock() + mock_privacy_request.status = MagicMock() + mock_privacy_request.status.__eq__ = ( + lambda self, other: False + ) # not already errored + mock_privacy_request.policy.get_action_type.return_value = "access" + + mock_session.query.return_value.filter.return_value.first.return_value = ( + mock_privacy_request + ) + + task = DatabaseTask() + task.on_failure( + exc=RuntimeError("Worker killed by OOM"), + task_id="test-task-id", + args=(), + kwargs={"privacy_request_id": "test-pr-id"}, + einfo=None, + ) + + mock_privacy_request.add_error_execution_log.assert_called_once() + call_kwargs = mock_privacy_request.add_error_execution_log.call_args + assert "Worker killed by OOM" in call_kwargs[1][ + "message" + ] or "Worker killed by OOM" in str(call_kwargs) + assert call_kwargs[1]["dataset_name"] == "Worker task failure" + + mock_privacy_request.error_processing.assert_called_once_with(db=mock_session) + mock_session.commit.assert_called_once() + mock_session.close.assert_called_once() + + @patch.object(DatabaseTask, "get_new_session") + def test_on_failure_skips_already_errored_request(self, mock_get_session): + """If the in-task exception handler already handled the error, on_failure is a no-op.""" + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + # Simulate PrivacyRequestStatus.error comparison + from fides.api.schemas.privacy_request import PrivacyRequestStatus + + mock_privacy_request = MagicMock() + mock_privacy_request.status = PrivacyRequestStatus.error + + mock_session.query.return_value.filter.return_value.first.return_value = ( + mock_privacy_request + ) + + task = DatabaseTask() + task.on_failure( + exc=RuntimeError("boom"), + task_id="test-task-id", + args=(), + kwargs={"privacy_request_id": "test-pr-id"}, + einfo=None, + ) + + mock_privacy_request.add_error_execution_log.assert_not_called() + mock_privacy_request.error_processing.assert_not_called() + mock_session.close.assert_called_once() + + @patch.object(DatabaseTask, "get_new_session") + def test_on_failure_handles_db_errors_gracefully(self, mock_get_session): + """If the DB is unavailable during on_failure, the error is logged but not raised.""" + mock_get_session.side_effect = OperationalError("DB down", None, None) + + task = DatabaseTask() + # Should not raise + task.on_failure( + exc=RuntimeError("original error"), + task_id="test-task-id", + args=(), + kwargs={"privacy_request_id": "test-pr-id"}, + einfo=None, + )