diff --git a/workers/queue_backend/__init__.py b/workers/queue_backend/__init__.py index 5fe4ed4510..a8a0d681b6 100644 --- a/workers/queue_backend/__init__.py +++ b/workers/queue_backend/__init__.py @@ -1,22 +1,11 @@ -"""Queue-backend seam for workers. +"""Queue-backend seam. -This module is the single place where the choice of queue substrate -(Celery+RabbitMQ today; PG Queue in the future) lives. - -Today both entry points are no-op aliases over Celery primitives: - -* ``dispatch(task_name, args, kwargs, queue)`` -> ``current_app.send_task(...)`` -* ``@worker_task`` -> ``@shared_task`` - -A later phase will route specific tasks through a non-Celery substrate -(PG Queue) based on configuration; until then everything goes to Celery. -The exact routing mechanism is intentionally not pinned here. - -Call sites should migrate to this module so the eventual substrate switch -is a single-flag operation rather than a codebase-wide rewrite. +Single place where the substrate choice (Celery today; PG Queue later) +lives. Both entry points are transparent passthroughs to Celery today. """ from .decorator import worker_task from .dispatch import dispatch +from .fairness import FairnessKey -__all__ = ["dispatch", "worker_task"] +__all__ = ["FairnessKey", "dispatch", "worker_task"] diff --git a/workers/queue_backend/decorator.py b/workers/queue_backend/decorator.py index 9f8b666fe0..d7755ba06f 100644 --- a/workers/queue_backend/decorator.py +++ b/workers/queue_backend/decorator.py @@ -1,18 +1,7 @@ -"""Task registration decorator. +"""Transparent wrapper over ``celery.shared_task``. -Today: a transparent wrapper over ``celery.shared_task``. -Future: registers the task body with whichever substrates are enabled -(Celery + optionally PG Queue), so a single ``@worker_task`` definition -can be served by either consumer. - -Accepts both Celery decorator forms — ``shared_task`` handles them -internally, so a pass-through ``*args, **kwargs`` is enough: - - @worker_task - def healthcheck(self): ... - - @worker_task(bind=True, name="my.task") - def my_task(self, payload): ... +Accepts both decorator forms (bare and parameterised); a later phase +may register the task body with non-Celery substrates from here too. """ from __future__ import annotations @@ -23,17 +12,10 @@ def my_task(self, payload): ... def worker_task(*args: Any, **kwargs: Any) -> Any: - """Register a function as a worker task via the queue_backend seam. - - Today this is a one-line passthrough to ``celery.shared_task``. The - indirection is the seam: when a later phase adds PG Queue routing, - the consumer-registration logic lands here without touching call - sites. + """Register a function as a worker task. - The return type is ``Any`` because ``shared_task`` returns different - objects depending on call form — a ``PromiseProxy`` for the bare - ``@worker_task`` form and a decorator factory for the parameterised - ``@worker_task(name=...)`` form. Pinning a tighter type would lock - out future routing variants without buying real safety today. + ``Any`` return type because ``shared_task`` produces a + ``PromiseProxy`` for the bare form and a decorator factory for the + parameterised form. """ return shared_task(*args, **kwargs) diff --git a/workers/queue_backend/dispatch.py b/workers/queue_backend/dispatch.py index f1c21ca110..274922bf5d 100644 --- a/workers/queue_backend/dispatch.py +++ b/workers/queue_backend/dispatch.py @@ -1,12 +1,8 @@ """Transport-agnostic task dispatch. -Today: thin pass-through to ``celery.current_app.send_task``. -A later phase will introduce per-task routing through a non-Celery -substrate (PG Queue); call sites stay untouched. - -The signature intentionally exposes only what the current call sites -actually use (args, kwargs, queue). More Celery options can be added -when a real call site needs them — not before. +Thin pass-through to ``celery.current_app.send_task``; the indirection +is the seam — a future per-task router can land here without touching +call sites. """ from __future__ import annotations @@ -16,14 +12,14 @@ from celery import current_app +from .fairness import FAIRNESS_HEADER_NAME, FairnessKey + class DispatchHandle(Protocol): - """The minimum contract every dispatch substrate must satisfy. + """Minimum contract every dispatch substrate must satisfy. - Today this is satisfied by Celery's ``AsyncResult`` (which exposes - ``.id``). A future PG Queue handle will need to expose the same - attribute so existing callers — e.g. ``scheduler/tasks.py`` — keep - working unchanged. + Celery's ``AsyncResult`` satisfies this via ``.id``; any future + substrate handle must expose the same attribute. """ id: str @@ -35,24 +31,18 @@ def dispatch( args: Sequence[Any] | None = None, kwargs: Mapping[str, Any] | None = None, queue: str | None = None, + fairness: FairnessKey | None = None, ) -> DispatchHandle: """Enqueue a task by name. - Args: - task_name: Registered task name (e.g. "send_webhook_notification"). - args: Positional task args. Forwarded verbatim; Celery normalises - ``None`` internally. - kwargs: Keyword task args. Forwarded verbatim; Celery normalises - ``None`` internally. - queue: Target queue name. Defaults to the task's bound queue. - - Returns: - A handle to the enqueued task. ``.id`` is guaranteed; everything - else is substrate-specific and callers must not rely on it. + ``fairness`` is attached as the ``x-fairness-key`` header (not in + kwargs). Pass ``None`` for non-workflow worker tasks. """ + headers = {FAIRNESS_HEADER_NAME: fairness.to_dict()} if fairness is not None else None return current_app.send_task( task_name, args=args, kwargs=kwargs, queue=queue, + headers=headers, ) diff --git a/workers/queue_backend/fairness.py b/workers/queue_backend/fairness.py new file mode 100644 index 0000000000..30ef048497 --- /dev/null +++ b/workers/queue_backend/fairness.py @@ -0,0 +1,55 @@ +"""Workflow-execution fairness key. + +Attached to dispatches that start a workflow execution. Non-workflow +worker tasks (notifications, callbacks, healthchecks) pass +``fairness=None``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Final + + +class WorkloadType(StrEnum): + """Workflow execution type. Labs L2 check is binary api-vs-not.""" + + API = "api" + NON_API = "non_api" + + +# pipeline_priority bounds per labs schema (1..10, higher = sooner). +MIN_PRIORITY: Final[int] = 1 +MAX_PRIORITY: Final[int] = 10 +DEFAULT_PRIORITY: Final[int] = 5 + +# Header (not kwarg) so task-body signatures without **kwargs aren't broken. +FAIRNESS_HEADER_NAME: Final[str] = "x-fairness-key" + + +@dataclass(frozen=True) +class FairnessKey: + """Routing metadata for a workflow-execution dispatch. + + ``org_id=None`` is valid for cross-org tasks — the scheduler's + ``org_config`` JOIN simply doesn't match. + """ + + org_id: str | None + workload_type: WorkloadType + pipeline_priority: int = DEFAULT_PRIORITY + + def __post_init__(self) -> None: + if not MIN_PRIORITY <= self.pipeline_priority <= MAX_PRIORITY: + raise ValueError( + "pipeline_priority out of range " + f"[{MIN_PRIORITY}, {MAX_PRIORITY}]: {self.pipeline_priority}" + ) + + def to_dict(self) -> dict[str, str | int | None]: + return { + "org_id": self.org_id, + "workload_type": self.workload_type.value, + "pipeline_priority": self.pipeline_priority, + } diff --git a/workers/scheduler/tasks.py b/workers/scheduler/tasks.py index 3215ecc104..a152b98aed 100644 --- a/workers/scheduler/tasks.py +++ b/workers/scheduler/tasks.py @@ -7,7 +7,8 @@ import traceback from typing import Any -from queue_backend import dispatch, worker_task +from queue_backend import FairnessKey, dispatch, worker_task +from queue_backend.fairness import WorkloadType from shared.enums.status_enums import PipelineStatus from shared.enums.worker_enums import QueueName from shared.infrastructure.config import WorkerConfig @@ -150,21 +151,24 @@ def _execute_scheduled_workflow( ) try: - # Dispatch through the queue_backend seam (Celery underneath today). async_result = dispatch( "async_execute_bin", args=[ - context.organization_id, # schema_name (organization_id) - context.workflow_id, # workflow_id - execution_id, # execution_id - {}, # hash_values_of_files (empty for scheduled) - True, # scheduled (THIS IS A SCHEDULED EXECUTION) + context.organization_id, + context.workflow_id, + execution_id, + {}, + True, # scheduled ], kwargs={ - "use_file_history": context.use_file_history, # Pass as kwarg - "pipeline_id": context.pipeline_id, # CRITICAL FIX: Pass pipeline_id for direct status updates + "use_file_history": context.use_file_history, + "pipeline_id": context.pipeline_id, }, - queue=QueueName.GENERAL, # Route to General queue for proper separation + queue=QueueName.GENERAL, + fairness=FairnessKey( + org_id=context.organization_id, + workload_type=WorkloadType.NON_API, + ), ) task_id = async_result.id diff --git a/workers/shared/patterns/notification/helper.py b/workers/shared/patterns/notification/helper.py index cd79d6bd24..199a5b8de6 100644 --- a/workers/shared/patterns/notification/helper.py +++ b/workers/shared/patterns/notification/helper.py @@ -87,6 +87,7 @@ def send_notification_to_worker( "platform": platform, }, queue="notifications", + fairness=None, # not a workflow-execution dispatch ) logger.info( diff --git a/workers/tests/test_fairness_key.py b/workers/tests/test_fairness_key.py new file mode 100644 index 0000000000..74e8352c34 --- /dev/null +++ b/workers/tests/test_fairness_key.py @@ -0,0 +1,311 @@ +"""Tests for the fairness-key plumbing (PG Queue Phase 5.1).""" + +from __future__ import annotations + +import ast +import json +import pathlib +from dataclasses import FrozenInstanceError +from unittest.mock import patch + +import pytest +from celery import Celery + +from queue_backend import FairnessKey, dispatch +from queue_backend.fairness import ( + DEFAULT_PRIORITY, + FAIRNESS_HEADER_NAME, + MAX_PRIORITY, + MIN_PRIORITY, + WorkloadType, +) + +# Helpers extracted so the audit tests below stay flat (SonarCloud +# S3776 cognitive-complexity threshold). + +_WORKERS_ROOT = pathlib.Path(__file__).parent.parent +_SKIP_TOP_DIRS = frozenset( + {"tests", "__pycache__", "htmlcov", ".venv", "queue_backend"} +) + + +def _iter_production_trees() -> list[tuple[pathlib.Path, ast.AST]]: + out: list[tuple[pathlib.Path, ast.AST]] = [] + for py in _WORKERS_ROOT.rglob("*.py"): + rel = py.relative_to(_WORKERS_ROOT) + if rel.parts and rel.parts[0] in _SKIP_TOP_DIRS: + continue + try: + tree = ast.parse(py.read_text(), filename=str(py)) + except SyntaxError: + continue + out.append((rel, tree)) + return out + + +def _aliased_dispatch_imports(tree: ast.AST) -> list[tuple[int, str]]: + hits: list[tuple[int, str]] = [] + for node in ast.walk(tree): + if not (isinstance(node, ast.ImportFrom) and node.module == "queue_backend"): + continue + for alias in node.names: + if alias.name == "dispatch" and alias.asname not in (None, "dispatch"): + hits.append((node.lineno, alias.asname)) + return hits + + +def _dispatch_calls_missing_fairness(tree: ast.AST) -> list[int]: + # Only matches the bare name ``dispatch`` — ``dispatcher.dispatch(...)`` + # is ExecutionDispatcher (executor RPC), a different concept. + hits: list[int] = [] + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + callee = node.func + if not (isinstance(callee, ast.Name) and callee.id == "dispatch"): + continue + if not any(kw.arg == "fairness" for kw in node.keywords): + hits.append(node.lineno) + return hits + + +class TestFairnessKey: + def test_minimal_construction(self): + key = FairnessKey(org_id="org-123", workload_type=WorkloadType.API) + assert key.org_id == "org-123" + assert key.workload_type == "api" + assert key.pipeline_priority == DEFAULT_PRIORITY # default 5 + + def test_org_id_can_be_none(self): + key = FairnessKey(org_id=None, workload_type=WorkloadType.API) + assert key.org_id is None + + def test_workload_type_non_api(self): + key = FairnessKey(org_id="x", workload_type=WorkloadType.NON_API) + assert key.workload_type == "non_api" + assert key.workload_type == WorkloadType.NON_API + + def test_pipeline_priority_override(self): + key = FairnessKey(org_id="x", workload_type=WorkloadType.API, pipeline_priority=9) + assert key.pipeline_priority == 9 + + def test_is_frozen(self): + key = FairnessKey(org_id="x", workload_type=WorkloadType.API) + with pytest.raises(FrozenInstanceError): + key.org_id = "y" # type: ignore[misc] + + def test_priority_below_range_rejected(self): + with pytest.raises(ValueError, match="pipeline_priority out of range"): + FairnessKey( + org_id="x", workload_type=WorkloadType.API, pipeline_priority=MIN_PRIORITY - 1 + ) + + def test_priority_above_range_rejected(self): + with pytest.raises(ValueError, match="pipeline_priority out of range"): + FairnessKey( + org_id="x", workload_type=WorkloadType.API, pipeline_priority=MAX_PRIORITY + 1 + ) + + def test_priority_boundaries_accepted(self): + FairnessKey(org_id="x", workload_type=WorkloadType.API, pipeline_priority=MIN_PRIORITY) + FairnessKey(org_id="x", workload_type=WorkloadType.API, pipeline_priority=MAX_PRIORITY) + + def test_typo_in_field_name_raises(self): + with pytest.raises(TypeError, match="pipeline_prio"): + FairnessKey( + org_id="x", + workload_type=WorkloadType.API, + pipeline_prio=9, # type: ignore[call-arg] + ) + + def test_to_dict_shape(self): + key = FairnessKey( + org_id="org-1", workload_type=WorkloadType.NON_API, pipeline_priority=9 + ) + assert key.to_dict() == { + "org_id": "org-1", + "workload_type": "non_api", + "pipeline_priority": 9, + } + + def test_to_dict_uses_plain_string_not_enum_member(self): + # Downstream consumers shouldn't need to import WorkloadType. + key = FairnessKey(org_id="x", workload_type=WorkloadType.API) + wt = key.to_dict()["workload_type"] + assert type(wt) is str + assert wt == "api" + + def test_to_dict_is_json_safe(self): + key = FairnessKey( + org_id="org-1", workload_type=WorkloadType.API, pipeline_priority=7 + ) + round_tripped = json.loads(json.dumps(key.to_dict())) + assert round_tripped == key.to_dict() + + def test_orgless_key_round_trips(self): + key = FairnessKey(org_id=None, workload_type=WorkloadType.API) + round_tripped = json.loads(json.dumps(key.to_dict())) + assert round_tripped == { + "org_id": None, + "workload_type": "api", + "pipeline_priority": DEFAULT_PRIORITY, + } + + +# --- dispatch() integration --- + + +class TestDispatchAttachesFairness: + def test_omitted_fairness_no_header_sent(self): + with patch("queue_backend.dispatch.current_app") as mock_app: + dispatch("any_task", kwargs={"foo": "bar"}) + + call_kwargs = mock_app.send_task.call_args.kwargs + assert call_kwargs["headers"] is None + assert call_kwargs["kwargs"] == {"foo": "bar"} + + def test_explicit_fairness_none_no_header_sent(self): + # Documented opt-out for non-workflow dispatches. + with patch("queue_backend.dispatch.current_app") as mock_app: + dispatch("send_webhook_notification", kwargs={"x": 1}, fairness=None) + + call_kwargs = mock_app.send_task.call_args.kwargs + assert call_kwargs["headers"] is None + assert call_kwargs["kwargs"] == {"x": 1} + + def test_provided_fairness_attached_as_message_header(self): + with patch("queue_backend.dispatch.current_app") as mock_app: + dispatch( + "any_task", + kwargs={"foo": "bar"}, + fairness=FairnessKey( + org_id="org-1", workload_type=WorkloadType.API, pipeline_priority=9 + ), + ) + + call_kwargs = mock_app.send_task.call_args.kwargs + assert call_kwargs["headers"] == { + FAIRNESS_HEADER_NAME: { + "org_id": "org-1", + "workload_type": "api", + "pipeline_priority": 9, + } + } + # Business kwargs must NOT contain the fairness slot — tasks + # without **kwargs would break. + sent_kwargs = call_kwargs["kwargs"] + assert sent_kwargs == {"foo": "bar"} + assert FAIRNESS_HEADER_NAME not in sent_kwargs + + def test_fairness_with_no_business_kwargs(self): + with patch("queue_backend.dispatch.current_app") as mock_app: + dispatch( + "any_task", + fairness=FairnessKey(org_id=None, workload_type=WorkloadType.NON_API), + ) + + call_kwargs = mock_app.send_task.call_args.kwargs + assert call_kwargs["kwargs"] is None + assert call_kwargs["headers"] == { + FAIRNESS_HEADER_NAME: { + "org_id": None, + "workload_type": "non_api", + "pipeline_priority": DEFAULT_PRIORITY, + } + } + + def test_caller_kwargs_not_mutated_in_place(self): + caller_kwargs = {"foo": "bar"} + with patch("queue_backend.dispatch.current_app"): + dispatch( + "any_task", + kwargs=caller_kwargs, + fairness=FairnessKey(org_id="org-1", workload_type=WorkloadType.API), + ) + + assert caller_kwargs == {"foo": "bar"} + assert FAIRNESS_HEADER_NAME not in caller_kwargs + + +class TestDispatchCallSitesPassFairness: + """AST audit: every production ``dispatch(...)`` declares fairness.""" + + def test_dispatch_must_be_imported_unaliased(self): + # Alias imports would defeat the bare-name canary below. + aliased = [ + f"{rel}:{lineno} (as {alias})" + for rel, tree in _iter_production_trees() + for lineno, alias in _aliased_dispatch_imports(tree) + ] + assert aliased == [], ( + "``queue_backend.dispatch`` must be imported under its real " + "name — alias imports defeat the fairness inventory canary. " + "Found:\n " + "\n ".join(aliased) + ) + + def test_every_production_dispatch_passes_fairness(self): + offenders = [ + f"{rel}:{lineno}" + for rel, tree in _iter_production_trees() + for lineno in _dispatch_calls_missing_fairness(tree) + ] + assert offenders == [], ( + "Production dispatch(...) call site(s) missing fairness=. " + "Every production dispatch must declare its fairness — pass " + "``fairness=FairnessKey(org_id=..., workload_type=WorkloadType...)`` " + "for a workflow-execution dispatch, or ``fairness=None`` " + "for a worker-internal task that doesn't start a workflow " + "execution. Found:\n " + "\n ".join(offenders) + ) + + +class TestNoConsumerYet: + """Additive-only invariant — no production code reads the slot yet.""" + + def test_no_consumer_reads_fairness_header(self): + forbidden_tokens = ("x-fairness-key", "FAIRNESS_HEADER_NAME") + + readers: list[str] = [] + for py in _WORKERS_ROOT.rglob("*.py"): + rel = py.relative_to(_WORKERS_ROOT) + if rel.parts and rel.parts[0] in _SKIP_TOP_DIRS: + continue + for line_no, line in enumerate(py.read_text().splitlines(), start=1): + if any(token in line for token in forbidden_tokens): + readers.append(f"{rel}:{line_no}") + + assert readers == [], ( + "Found reader(s) of the fairness slot before Phase 8. " + "Phase 5.1 is additive-only — no consumer should exist yet. " + "Found:\n " + "\n ".join(readers) + ) + + +class TestHeaderSurvivesCeleryPipeline: + """End-to-end: header survives Celery's real send_task code path.""" + + def test_header_present_on_outbound_message(self): + app = Celery( + "test_fairness_e2e", broker="memory://", backend="cache+memory://" + ) + + with patch("queue_backend.dispatch.current_app", app), patch.object( + app, "send_task", wraps=app.send_task + ) as wrapped_send: + dispatch( + "qb.e2e.echo", + fairness=FairnessKey( + org_id="org-1", + workload_type=WorkloadType.NON_API, + pipeline_priority=9, + ), + ) + + call_headers = wrapped_send.call_args.kwargs["headers"] + assert call_headers == { + FAIRNESS_HEADER_NAME: { + "org_id": "org-1", + "workload_type": "non_api", + "pipeline_priority": 9, + } + } diff --git a/workers/tests/test_queue_backend_seam.py b/workers/tests/test_queue_backend_seam.py index bf42aaf7ab..5a6ac88b59 100644 --- a/workers/tests/test_queue_backend_seam.py +++ b/workers/tests/test_queue_backend_seam.py @@ -275,7 +275,7 @@ def test_exports_worker_task(self): def test_all_exports(self): import queue_backend - assert set(queue_backend.__all__) == {"dispatch", "worker_task"} + assert set(queue_backend.__all__) == {"FairnessKey", "dispatch", "worker_task"} if __name__ == "__main__":