Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
/__pycache__/
/.pytest_cache/
/.mypy_cache/
/.hypothesis/
/dist/
/build/
/*.sqlite*
Expand Down
2 changes: 2 additions & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pytest>=8.2
pytest-cov>=5.0
schemathesis>=3.27
ultralytics>=8.3
timm>=1.0
matplotlib>=3.8
onnx>=1.16
onnxruntime>=1.18
open-clip-torch>=2.24
Expand Down
1 change: 1 addition & 0 deletions backend/src/app/api/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def train(
base_model=payload.baseModel,
owner_id=current_user.id,
cluster_id=payload.clusterId,
framework=payload.framework,
)
except cluster_service.ClusterNotAvailableError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
Expand Down
38 changes: 38 additions & 0 deletions backend/src/app/api/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Training capabilities API.

Exposes the registered training backends and their form schemas so the frontend
can render each framework's training form dynamically — no library-specific UI
code. Also provides a model-search endpoint for backends with large catalogues
(e.g. Timm's ~1300 architectures).
"""

from __future__ import annotations

from fastapi import APIRouter, Depends, HTTPException, Query

from app.db.deps import get_current_user
from app.models.user import User
from app.services import training as training_pkg

router = APIRouter(prefix="/api/training", tags=["training"])


@router.get("/frameworks")
def list_frameworks(current_user: User = Depends(get_current_user)) -> dict:
"""Return every registered training backend with its capability descriptor."""
return {"items": [t.capabilities().to_dict() for t in training_pkg.list_frameworks()]}


@router.get("/frameworks/{key}/models")
def search_models(
key: str,
task: str = Query("classify"),
query: str = Query("", alias="query"),
current_user: User = Depends(get_current_user),
) -> dict:
"""Search a framework's model catalogue (powers the searchable backbone picker)."""
try:
trainer = training_pkg.get_trainer(key)
except training_pkg.registry.UnknownFrameworkError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return {"items": trainer.list_models(task, query)}
34 changes: 34 additions & 0 deletions backend/src/app/db/migrations/versions/0008_training_framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Add training-framework columns; merge the cluster-discovery + phase2 heads.

The migration tree had two open heads (``0004_cluster_discovery`` and
``0007_phase2_task_type_and_review``). This revision merges them into a single
head so ``alembic upgrade head`` is unambiguous, and adds the nullable
``framework`` column to ``experiment_runs`` and ``model_artifacts`` used by the
pluggable training backends (Ultralytics / Timm / …).

Revision ID: 0008_training_framework
Revises: 0004_cluster_discovery, 0007_phase2_task_type_and_review
Create Date: 2026-05-30
"""

import sqlalchemy as sa
from alembic import op

revision = "0008_training_framework"
down_revision = ("0004_cluster_discovery", "0007_phase2_task_type_and_review")
branch_labels = None
depends_on = None


def upgrade() -> None:
with op.batch_alter_table("experiment_runs") as batch:
batch.add_column(sa.Column("framework", sa.String(length=32), nullable=True))
with op.batch_alter_table("model_artifacts") as batch:
batch.add_column(sa.Column("framework", sa.String(length=32), nullable=True))


def downgrade() -> None:
with op.batch_alter_table("model_artifacts") as batch:
batch.drop_column("framework")
with op.batch_alter_table("experiment_runs") as batch:
batch.drop_column("framework")
59 changes: 26 additions & 33 deletions backend/src/app/jobs/tasks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,12 @@ def evaluate_task(payload: dict) -> dict:


def _load_model(artifact, scratch_dir: Path) -> Any:
"""Lazy model loader. Downloads weights into `scratch_dir` (caller-managed)."""
"""Lazy model loader. Downloads weights into `scratch_dir` (caller-managed).

Returns ``(kind, predictor)`` where ``kind`` is ``"onnx"`` or the framework
key (e.g. ``"ultralytics"``, ``"timm"``). Framework dispatch goes through the
trainer registry so no library name is hard-coded here.
"""
storage_path = artifact.storage_path
if not storage_path:
raise RuntimeError("artifact has no storage_path")
Expand All @@ -684,54 +689,42 @@ def _load_model(artifact, scratch_dir: Path) -> Any:
"onnx",
ort.InferenceSession(str(local), providers=["CPUExecutionProvider"]),
)
# Default: YOLO weights
from ultralytics import YOLO # type: ignore
from app.services import training as training_pkg

return ("yolo", YOLO(str(local)))
framework = getattr(artifact, "framework", None)
trainer = training_pkg.get_trainer(framework)
return (trainer.key, trainer.load_predictor(local))


def _predict_detections(model, asset, score_threshold: float) -> list[dict[str, Any]]:
kind, m = model
if kind == "onnx":
return []
with tempfile.TemporaryDirectory() as tmp:
img_path = Path(tmp) / "img"
if not _download_to(asset.uri, img_path):
return []
if kind != "yolo":
from app.services import training as training_pkg

try:
return training_pkg.get_trainer(kind).predict_detections(
m, str(img_path), score_threshold
)
except NotImplementedError:
return []
results = m.predict(str(img_path), conf=score_threshold, verbose=False)
out: list[dict[str, Any]] = []
for r in results:
names = getattr(r, "names", {}) or {}
for box in getattr(r, "boxes", []) or []:
cls_idx = int(box.cls.item()) if hasattr(box, "cls") else 0
xy = box.xywh[0].tolist() if hasattr(box, "xywh") else [0, 0, 0, 0]
cx, cy, w, h = xy
out.append(
{
"class": names.get(cls_idx, str(cls_idx)),
"bbox": (cx - w / 2, cy - h / 2, w, h),
"score": (float(box.conf.item()) if hasattr(box, "conf") else 0.0),
}
)
return out


def _predict_classification(model, asset, classes: list[str]) -> tuple[str, float]:
kind, m = model
if kind == "onnx":
return ("unknown", 0.0)
with tempfile.TemporaryDirectory() as tmp:
img_path = Path(tmp) / "img"
if not _download_to(asset.uri, img_path):
return ("unknown", 0.0)
if kind != "yolo":
from app.services import training as training_pkg

try:
return training_pkg.get_trainer(kind).predict_classification(m, str(img_path))
except NotImplementedError:
return ("unknown", 0.0)
results = m.predict(str(img_path), verbose=False)
for r in results:
probs = getattr(r, "probs", None)
if probs is not None:
top = int(probs.top1)
names = getattr(r, "names", {}) or {}
return (
names.get(top, str(top)),
float(probs.top1conf.item()),
)
return ("unknown", 0.0)
34 changes: 16 additions & 18 deletions backend/src/app/jobs/tasks/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@ def export_task(payload: dict) -> dict:
if run is None:
raise ValueError(f"ExperimentRun {experiment_id!r} not found")

# Determine MinIO key for the .pt model
# Determine MinIO key for the .pt model and which framework produced it
# so we can dispatch to the right exporter (Ultralytics .export() vs
# timm torch.onnx.export vs …).
artifacts_data: list[dict] = json.loads(run.artifacts or "[]")
pt_key: str | None = None
framework: str | None = getattr(run, "framework", None)
for art in artifacts_data:
if art.get("type") == "pytorch":
pt_key = art.get("key")
framework = art.get("framework") or framework
break

# Allow override via payload
Expand Down Expand Up @@ -99,30 +103,24 @@ def export_task(payload: dict) -> dict:
if job_id:
update_job_status(db, job_id, status="running", progress=0.3)

# Export to ONNX using Ultralytics
# Export to ONNX via the framework's trainer (no hard-coded library).
onnx_path: Path | None = None
try:
from ultralytics import YOLO # type: ignore
from app.services import training as training_pkg

if pt_path.exists() and pt_path.stat().st_size > 0:
model = YOLO(str(pt_path))
export_kwargs: dict = {"format": "onnx", "simplify": True}
if dynamic_axes is not None:
export_kwargs["dynamic"] = bool(dynamic_axes)
else:
export_kwargs["dynamic"] = True
if opset is not None:
export_kwargs["opset"] = int(opset)
export_result = model.export(**export_kwargs)
# export() returns the path to the exported file
if export_result:
onnx_path = Path(str(export_result))
else:
# Fall back to searching for .onnx next to .pt
trainer = training_pkg.get_trainer(framework)
onnx_path = trainer.export_onnx(
pt_path,
tmp_path,
opset=opset,
dynamic=bool(dynamic_axes) if dynamic_axes is not None else True,
)
if onnx_path is None:
candidate = pt_path.with_suffix(".onnx")
if candidate.exists():
onnx_path = candidate
except ImportError:
except (ImportError, NotImplementedError):
pass

if job_id:
Expand Down
Loading
Loading