Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ MINIO_PRESIGN_EXPIRY_SECONDS=3600
REDIS_URL=redis://redis:6379/0

# App
# Set to `production` to enforce fail-fast on default/`change-me` secrets at startup.
APP_ENV=development
SECRET_KEY=change-me-super-secret-must-be-at-least-32-characters-long
CORS_ALLOW_ORIGINS=http://localhost:5173,http://127.0.0.1:5173
SKIP_DB_MIGRATIONS=false
Expand All @@ -31,6 +33,9 @@ REFRESH_TOKEN_EXPIRE_DAYS=7

# Observability
PROMETHEUS_PORT=9090
# Optional: when set, GET /metrics requires `Authorization: Bearer <token>`.
# Leave empty to keep /metrics open (dev / trusted internal networks only).
METRICS_BEARER_TOKEN=

# Grafana
GRAFANA_ADMIN_USER=admin
Expand Down
151 changes: 151 additions & 0 deletions AUDIT.md

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion agent/src/vf_agent/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
def _payload(ident: identity.Identity) -> dict[str, Any]:
snap = discover.discover()
return {
# Token also goes in the Authorization header (preferred); kept in the
# body for compatibility with platforms that predate header auth.
"register_token": ident.register_token,
"status": "online",
"cpu_usage_pct": snap["cpu_usage_pct"],
Expand All @@ -43,7 +45,8 @@ def send_once(ident: identity.Identity, *, client: httpx.Client | None = None) -
own = client is None
c = client or httpx.Client(timeout=HEARTBEAT_TIMEOUT_S)
try:
resp = c.post(url, json=_payload(ident))
headers = {"Authorization": f"Bearer {ident.register_token}"}
resp = c.post(url, json=_payload(ident), headers=headers)
return resp.status_code
except httpx.HTTPError as exc:
logger.warning("heartbeat transport error: %s", exc)
Expand Down
129 changes: 104 additions & 25 deletions agent/src/vf_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import subprocess
import sys
import time
from collections.abc import Callable

from vf_agent import identity

Expand All @@ -19,6 +20,12 @@
HTTP_HOST = os.getenv("VF_AGENT_HOST", "0.0.0.0") # noqa: S104 - intended; agent serves on LAN
WAIT_FOR_IDENTITY_S = float(os.getenv("VF_AGENT_ADOPT_POLL", "2"))

# Respawn policy for child processes (HTTP server, heartbeat, celery worker).
RESPAWN_BACKOFF_INITIAL_S = 2.0
RESPAWN_BACKOFF_MAX_S = 60.0
RESPAWN_STABLE_AFTER_S = 300.0 # child stayed up this long -> reset its backoff
RESPAWN_MAX_CONSECUTIVE_FAILURES = 5


def _spawn_http() -> subprocess.Popen[bytes]:
cmd = [
Expand Down Expand Up @@ -65,14 +72,20 @@ def _spawn_celery(ident: identity.Identity) -> subprocess.Popen[bytes]:
return subprocess.Popen(cmd, env=env)


def _wait_for_identity() -> identity.Identity:
logger.info("waiting for adoption (POST /adopt on the HTTP server)...")
while True:
ident = identity.load()
if ident is not None:
logger.info("adopted as cluster %s", ident.cluster_id)
return ident
time.sleep(WAIT_FOR_IDENTITY_S)
class _Child:
"""A supervised child process with restart-with-backoff bookkeeping."""

def __init__(self, name: str, spawn: Callable[[], subprocess.Popen[bytes]]) -> None:
self.name = name
self.spawn = spawn
self.proc: subprocess.Popen[bytes] = spawn()
self.started_at = time.monotonic()
self.backoff_s = RESPAWN_BACKOFF_INITIAL_S
self.consecutive_failures = 0

def respawn(self) -> None:
self.proc = self.spawn()
self.started_at = time.monotonic()


def run() -> int:
Expand All @@ -85,38 +98,104 @@ def run() -> int:
logger.error("VF_AGENT_TOKEN is not set; refusing to start")
return 2

http = _spawn_http()
ident = _wait_for_identity()
heartbeat = _spawn_heartbeat()
celery_proc = _spawn_celery(ident)

children: list[subprocess.Popen[bytes]] = [http, heartbeat, celery_proc]
children: list[_Child] = []
shutting_down = False

def _shutdown(*_: object) -> None:
nonlocal shutting_down
shutting_down = True
logger.info("shutting down agent")
for child in children:
if child.poll() is None:
child.terminate()
if child.proc.poll() is None:
child.proc.terminate()
for child in children:
try:
child.wait(timeout=10)
child.proc.wait(timeout=10)
except subprocess.TimeoutExpired:
child.kill()
child.proc.kill()

signal.signal(signal.SIGTERM, _shutdown)
signal.signal(signal.SIGINT, _shutdown)

def _supervise(child: _Child) -> None:
"""Check one child; respawn with backoff if it died.

Raises SystemExit after RESPAWN_MAX_CONSECUTIVE_FAILURES consecutive
failed respawns of the same child.
"""
ret = child.proc.poll()
if ret is None:
# Child has stayed up long enough: consider it healthy again.
if (
child.consecutive_failures
and time.monotonic() - child.started_at >= RESPAWN_STABLE_AFTER_S
):
logger.info(
"child %s stable for %ss; resetting respawn backoff",
child.name,
int(RESPAWN_STABLE_AFTER_S),
)
child.consecutive_failures = 0
child.backoff_s = RESPAWN_BACKOFF_INITIAL_S
return

if shutting_down:
return

child.consecutive_failures += 1
if child.consecutive_failures > RESPAWN_MAX_CONSECUTIVE_FAILURES:
logger.error(
"child %s failed %s consecutive respawns (last exit code %s); "
"giving up and tearing down",
child.name,
RESPAWN_MAX_CONSECUTIVE_FAILURES,
ret,
)
raise SystemExit(ret or 1)

logger.error(
"child %s exited unexpectedly with code %s; respawning in %.0fs (attempt %s/%s)",
child.name,
ret,
child.backoff_s,
child.consecutive_failures,
RESPAWN_MAX_CONSECUTIVE_FAILURES,
)
time.sleep(child.backoff_s)
child.backoff_s = min(child.backoff_s * 2, RESPAWN_BACKOFF_MAX_S)
if not shutting_down:
child.respawn()

exit_code = 0
try:
while True:
http = _Child("http", _spawn_http)
children.append(http)

# Adoption gate: only the HTTP server runs until the backend POSTs
# /adopt. Keep supervising it (with respawn) while we wait.
logger.info("waiting for adoption (POST /adopt on the HTTP server)...")
ident: identity.Identity | None = None
while not shutting_down:
ident = identity.load()
if ident is not None:
logger.info("adopted as cluster %s", ident.cluster_id)
break
_supervise(http)
time.sleep(WAIT_FOR_IDENTITY_S)
if shutting_down or ident is None:
return exit_code

adopted = ident
children.append(_Child("heartbeat", _spawn_heartbeat))
children.append(_Child("celery", lambda: _spawn_celery(adopted)))

while not shutting_down:
for child in children:
ret = child.poll()
if ret is not None:
logger.error("child process exited with code %s; tearing down", ret)
exit_code = ret or 1
raise SystemExit(exit_code)
_supervise(child)
time.sleep(2)
except SystemExit:
return exit_code
except SystemExit as exc:
exit_code = int(exc.code) if isinstance(exc.code, int) else 1
_shutdown()
return exit_code

Expand Down
4 changes: 3 additions & 1 deletion agent/src/vf_agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from __future__ import annotations

import hmac
import os
from typing import Any

Expand All @@ -35,7 +36,8 @@ def _require_token(authorization: str | None = Header(default=None)) -> None:
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="missing bearer token")
presented = authorization.split(" ", 1)[1].strip()
if presented != expected:
# Constant-time comparison to avoid leaking token bytes via timing.
if not hmac.compare_digest(presented.encode("utf-8"), expected.encode("utf-8")):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid agent token")


Expand Down
40 changes: 36 additions & 4 deletions backend/src/app/api/al.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import random
from datetime import datetime, timezone

from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
Expand All @@ -14,8 +15,11 @@
from app.models.alrun import ALRun
from app.models.artifact import ModelArtifact
from app.models.asset import Asset
from app.models.dataset_version import DatasetVersion
from app.models.project import Project
from app.models.user import User
from app.services import inference_service
from app.models.workspace import Membership, Role
from app.services import authz, inference_service
from app.services.active_learning_service import select_diverse, select_uncertain
from app.services.asset_fetch import fetch_asset_bytes
from app.services.embeddings_service import EmbeddingsService
Expand All @@ -28,6 +32,15 @@
_INLINE_SCORE_CAP = int(os.getenv("VF_AL_INLINE_SCORE_CAP", "200"))


def _member_workspace_ids(db: Session, user: User) -> list[str]:
ws_ids = {
m.workspace_id
for m in db.scalars(select(Membership).where(Membership.user_id == user.id)).all()
}
ws_ids.add(authz.DEFAULT_WORKSPACE_ID)
return list(ws_ids)


def _uncertainty_scores(db: Session, assets: list, model_id: str | None) -> list[float]:
"""Compute uncertainty scores for ``assets`` using ``model_id`` if available.

Expand Down Expand Up @@ -149,13 +162,16 @@ def queue_uncertainty_scoring(
can read cached scores instead of running inference inline.
"""
from app.jobs.celery_app import celery_app
from app.models.dataset_version import DatasetVersion
from app.services.jobs_service import create_job, update_job_status

if not db.get(ModelArtifact, body.artifact_id):
artifact = db.get(ModelArtifact, body.artifact_id)
if not artifact:
raise HTTPException(status_code=400, detail="artifact not found")
if not db.get(DatasetVersion, body.dataset_version_id):
version = db.get(DatasetVersion, body.dataset_version_id)
if not version:
raise HTTPException(status_code=400, detail="dataset version not found")
authz.require_project_access(db, current_user, artifact.project_id, Role.DEVELOPER)
authz.require_dataset_access(db, current_user, version.dataset_id, Role.DEVELOPER)

payload = {
"artifactId": body.artifact_id,
Expand All @@ -182,6 +198,10 @@ def select_samples(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
authz.require_project_access(db, current_user, body.project_id, Role.DEVELOPER)
version = db.get(DatasetVersion, body.dataset_version_id)
if version is not None:
authz.require_dataset_access(db, current_user, version.dataset_id, Role.DEVELOPER)
# Get unlabeled assets for this version
assets = list(
db.scalars(
Expand Down Expand Up @@ -257,7 +277,12 @@ def list_runs(
):
base = select(ALRun)
if project_id:
authz.require_project_access(db, current_user, project_id, Role.VIEWER)
base = base.where(ALRun.project_id == project_id)
elif not authz.is_superuser(db, current_user):
base = base.join(Project, ALRun.project_id == Project.id).where(
Project.workspace_id.in_(_member_workspace_ids(db, current_user))
)
total = db.scalar(select(func.count()).select_from(base.subquery())) or 0
offset = (page - 1) * page_size
runs = list(
Expand Down Expand Up @@ -285,6 +310,9 @@ def get_al_items(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
run = db.get(ALRun, al_run_id)
if run is not None:
authz.require_project_access(db, current_user, run.project_id, Role.VIEWER)
items = list(db.scalars(select(ALItem).where(ALItem.al_run_id == al_run_id)).all())
return [
{
Expand All @@ -307,8 +335,12 @@ def resolve_item(
item = db.get(ALItem, item_id)
if not item or item.al_run_id != al_run_id:
raise HTTPException(status_code=404, detail="AL item not found")
run = db.get(ALRun, item.al_run_id)
if run is not None:
authz.require_project_access(db, current_user, run.project_id, Role.ANNOTATOR)
item.resolved_status = "resolved"
item.resolved_by = current_user.id
item.resolved_at = datetime.now(timezone.utc)
db.add(item)
db.commit()
db.refresh(item)
Expand Down
Loading
Loading