diff --git a/sdks/python/hatchet_sdk/context/context.py b/sdks/python/hatchet_sdk/context/context.py index a01c89ba17..35439ff288 100644 --- a/sdks/python/hatchet_sdk/context/context.py +++ b/sdks/python/hatchet_sdk/context/context.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, cast, overload from warnings import warn from pydantic import BaseModel, TypeAdapter @@ -98,6 +98,52 @@ def _compute_memo_key(task_run_external_id: str, *args: Any, **kwargs: Any) -> b return h.digest() +TSagaOperationResult = TypeVar("TSagaOperationResult") + + +class SagaOperation(Generic[TSagaOperationResult]): + def __init__( + self, + operation_fn: Callable[[], TSagaOperationResult], + compensation_fn: Callable[[], None], + ): + self._compensation_fn = compensation_fn + self._operation_fn = operation_fn + + def apply(self) -> TSagaOperationResult: + return self._operation_fn() + + def rollback(self) -> None: + return self._compensation_fn() + + +class Saga: + def __init__(self) -> None: + self._stack: list[SagaOperation[Any]] = [] + + def add( + self, + operation_fn: Callable[[], TSagaOperationResult], + compensation_fn: Callable[[], None], + ) -> TSagaOperationResult: + operation = SagaOperation(operation_fn, compensation_fn) + self._stack.append(operation) + + try: + return operation.apply() + except Exception as e: + self._rollback() + raise e + + def _rollback(self) -> None: + while self._stack: + operation = self._stack.pop() + try: + operation.rollback() + except Exception: + logger.exception("Error during compensation") + + class Context: def __init__( self, @@ -721,6 +767,9 @@ def get_task_run_error( return TaskRunError.deserialize(error) + def begin_compensation_chain(self) -> Saga: + return Saga() + @dataclass class DurableSpawnResult: