From bc10672d167611c7d0e2045a78609802c2876b18 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 30 May 2026 20:40:50 +0000 Subject: [PATCH 1/2] Add Timm image classification via pluggable training backends MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a generalized, registry-based training-backend abstraction so training libraries are no longer hard-coded into the orchestration code, and add Timm (PyTorch Image Models) as a first-class image-classification trainer alongside Ultralytics/YOLO. Backend - New app/services/training package: Trainer ABC + TrainContext/TrainResult, a registry (get_trainer/list_frameworks, defaults to ultralytics), shared dataset exporters (ImageFolder + YOLO-detect), and two trainers. - UltralyticsTrainer: behavior-preserving extraction of the existing YOLO training/export/predict logic behind the interface. - TimmTrainer: full classification loop with the complete Timm surface — any backbone (searchable), pooling head + drop rates, full optimizer/ scheduler matrix, mixup/cutmix, model EMA, label smoothing, AMP, channels-last, and the full augmentation pipeline; self-describing checkpoint; ONNX export and predictor. - Thin train_task / onnx_export / evaluation / inference now resolve the framework via the registry and delegate (no library names in orchestration). - framework column on ExperimentRun + ModelArtifact; migration 0008 merges the two open migration heads and adds the columns. - Capabilities API (/api/training/frameworks + model search) drives the form; TrainRequest.framework added and validated against the trainer's tasks. - requirements: add timm and matplotlib (propagate to agent images). Frontend - experiments/new.tsx is now schema-driven: a Framework selector, task/model/ hyperparameter sections rendered from the capabilities API, and a searchable Timm backbone picker. Ultralytics form is unchanged. Tests - Registry, shared dataset exporters, capabilities API, framework persistence/ validation, Timm capability schema, and a gated CPU smoke train. https://claude.ai/code/session_01CGFLcnxLvPEuF1y6W79nfA --- backend/requirements.txt | 2 + backend/src/app/api/ops.py | 1 + backend/src/app/api/training.py | 38 + .../versions/0008_training_framework.py | 34 + backend/src/app/jobs/tasks/evaluation.py | 59 +- backend/src/app/jobs/tasks/onnx_export.py | 34 +- backend/src/app/jobs/tasks/training.py | 453 +++-------- backend/src/app/main.py | 2 + backend/src/app/models/artifact.py | 3 + backend/src/app/models/experiment.py | 3 + backend/src/app/schemas/common.py | 3 +- backend/src/app/services/inference_service.py | 25 +- backend/src/app/services/training/__init__.py | 49 ++ backend/src/app/services/training/base.py | 170 +++++ backend/src/app/services/training/datasets.py | 153 ++++ backend/src/app/services/training/registry.py | 50 ++ .../src/app/services/training/timm_trainer.py | 710 ++++++++++++++++++ .../services/training/ultralytics_trainer.py | 365 +++++++++ backend/src/app/services/training_service.py | 17 +- backend/tests/unit/test_services_training.py | 57 +- backend/tests/unit/test_timm_trainer.py | 127 ++++ backend/tests/unit/test_training_datasets.py | 85 +++ .../unit/test_training_frameworks_api.py | 33 + backend/tests/unit/test_training_registry.py | 59 ++ frontend/src/pages/experiments/new.tsx | 562 +++++++------- specs/timm-image-classification/plan.md | 47 ++ specs/timm-image-classification/spec.md | 92 +++ specs/timm-image-classification/tasks.md | 37 + 28 files changed, 2532 insertions(+), 738 deletions(-) create mode 100644 backend/src/app/api/training.py create mode 100644 backend/src/app/db/migrations/versions/0008_training_framework.py create mode 100644 backend/src/app/services/training/__init__.py create mode 100644 backend/src/app/services/training/base.py create mode 100644 backend/src/app/services/training/datasets.py create mode 100644 backend/src/app/services/training/registry.py create mode 100644 backend/src/app/services/training/timm_trainer.py create mode 100644 backend/src/app/services/training/ultralytics_trainer.py create mode 100644 backend/tests/unit/test_timm_trainer.py create mode 100644 backend/tests/unit/test_training_datasets.py create mode 100644 backend/tests/unit/test_training_frameworks_api.py create mode 100644 backend/tests/unit/test_training_registry.py create mode 100644 specs/timm-image-classification/plan.md create mode 100644 specs/timm-image-classification/spec.md create mode 100644 specs/timm-image-classification/tasks.md diff --git a/backend/requirements.txt b/backend/requirements.txt index e25c0d1..84b5d07 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,6 +14,8 @@ pytest>=8.2 pytest-cov>=5.0 schemathesis>=3.27 ultralytics>=8.3 +timm>=1.0 +matplotlib>=3.8 onnx>=1.16 onnxruntime>=1.18 open-clip-torch>=2.24 diff --git a/backend/src/app/api/ops.py b/backend/src/app/api/ops.py index b79450b..f343fbc 100644 --- a/backend/src/app/api/ops.py +++ b/backend/src/app/api/ops.py @@ -53,6 +53,7 @@ def train( base_model=payload.baseModel, owner_id=current_user.id, cluster_id=payload.clusterId, + framework=payload.framework, ) except cluster_service.ClusterNotAvailableError as exc: raise HTTPException(status_code=409, detail=str(exc)) from exc diff --git a/backend/src/app/api/training.py b/backend/src/app/api/training.py new file mode 100644 index 0000000..3c7922a --- /dev/null +++ b/backend/src/app/api/training.py @@ -0,0 +1,38 @@ +"""Training capabilities API. + +Exposes the registered training backends and their form schemas so the frontend +can render each framework's training form dynamically — no library-specific UI +code. Also provides a model-search endpoint for backends with large catalogues +(e.g. Timm's ~1300 architectures). +""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, HTTPException, Query + +from app.db.deps import get_current_user +from app.models.user import User +from app.services import training as training_pkg + +router = APIRouter(prefix="/api/training", tags=["training"]) + + +@router.get("/frameworks") +def list_frameworks(current_user: User = Depends(get_current_user)) -> dict: + """Return every registered training backend with its capability descriptor.""" + return {"items": [t.capabilities().to_dict() for t in training_pkg.list_frameworks()]} + + +@router.get("/frameworks/{key}/models") +def search_models( + key: str, + task: str = Query("classify"), + query: str = Query("", alias="query"), + current_user: User = Depends(get_current_user), +) -> dict: + """Search a framework's model catalogue (powers the searchable backbone picker).""" + try: + trainer = training_pkg.get_trainer(key) + except training_pkg.registry.UnknownFrameworkError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + return {"items": trainer.list_models(task, query)} diff --git a/backend/src/app/db/migrations/versions/0008_training_framework.py b/backend/src/app/db/migrations/versions/0008_training_framework.py new file mode 100644 index 0000000..eb37ec0 --- /dev/null +++ b/backend/src/app/db/migrations/versions/0008_training_framework.py @@ -0,0 +1,34 @@ +"""Add training-framework columns; merge the cluster-discovery + phase2 heads. + +The migration tree had two open heads (``0004_cluster_discovery`` and +``0007_phase2_task_type_and_review``). This revision merges them into a single +head so ``alembic upgrade head`` is unambiguous, and adds the nullable +``framework`` column to ``experiment_runs`` and ``model_artifacts`` used by the +pluggable training backends (Ultralytics / Timm / …). + +Revision ID: 0008_training_framework +Revises: 0004_cluster_discovery, 0007_phase2_task_type_and_review +Create Date: 2026-05-30 +""" + +import sqlalchemy as sa +from alembic import op + +revision = "0008_training_framework" +down_revision = ("0004_cluster_discovery", "0007_phase2_task_type_and_review") +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("experiment_runs") as batch: + batch.add_column(sa.Column("framework", sa.String(length=32), nullable=True)) + with op.batch_alter_table("model_artifacts") as batch: + batch.add_column(sa.Column("framework", sa.String(length=32), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("model_artifacts") as batch: + batch.drop_column("framework") + with op.batch_alter_table("experiment_runs") as batch: + batch.drop_column("framework") diff --git a/backend/src/app/jobs/tasks/evaluation.py b/backend/src/app/jobs/tasks/evaluation.py index 078d917..5e4ea6b 100644 --- a/backend/src/app/jobs/tasks/evaluation.py +++ b/backend/src/app/jobs/tasks/evaluation.py @@ -669,7 +669,12 @@ def evaluate_task(payload: dict) -> dict: def _load_model(artifact, scratch_dir: Path) -> Any: - """Lazy model loader. Downloads weights into `scratch_dir` (caller-managed).""" + """Lazy model loader. Downloads weights into `scratch_dir` (caller-managed). + + Returns ``(kind, predictor)`` where ``kind`` is ``"onnx"`` or the framework + key (e.g. ``"ultralytics"``, ``"timm"``). Framework dispatch goes through the + trainer registry so no library name is hard-coded here. + """ storage_path = artifact.storage_path if not storage_path: raise RuntimeError("artifact has no storage_path") @@ -684,54 +689,42 @@ def _load_model(artifact, scratch_dir: Path) -> Any: "onnx", ort.InferenceSession(str(local), providers=["CPUExecutionProvider"]), ) - # Default: YOLO weights - from ultralytics import YOLO # type: ignore + from app.services import training as training_pkg - return ("yolo", YOLO(str(local))) + framework = getattr(artifact, "framework", None) + trainer = training_pkg.get_trainer(framework) + return (trainer.key, trainer.load_predictor(local)) def _predict_detections(model, asset, score_threshold: float) -> list[dict[str, Any]]: kind, m = model + if kind == "onnx": + return [] with tempfile.TemporaryDirectory() as tmp: img_path = Path(tmp) / "img" if not _download_to(asset.uri, img_path): return [] - if kind != "yolo": + from app.services import training as training_pkg + + try: + return training_pkg.get_trainer(kind).predict_detections( + m, str(img_path), score_threshold + ) + except NotImplementedError: return [] - results = m.predict(str(img_path), conf=score_threshold, verbose=False) - out: list[dict[str, Any]] = [] - for r in results: - names = getattr(r, "names", {}) or {} - for box in getattr(r, "boxes", []) or []: - cls_idx = int(box.cls.item()) if hasattr(box, "cls") else 0 - xy = box.xywh[0].tolist() if hasattr(box, "xywh") else [0, 0, 0, 0] - cx, cy, w, h = xy - out.append( - { - "class": names.get(cls_idx, str(cls_idx)), - "bbox": (cx - w / 2, cy - h / 2, w, h), - "score": (float(box.conf.item()) if hasattr(box, "conf") else 0.0), - } - ) - return out def _predict_classification(model, asset, classes: list[str]) -> tuple[str, float]: kind, m = model + if kind == "onnx": + return ("unknown", 0.0) with tempfile.TemporaryDirectory() as tmp: img_path = Path(tmp) / "img" if not _download_to(asset.uri, img_path): return ("unknown", 0.0) - if kind != "yolo": + from app.services import training as training_pkg + + try: + return training_pkg.get_trainer(kind).predict_classification(m, str(img_path)) + except NotImplementedError: return ("unknown", 0.0) - results = m.predict(str(img_path), verbose=False) - for r in results: - probs = getattr(r, "probs", None) - if probs is not None: - top = int(probs.top1) - names = getattr(r, "names", {}) or {} - return ( - names.get(top, str(top)), - float(probs.top1conf.item()), - ) - return ("unknown", 0.0) diff --git a/backend/src/app/jobs/tasks/onnx_export.py b/backend/src/app/jobs/tasks/onnx_export.py index db828a1..2370599 100644 --- a/backend/src/app/jobs/tasks/onnx_export.py +++ b/backend/src/app/jobs/tasks/onnx_export.py @@ -51,12 +51,16 @@ def export_task(payload: dict) -> dict: if run is None: raise ValueError(f"ExperimentRun {experiment_id!r} not found") - # Determine MinIO key for the .pt model + # Determine MinIO key for the .pt model and which framework produced it + # so we can dispatch to the right exporter (Ultralytics .export() vs + # timm torch.onnx.export vs …). artifacts_data: list[dict] = json.loads(run.artifacts or "[]") pt_key: str | None = None + framework: str | None = getattr(run, "framework", None) for art in artifacts_data: if art.get("type") == "pytorch": pt_key = art.get("key") + framework = art.get("framework") or framework break # Allow override via payload @@ -99,30 +103,24 @@ def export_task(payload: dict) -> dict: if job_id: update_job_status(db, job_id, status="running", progress=0.3) - # Export to ONNX using Ultralytics + # Export to ONNX via the framework's trainer (no hard-coded library). onnx_path: Path | None = None try: - from ultralytics import YOLO # type: ignore + from app.services import training as training_pkg if pt_path.exists() and pt_path.stat().st_size > 0: - model = YOLO(str(pt_path)) - export_kwargs: dict = {"format": "onnx", "simplify": True} - if dynamic_axes is not None: - export_kwargs["dynamic"] = bool(dynamic_axes) - else: - export_kwargs["dynamic"] = True - if opset is not None: - export_kwargs["opset"] = int(opset) - export_result = model.export(**export_kwargs) - # export() returns the path to the exported file - if export_result: - onnx_path = Path(str(export_result)) - else: - # Fall back to searching for .onnx next to .pt + trainer = training_pkg.get_trainer(framework) + onnx_path = trainer.export_onnx( + pt_path, + tmp_path, + opset=opset, + dynamic=bool(dynamic_axes) if dynamic_axes is not None else True, + ) + if onnx_path is None: candidate = pt_path.with_suffix(".onnx") if candidate.exists(): onnx_path = candidate - except ImportError: + except (ImportError, NotImplementedError): pass if job_id: diff --git a/backend/src/app/jobs/tasks/training.py b/backend/src/app/jobs/tasks/training.py index dcd7b42..1627662 100644 --- a/backend/src/app/jobs/tasks/training.py +++ b/backend/src/app/jobs/tasks/training.py @@ -1,5 +1,15 @@ +"""Celery training task — a thin orchestrator over the pluggable trainers. + +Framework-agnostic scaffolding (load the run, resolve the dataset + split, set +up MinIO, report progress/metrics, persist the model artifact and plots, mark +the job done) lives here. The library-specific training is delegated to a +:class:`~app.services.training.base.Trainer` resolved from the run's framework, +so adding a new library never touches this file. +""" + from __future__ import annotations +import hashlib import json import os import tempfile @@ -29,241 +39,6 @@ def _make_session(): return sessionmaker(bind=engine, autoflush=False, autocommit=False)() -def _extract_minio_key(uri: str) -> str: - """Extract the object key from an asset URI (strips bucket prefix if present).""" - # URI may be: s3://bucket/key, minio://bucket/key, http://host/bucket/key, or plain key - for prefix in ("s3://", "minio://"): - if uri.startswith(prefix): - parts = uri[len(prefix) :].split("/", 1) - return parts[1] if len(parts) > 1 else uri - if uri.startswith("http://") or uri.startswith("https://"): - # http://host/bucket/key → key is everything after second / - path = uri.split("/", 3) - return path[3] if len(path) > 3 else uri - return uri - - -# Ultralytics emits verbose metric keys (e.g. "metrics/mAP50(B)", -# "train/box_loss", "lr/pg0"). Map them to clean, stable keys the frontend -# charts read. Raw keys are preserved alongside these for completeness. -_METRIC_KEY_MAP = { - "metrics/mAP50(B)": "mAP50", - "metrics/mAP50-95(B)": "mAP50_95", - "metrics/precision(B)": "precision", - "metrics/recall(B)": "recall", - "train/box_loss": "train_box_loss", - "train/cls_loss": "train_cls_loss", - "train/dfl_loss": "train_dfl_loss", - "val/box_loss": "val_box_loss", - "val/cls_loss": "val_cls_loss", - "val/dfl_loss": "val_dfl_loss", - "lr/pg0": "lr", - "metrics/accuracy_top1": "top1", - "metrics/accuracy_top5": "top5", -} - -# Allow-list of Ultralytics ``model.train()`` arguments users may tune. Keys not -# present here are ignored so callers cannot inject unsafe/irrelevant kwargs. -# (data / project / name / device / plots are handled separately.) -ULTRALYTICS_TRAIN_ARGS: dict[str, Any] = { - # Core - "epochs": 50, - "imgsz": 640, - "batch": 16, - "patience": 100, - "rect": False, - "single_cls": False, - "seed": 0, - # Optimizer & schedule - "optimizer": "auto", - "lr0": 0.01, - "lrf": 0.01, - "momentum": 0.937, - "weight_decay": 0.0005, - "warmup_epochs": 3.0, - "warmup_momentum": 0.8, - "warmup_bias_lr": 0.1, - "cos_lr": False, - "close_mosaic": 10, - "nbs": 64, - "amp": True, - # Regularization / loss gains - "dropout": 0.0, - "label_smoothing": 0.0, - "box": 7.5, - "cls": 0.5, - "dfl": 1.5, - "overlap_mask": True, - "mask_ratio": 4, - "freeze": None, - # Augmentation - "hsv_h": 0.015, - "hsv_s": 0.7, - "hsv_v": 0.4, - "degrees": 0.0, - "translate": 0.1, - "scale": 0.5, - "shear": 0.0, - "perspective": 0.0, - "flipud": 0.0, - "fliplr": 0.5, - "bgr": 0.0, - "mosaic": 1.0, - "mixup": 0.0, - "copy_paste": 0.0, - "erasing": 0.4, - "crop_fraction": 1.0, - "auto_augment": "randaugment", -} - -# Ultralytics writes these plot images into the run directory when plots=True. -_PLOT_FILES = [ - "results.png", - "PR_curve.png", - "P_curve.png", - "R_curve.png", - "F1_curve.png", - "confusion_matrix.png", - "confusion_matrix_normalized.png", - "labels.jpg", - "BoxPR_curve.png", -] - - -def _normalize_metrics(raw: dict) -> dict: - """Return a metric dict with clean keys mapped in, raw keys preserved.""" - out: dict[str, Any] = {} - for k, v in raw.items(): - if k in _METRIC_KEY_MAP: - out[_METRIC_KEY_MAP[k]] = v - out[k] = v - return out - - -def _build_yolo_dataset( - assets: list[Any], - annotations_by_asset: dict[str, list[Any]], - class_names: list[str], - output_dir: Path, - minio_client: Any, - bucket: str, - task_type: str = "detect", - splits: dict[str, str] | None = None, -) -> Path: - """Export assets and annotations to YOLO dataset format and return path to data.yaml. - - ``splits`` maps asset id -> "train" | "val" | "test". Test assets are held - out entirely (never written) so they are a true unseen set for evaluation. - """ - splits = splits or {} - images_train = output_dir / "train" / "images" - images_val = output_dir / "val" / "images" - labels_train = output_dir / "train" / "labels" - labels_val = output_dir / "val" / "labels" - - if task_type == "classify": - for split in ("train", "val"): - for cls in class_names: - (output_dir / split / cls).mkdir(parents=True, exist_ok=True) - else: - for d in (images_train, images_val, labels_train, labels_val): - d.mkdir(parents=True, exist_ok=True) - - class_idx: dict[str, int] = {name: i for i, name in enumerate(class_names)} - - for asset in assets: - split = splits.get(asset.id, "train") - if split == "test": - # Held out — never exported into the training/val set. - continue - ext = Path(asset.uri).suffix or ".jpg" - img_name = f"{asset.id}{ext}" - - # Download image from MinIO - try: - key = _extract_minio_key(asset.uri) - response = minio_client.get_object(bucket, key) - img_data = response.read() - response.close() - response.release_conn() - except Exception: - img_data = None - - if task_type == "classify": - # Determine class from annotations - ann_list = annotations_by_asset.get(asset.id, []) - cls_name = ( - ann_list[0].class_name - if ann_list - else (class_names[0] if class_names else "unknown") - ) - dest_dir = output_dir / split / cls_name - dest_dir.mkdir(parents=True, exist_ok=True) - dest_img = dest_dir / img_name - if img_data: - dest_img.write_bytes(img_data) - else: - dest_img.write_bytes(b"") - else: - dest_img = (images_train if split == "train" else images_val) / img_name - if img_data: - dest_img.write_bytes(img_data) - else: - dest_img.write_bytes(b"") - - # Write label file - label_dir = labels_train if split == "train" else labels_val - label_file = label_dir / f"{asset.id}.txt" - ann_list = annotations_by_asset.get(asset.id, []) - lines: list[str] = [] - for ann in ann_list: - try: - geo = ( - json.loads(ann.geometry) if isinstance(ann.geometry, str) else ann.geometry - ) - except Exception: - continue - cls_name = ann.class_name or "" - idx = class_idx.get(cls_name, 0) - w_img = asset.width or 640 - h_img = asset.height or 640 - # Support both {x,y,w,h} and {x1,y1,x2,y2} geometry formats - if "x1" in geo: - bx = (geo["x1"] + geo["x2"]) / 2 / w_img - by = (geo["y1"] + geo["y2"]) / 2 / h_img - bw = (geo["x2"] - geo["x1"]) / w_img - bh = (geo["y2"] - geo["y1"]) / h_img - else: - bx = (geo.get("x", 0) + geo.get("w", 0) / 2) / w_img - by = (geo.get("y", 0) + geo.get("h", 0) / 2) / h_img - bw = geo.get("w", 0) / w_img - bh = geo.get("h", 0) / h_img - lines.append(f"{idx} {bx:.6f} {by:.6f} {bw:.6f} {bh:.6f}") - label_file.write_text("\n".join(lines)) - - # Write data.yaml - if task_type == "classify": - data_yaml = output_dir / "data.yaml" - yaml_content = ( - f"path: {output_dir}\n" - f"train: train\n" - f"val: val\n" - f"nc: {len(class_names)}\n" - f"names: {json.dumps(class_names)}\n" - ) - else: - data_yaml = output_dir / "data.yaml" - yaml_content = ( - f"path: {output_dir}\n" - f"train: train/images\n" - f"val: val/images\n" - f"nc: {len(class_names)}\n" - f"names: {json.dumps(class_names)}\n" - ) - data_yaml.write_text(yaml_content) - return data_yaml - - @shared_task(name="app.jobs.tasks.training.train_task") def train_task(payload: dict) -> dict: job_id = payload.get("jobId") @@ -279,14 +54,14 @@ def train_task(payload: dict) -> dict: from app.models.dataset import ClassMap, Dataset from app.models.dataset_version import DatasetVersion from app.models.experiment import ExperimentRun + from app.services import training as training_pkg from app.services.jobs_service import update_job_status from app.services.storage import ensure_bucket, get_minio_client + from app.services.training.base import TrainContext - # Mark job running if job_id: update_job_status(db, job_id, status="running", progress=0.05) - # Fetch ExperimentRun run: ExperimentRun | None = db.get(ExperimentRun, experiment_id) if experiment_id else None if run is None: raise ValueError(f"ExperimentRun {experiment_id!r} not found") @@ -297,21 +72,22 @@ def train_task(payload: dict) -> dict: db.commit() params: dict = json.loads(run.params_json or "{}") if run.params_json else {} - base_model = params.get("base_model", "yolov8n.pt") - task_type = params.get("task", "detect") # detect | classify | segment + task_type = params.get("task", "detect") + framework = ( + params.get("framework") + or getattr(run, "framework", None) + or training_pkg.registry.DEFAULT_FRAMEWORK + ) - # Fetch DatasetVersion and related assets/annotations + # Resolve assets / class names / annotations for the version. version_id = run.dataset_version_id assets: list[Asset] = [] class_names: list[str] = [] annotations_by_asset: dict[str, list[Annotation]] = {} - if version_id: assets = db.query(Asset).filter(Asset.version_id == version_id).all() version: DatasetVersion | None = db.get(DatasetVersion, version_id) dataset: Dataset | None = db.get(Dataset, version.dataset_id) if version else None - - # Build class list from ClassMap if dataset and dataset.class_map_id: cm: ClassMap | None = db.get(ClassMap, dataset.class_map_id) if cm: @@ -320,15 +96,11 @@ def train_task(payload: dict) -> dict: class_names = [c["name"] for c in raw] else: class_names = [str(c) for c in raw] - - # Fetch all annotations for these assets asset_ids = [a.id for a in assets] if asset_ids: ann_rows = db.query(Annotation).filter(Annotation.asset_id.in_(asset_ids)).all() for ann in ann_rows: annotations_by_asset.setdefault(ann.asset_id, []).append(ann) - - # If no ClassMap, derive class names from annotations if not class_names: seen: list[str] = [] for ann_list in annotations_by_asset.values(): @@ -337,10 +109,8 @@ def train_task(payload: dict) -> dict: seen.append(ann.class_name) class_names = seen or ["object"] - # Resolve the train/val/test split for every asset. Honor a split that - # was already persisted on the version (via the split endpoint); fall - # back to a deterministic, reproducible hash split for any asset that - # has none, using the ratios/seed carried in the run params. + # Resolve the train/val/test split for every asset (honor persisted + # split; fall back to deterministic hash split). from app.services.split_service import ( DEFAULT_SEED, asset_split, @@ -373,11 +143,11 @@ def train_task(payload: dict) -> dict: except Exception: minio_client = None + # Single metrics blob persisted throughout: per-epoch history, the + # resolved split, a final summary, and links to generated plots. epoch_metrics: list[dict] = [] - best_pt_path: Path | None = None - # Single metrics blob persisted throughout the run: per-epoch history, - # the resolved split, a final summary, and links to generated plots. metrics_blob: dict[str, Any] = { + "framework": framework, "epochs": epoch_metrics, "split": { "counts": split_counts, @@ -389,167 +159,120 @@ def train_task(payload: dict) -> dict: db.add(run) db.commit() - with tempfile.TemporaryDirectory() as tmp: - tmp_path = Path(tmp) - dataset_dir = tmp_path / "dataset" - output_dir = tmp_path / "output" - output_dir.mkdir(parents=True) + def report(progress: float | None = None, epoch: dict | None = None) -> None: + """Uniform progress/metrics callback handed to every trainer.""" + if epoch is not None: + epoch_metrics.append(epoch) + try: + run.metrics_json = json.dumps(metrics_blob) + db.add(run) + db.commit() + except Exception: + pass + if progress is not None and job_id: + try: + update_job_status(db, job_id, status="running", progress=progress) + except Exception: + pass - # Export dataset to YOLO format - data_yaml = _build_yolo_dataset( + minio_key: str | None = None + with tempfile.TemporaryDirectory() as tmp: + work_dir = Path(tmp) + ctx = TrainContext( + db=db, + run=run, + params=params, + task=task_type, + class_names=class_names, assets=assets, annotations_by_asset=annotations_by_asset, - class_names=class_names, - output_dir=dataset_dir, + splits=splits, minio_client=minio_client, bucket=bucket, - task_type=task_type, - splits=splits, + work_dir=work_dir, + report=report, ) - if job_id: - update_job_status(db, job_id, status="running", progress=0.2) - - # Run YOLO training try: - from ultralytics import YOLO # type: ignore - - model = YOLO(base_model) - total_epochs = params.get("epochs", 50) - - def on_train_epoch_end(trainer: Any) -> None: # noqa: ANN001 - epoch_num = getattr(trainer, "epoch", 0) - metrics_dict = {} - if hasattr(trainer, "metrics"): - raw_metrics = trainer.metrics - if hasattr(raw_metrics, "results_dict"): - metrics_dict = dict(raw_metrics.results_dict) - elif isinstance(raw_metrics, dict): - metrics_dict = raw_metrics - if hasattr(trainer, "loss"): - loss_val = trainer.loss - metrics_dict["loss"] = float(loss_val) if loss_val is not None else None - entry = {"epoch": epoch_num, **_normalize_metrics(metrics_dict)} - epoch_metrics.append(entry) - - # Persist metrics to DB (epoch history lives inside the blob) - try: - run.metrics_json = json.dumps(metrics_blob) - db.add(run) - db.commit() - except Exception: - pass - - # Update job progress - progress = 0.2 + 0.75 * (epoch_num / max(total_epochs, 1)) - try: - if job_id: - update_job_status(db, job_id, status="running", progress=progress) - except Exception: - pass - - model.add_callback("on_train_epoch_end", on_train_epoch_end) - - train_kwargs: dict = { - "data": str(data_yaml), - "device": params.get("device", "cpu"), - "project": str(output_dir), - "name": "train", - "plots": True, # emit PR/confusion/results plots for the UI - } - # Pull every tunable hyperparameter / augmentation knob from the - # allow-list, taking the user's value when present and the - # Ultralytics default otherwise. `epochs` stays bound to - # total_epochs so progress reporting matches. - for key, default in ULTRALYTICS_TRAIN_ARGS.items(): - val = params.get(key, default) - if val is not None: - train_kwargs[key] = val - train_kwargs["epochs"] = total_epochs - model.train(**train_kwargs) - - # Locate best.pt - candidate = output_dir / "train" / "weights" / "best.pt" - if candidate.exists(): - best_pt_path = candidate - else: - # Fallback: search for any .pt file - pt_files = list(output_dir.rglob("*.pt")) - best_pt_path = pt_files[0] if pt_files else None - - except ImportError: - # Ultralytics not installed — record that training was skipped - epoch_metrics = [{"epoch": 0, "note": "ultralytics not available"}] - run.metrics_json = json.dumps({"epochs": epoch_metrics}) + trainer = training_pkg.get_trainer(framework) + train_result = trainer.run(ctx) + except ImportError as imp_err: + # ML wheel for this framework missing in the worker image — + # record a clean skip rather than crashing (mirrors prior behavior). + epoch_metrics.append({"epoch": 0, "note": f"{framework} not available: {imp_err}"}) + run.metrics_json = json.dumps(metrics_blob) db.add(run) db.commit() + train_result = None + + best_pt_path = train_result.best_model_path if train_result else None - # Upload best.pt to MinIO and create ModelArtifact - minio_key: str | None = None + # Persist the trained model artifact. if best_pt_path and minio_client: try: minio_key = f"models/{experiment_id}/best.pt" minio_client.fput_object(bucket, minio_key, str(best_pt_path)) - size_bytes = best_pt_path.stat().st_size - import hashlib - checksum = hashlib.md5(best_pt_path.read_bytes()).hexdigest() - artifact = ModelArtifact( id=str(uuid.uuid4()), project_id=run.project_id, run_id=run.id, version=1, type="pytorch", + format="pytorch", + framework=framework, + storage_path=minio_key, checksum=checksum, size_bytes=size_bytes, ) - # Store storage_path in checksum field if no dedicated column exists - # (ModelArtifact has no storage_path column per the model definition) db.add(artifact) db.commit() db.refresh(artifact) - - # Record artifact reference in ExperimentRun.artifacts JSON artifacts_data = json.loads(run.artifacts or "[]") - artifacts_data.append({"id": artifact.id, "type": "pytorch", "key": minio_key}) + entry = { + "id": artifact.id, + "type": "pytorch", + "key": minio_key, + "framework": framework, + } + if train_result and train_result.artifact_meta: + entry["config"] = train_result.artifact_meta + artifacts_data.append(entry) run.artifacts = json.dumps(artifacts_data) db.add(run) db.commit() except Exception as upload_err: - # Log but don't fail the task epoch_metrics.append({"warning": f"model upload failed: {upload_err}"}) - # Upload Ultralytics-generated plots (PR curve, confusion matrix, - # results grid, label distribution) so the UI can render them, and - # record a final metric summary. Must run inside the temp-dir block - # so the plot files still exist on disk. + # Persist generated plots. plot_records: list[dict] = [] - results_dir = output_dir / "train" - if minio_client and results_dir.exists(): + if minio_client and train_result and train_result.plot_files: from app.services import storage as _storage - for fname in _PLOT_FILES: - fpath = results_dir / fname - if not fpath.exists(): + for fpath in train_result.plot_files: + if not Path(fpath).exists(): continue try: - ext = fpath.suffix.lstrip(".").lower() + ext = Path(fpath).suffix.lstrip(".").lower() ctype = "image/jpeg" if ext in ("jpg", "jpeg") else "image/png" - key = f"models/{experiment_id}/plots/{fname}" + key = f"models/{experiment_id}/plots/{Path(fpath).name}" _storage.put_bytes( minio_client, key, - fpath.read_bytes(), + Path(fpath).read_bytes(), content_type=ctype, bucket=bucket, ) - plot_records.append({"name": fpath.stem, "file": fname, "key": key}) + plot_records.append( + {"name": Path(fpath).stem, "file": Path(fpath).name, "key": key} + ) except Exception: continue metrics_blob["plots"] = plot_records - if epoch_metrics: + if train_result and train_result.final_metrics: + metrics_blob["summary"] = train_result.final_metrics + elif epoch_metrics: metrics_blob["summary"] = { k: v for k, v in epoch_metrics[-1].items() if k != "epoch" } @@ -563,7 +286,6 @@ def on_train_epoch_end(trainer: Any) -> None: # noqa: ANN001 if job_id: update_job_status(db, job_id, status="running", progress=0.97) - # Mark experiment run as succeeded run.status = "succeeded" run.completed_at = datetime.now(timezone.utc) if not run.metrics_json: @@ -577,6 +299,7 @@ def on_train_epoch_end(trainer: Any) -> None: # noqa: ANN001 result = { "status": "succeeded", "experiment_id": experiment_id, + "framework": framework, "epochs_completed": len(epoch_metrics), "model_key": minio_key, } diff --git a/backend/src/app/main.py b/backend/src/app/main.py index 00df54a..bfa4e2a 100644 --- a/backend/src/app/main.py +++ b/backend/src/app/main.py @@ -25,6 +25,7 @@ from app.api.middleware import auth_rate_limit_middleware, logging_middleware from app.api.ops import router as ops_router from app.api.projects import router as projects_router +from app.api.training import router as training_router from app.db.deps import get_db from app.models.job import Job as JobModel from app.observability.logging import configure_logging @@ -129,6 +130,7 @@ def _run_db_migrations(): app.include_router(clusters_router) app.include_router(agents_router) app.include_router(evaluations_router) +app.include_router(training_router) # New feature routers (conditionally registered based on availability) if _has_annotations: diff --git a/backend/src/app/models/artifact.py b/backend/src/app/models/artifact.py index baeac5a..df198a0 100644 --- a/backend/src/app/models/artifact.py +++ b/backend/src/app/models/artifact.py @@ -24,6 +24,9 @@ class ModelArtifact(Base): version: Mapped[str] = mapped_column(String(32), nullable=False, default="1") type: Mapped[str] = mapped_column(String(50), nullable=False) # pytorch, onnx format: Mapped[str] = mapped_column(String(32), nullable=False, default="pytorch") + # Training library that produced this artifact (ultralytics | timm | …). + # Lets inference/export route to the right loader without inspecting weights. + framework: Mapped[str | None] = mapped_column(String(32), nullable=True) checksum: Mapped[str | None] = mapped_column(String(128), nullable=True) size_bytes: Mapped[int | None] = mapped_column(Integer, nullable=True) storage_path: Mapped[str | None] = mapped_column(String(1024), nullable=True) # MinIO key diff --git a/backend/src/app/models/experiment.py b/backend/src/app/models/experiment.py index 1ed368e..2a3985e 100644 --- a/backend/src/app/models/experiment.py +++ b/backend/src/app/models/experiment.py @@ -26,6 +26,9 @@ class ExperimentRun(Base): ) name: Mapped[str] = mapped_column(String(255), nullable=False, default="Unnamed Run") status: Mapped[str] = mapped_column(String(20), nullable=False, default="queued") + # Training library for this run (ultralytics | timm | …). Nullable for rows + # created before pluggable frameworks; treated as "ultralytics" downstream. + framework: Mapped[str | None] = mapped_column(String(32), nullable=True) params_json: Mapped[str | None] = mapped_column( "params", Text, nullable=True ) # DB column "params" mapped to params_json diff --git a/backend/src/app/schemas/common.py b/backend/src/app/schemas/common.py index 8ca65ab..8d80ef7 100644 --- a/backend/src/app/schemas/common.py +++ b/backend/src/app/schemas/common.py @@ -18,7 +18,8 @@ class UploadUrlResponse(BaseModel): class TrainRequest(BaseModel): projectId: str datasetVersionId: str - task: str # "detect" or "classify" + task: str # "detect" | "classify" | "segment" | "pose" + framework: str = "ultralytics" # which training backend to use baseModel: str = "yolov8n.pt" params: dict[str, Any] = {} name: str = "Training Run" diff --git a/backend/src/app/services/inference_service.py b/backend/src/app/services/inference_service.py index 08f0041..1375826 100644 --- a/backend/src/app/services/inference_service.py +++ b/backend/src/app/services/inference_service.py @@ -107,11 +107,18 @@ def predict( f.write(image_bytes) img_path = f.name try: - if kind == "yolo": - return _yolo_predict(model, img_path, score_threshold) if kind == "onnx": return _onnx_predict(model, img_path, score_threshold) - raise InferenceError(f"unsupported model kind: {kind}") + if kind == "ultralytics": + return _yolo_predict(model, img_path, score_threshold) + # Any other framework (e.g. timm) is a classifier: delegate to its trainer. + from app.services import training as training_pkg + + try: + cls_name, score = training_pkg.get_trainer(kind).predict_classification(model, img_path) + return {"detections": [], "classification": {"class": cls_name, "score": score}} + except NotImplementedError as exc: + raise InferenceError(f"unsupported model kind: {kind}") from exc finally: try: os.unlink(img_path) @@ -149,12 +156,16 @@ def _load_artifact(artifact: ModelArtifact) -> tuple[str, Any, Path]: scratch, ) + # PyTorch weights: resolve the producing framework via the trainer registry + # (no hard-coded library). ``kind`` becomes the framework key. + from app.services import training as training_pkg + try: - from ultralytics import YOLO # type: ignore - except Exception as exc: # pragma: no cover + trainer = training_pkg.get_trainer(getattr(artifact, "framework", None)) + return (trainer.key, trainer.load_predictor(local), scratch) + except Exception as exc: # pragma: no cover - depends on optional ML wheels _safe_rmtree(scratch) - raise InferenceError("ultralytics not installed") from exc - return ("yolo", YOLO(str(local)), scratch) + raise InferenceError(f"could not load model: {exc}") from exc def _download_to(uri: str, dest: Path) -> bool: diff --git a/backend/src/app/services/training/__init__.py b/backend/src/app/services/training/__init__.py new file mode 100644 index 0000000..eabf01c --- /dev/null +++ b/backend/src/app/services/training/__init__.py @@ -0,0 +1,49 @@ +"""Pluggable training-backend abstraction. + +VisionForge supports more than one training library (Ultralytics/YOLO for +detection-family tasks, Timm for image classification, and — by design — others +in the future). Rather than hard-coding library names throughout the Celery +tasks, evaluation, and inference code, each library is wrapped in a ``Trainer`` +implementation (see ``base.py``) and registered in ``registry.py``. Orchestration +code resolves the right trainer by its ``key`` and delegates. + +Importing this package registers the built-in trainers as a side effect, so +``from app.services import training`` is enough to populate the registry. +""" + +from __future__ import annotations + +from app.services.training.base import ( + Capabilities, + FieldDef, + FieldGroup, + TrainContext, + Trainer, + TrainResult, +) +from app.services.training.registry import get_trainer, list_frameworks, register + +# Register built-in trainers. Each import self-registers via ``register(...)``. +# Wrapped defensively so a missing optional dependency at import time never +# prevents the rest of the app from starting. +try: # pragma: no cover - exercised indirectly + from app.services.training import ultralytics_trainer # noqa: F401 +except Exception: # pragma: no cover + pass + +try: # pragma: no cover - exercised indirectly + from app.services.training import timm_trainer # noqa: F401 +except Exception: # pragma: no cover + pass + +__all__ = [ + "Capabilities", + "FieldDef", + "FieldGroup", + "Trainer", + "TrainContext", + "TrainResult", + "get_trainer", + "list_frameworks", + "register", +] diff --git a/backend/src/app/services/training/base.py b/backend/src/app/services/training/base.py new file mode 100644 index 0000000..92688af --- /dev/null +++ b/backend/src/app/services/training/base.py @@ -0,0 +1,170 @@ +"""Core abstractions for pluggable training backends. + +A :class:`Trainer` wraps one ML library (Ultralytics, Timm, …). It declares the +tasks it supports and a *capability descriptor* (the model catalogues and the +hyperparameter / augmentation form, expressed as serializable field +definitions). The same descriptor drives the frontend form, so adding a library +means implementing a ``Trainer`` and registering it — no frontend or +orchestration edits. + +The Celery training task builds a :class:`TrainContext` (the minimal shared +surface every trainer needs: the dataset, the resolved split, MinIO access, a +scratch dir, and a uniform progress/metrics ``report`` callback) and calls +:meth:`Trainer.run`. Downstream stages (ONNX export, evaluation, inference) +resolve the trainer by ``key`` and call the relevant predict/export hook, so no +stage hard-codes a library name. +""" + +from __future__ import annotations + +import dataclasses +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Protocol + + +@dataclass +class FieldDef: + """One tunable knob, mirroring the frontend ``FieldDef`` shape. + + ``type`` is one of ``number | bool | select | text | model``. ``model`` is a + searchable architecture picker backed by ``Trainer.list_models``. + """ + + key: str + label: str + type: str + default: Any + min: float | None = None + max: float | None = None + step: float | None = None + options: list[str] | None = None + help: str | None = None + + +@dataclass +class FieldGroup: + """A titled, collapsible group of fields (e.g. "Optimizer & Schedule").""" + + title: str + fields: list[FieldDef] + + +@dataclass +class Capabilities: + """Everything the frontend needs to render a framework's training form.""" + + key: str + label: str + supported_tasks: list[str] + models_by_task: dict[str, list[str]] + groups: list[FieldGroup] + device_options: list[str] + + def to_dict(self) -> dict[str, Any]: + return dataclasses.asdict(self) + + +# Progress / metrics reporting callback. ``progress`` is a 0..1 fraction (or +# None to leave it unchanged); ``epoch`` is an optional per-epoch metrics entry +# that gets appended to the run's epoch history and persisted. +ReportFn = Callable[..., None] + + +@dataclass +class TrainContext: + """The shared surface a trainer needs to execute one run. + + Everything framework-agnostic (DB session, the resolved dataset + split, + object storage, a scratch directory, and the progress callback) lives here + so each :class:`Trainer` only has to express the library-specific bits. + """ + + db: Any + run: Any + params: dict[str, Any] + task: str + class_names: list[str] + assets: list[Any] + annotations_by_asset: dict[str, list[Any]] + splits: dict[str, str] + minio_client: Any + bucket: str + work_dir: Path + report: ReportFn + + +@dataclass +class TrainResult: + """What a trainer hands back to the task for generic persistence.""" + + best_model_path: Path | None = None + final_metrics: dict[str, Any] = field(default_factory=dict) + plot_files: list[Path] = field(default_factory=list) + # Framework-specific reconstruction info (e.g. timm arch/config) stored + # alongside the artifact so export/inference can rebuild the model. + artifact_meta: dict[str, Any] = field(default_factory=dict) + + +class Predictor(Protocol): + """Marker protocol; concrete predictors are framework-defined opaque objects.""" + + +class Trainer(ABC): + """Interface every training backend implements.""" + + #: Stable identifier persisted on runs/artifacts (e.g. "ultralytics", "timm"). + key: str = "" + #: Human-readable label shown in the UI. + label: str = "" + #: Tasks this backend can train (subset of detect|classify|segment|pose). + supported_tasks: set[str] = set() + + # -- capability surface --------------------------------------------------- + @abstractmethod + def capabilities(self) -> Capabilities: + """Return the model catalogues + form schema for this backend.""" + + def list_models(self, task: str, query: str = "") -> list[str]: + """Return model names for ``task`` filtered by ``query``. + + Defaults to the static catalogue from :meth:`capabilities`. Backends with + huge catalogues (e.g. Timm) override this with a live, filtered search. + """ + models = self.capabilities().models_by_task.get(task, []) + q = query.strip().lower() + return [m for m in models if q in m.lower()] if q else models + + # -- training ------------------------------------------------------------- + @abstractmethod + def run(self, ctx: TrainContext) -> TrainResult: + """Build the dataset, train, and return the best model + metrics.""" + + # -- export --------------------------------------------------------------- + def export_onnx( + self, + local_pt: Path, + out_dir: Path, + *, + opset: int | None = None, + dynamic: bool = True, + ) -> Path | None: + """Export a trained ``.pt`` to ONNX. Returns the path, or None on skip.""" + raise NotImplementedError(f"{self.key} does not support ONNX export") + + # -- inference / evaluation ---------------------------------------------- + def load_predictor(self, local_pt: Path) -> Any: + """Load a trained model into a reusable predictor object.""" + raise NotImplementedError(f"{self.key} does not support inference") + + def predict_classification(self, predictor: Any, image_path: str) -> tuple[str, float]: + """Return ``(class_name, score)`` for a single image.""" + raise NotImplementedError(f"{self.key} does not support classification inference") + + def predict_detections( + self, predictor: Any, image_path: str, score_threshold: float + ) -> list[dict[str, Any]]: + """Return a list of ``{class, bbox, score}`` detections for one image.""" + raise NotImplementedError(f"{self.key} does not support detection inference") diff --git a/backend/src/app/services/training/datasets.py b/backend/src/app/services/training/datasets.py new file mode 100644 index 0000000..ddead38 --- /dev/null +++ b/backend/src/app/services/training/datasets.py @@ -0,0 +1,153 @@ +"""Shared dataset exporters used by the training backends. + +Extracted from the original inline ``_build_yolo_dataset`` so multiple trainers +can reuse them. ``build_imagefolder`` produces the standard ImageFolder layout +used by both Ultralytics-classify and Timm; ``build_yolo_detect`` produces the +YOLO detection layout (images + label txts + ``data.yaml``). + +In both, assets whose resolved split is ``"test"`` are held out entirely so the +test set stays a true unseen set for evaluation. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + + +def extract_minio_key(uri: str) -> str: + """Extract the object key from an asset URI (strips bucket prefix if present).""" + for prefix in ("s3://", "minio://"): + if uri.startswith(prefix): + parts = uri[len(prefix) :].split("/", 1) + return parts[1] if len(parts) > 1 else uri + if uri.startswith("http://") or uri.startswith("https://"): + path = uri.split("/", 3) + return path[3] if len(path) > 3 else uri + return uri + + +def fetch_image_bytes(minio_client: Any, bucket: str, asset: Any) -> bytes | None: + """Download an asset's bytes from MinIO, or None if unavailable.""" + if minio_client is None: + return None + try: + key = extract_minio_key(asset.uri) + response = minio_client.get_object(bucket, key) + data = response.read() + response.close() + response.release_conn() + return data + except Exception: + return None + + +def build_imagefolder( + assets: list[Any], + annotations_by_asset: dict[str, list[Any]], + class_names: list[str], + output_dir: Path, + minio_client: Any, + bucket: str, + splits: dict[str, str] | None = None, +) -> Path: + """Export a classification dataset as ImageFolder (``//img``). + + Returns the dataset *root* directory (containing ``train`` and ``val``). + """ + splits = splits or {} + for split in ("train", "val"): + for cls in class_names: + (output_dir / split / cls).mkdir(parents=True, exist_ok=True) + + for asset in assets: + split = splits.get(asset.id, "train") + if split == "test": + continue + if split not in ("train", "val"): + split = "train" + ext = Path(asset.uri).suffix or ".jpg" + img_name = f"{asset.id}{ext}" + ann_list = annotations_by_asset.get(asset.id, []) + cls_name = ( + ann_list[0].class_name + if ann_list and ann_list[0].class_name + else (class_names[0] if class_names else "unknown") + ) + dest_dir = output_dir / split / cls_name + dest_dir.mkdir(parents=True, exist_ok=True) + img_data = fetch_image_bytes(minio_client, bucket, asset) + (dest_dir / img_name).write_bytes(img_data or b"") + + return output_dir + + +def build_yolo_detect( + assets: list[Any], + annotations_by_asset: dict[str, list[Any]], + class_names: list[str], + output_dir: Path, + minio_client: Any, + bucket: str, + splits: dict[str, str] | None = None, +) -> Path: + """Export a detection dataset in YOLO format. Returns the ``data.yaml`` path.""" + splits = splits or {} + images_train = output_dir / "train" / "images" + images_val = output_dir / "val" / "images" + labels_train = output_dir / "train" / "labels" + labels_val = output_dir / "val" / "labels" + for d in (images_train, images_val, labels_train, labels_val): + d.mkdir(parents=True, exist_ok=True) + + class_idx: dict[str, int] = {name: i for i, name in enumerate(class_names)} + + for asset in assets: + split = splits.get(asset.id, "train") + if split == "test": + continue + if split not in ("train", "val"): + split = "train" + ext = Path(asset.uri).suffix or ".jpg" + img_name = f"{asset.id}{ext}" + img_data = fetch_image_bytes(minio_client, bucket, asset) + + dest_img = (images_train if split == "train" else images_val) / img_name + dest_img.write_bytes(img_data or b"") + + label_dir = labels_train if split == "train" else labels_val + label_file = label_dir / f"{asset.id}.txt" + lines: list[str] = [] + for ann in annotations_by_asset.get(asset.id, []): + try: + geo = json.loads(ann.geometry) if isinstance(ann.geometry, str) else ann.geometry + except Exception: + continue + if not isinstance(geo, dict): + continue + idx = class_idx.get(ann.class_name or "", 0) + w_img = asset.width or 640 + h_img = asset.height or 640 + if "x1" in geo: + bx = (geo["x1"] + geo["x2"]) / 2 / w_img + by = (geo["y1"] + geo["y2"]) / 2 / h_img + bw = (geo["x2"] - geo["x1"]) / w_img + bh = (geo["y2"] - geo["y1"]) / h_img + else: + bx = (geo.get("x", 0) + geo.get("w", 0) / 2) / w_img + by = (geo.get("y", 0) + geo.get("h", 0) / 2) / h_img + bw = geo.get("w", 0) / w_img + bh = geo.get("h", 0) / h_img + lines.append(f"{idx} {bx:.6f} {by:.6f} {bw:.6f} {bh:.6f}") + label_file.write_text("\n".join(lines)) + + data_yaml = output_dir / "data.yaml" + data_yaml.write_text( + f"path: {output_dir}\n" + f"train: train/images\n" + f"val: val/images\n" + f"nc: {len(class_names)}\n" + f"names: {json.dumps(class_names)}\n" + ) + return data_yaml diff --git a/backend/src/app/services/training/registry.py b/backend/src/app/services/training/registry.py new file mode 100644 index 0000000..3e30ed0 --- /dev/null +++ b/backend/src/app/services/training/registry.py @@ -0,0 +1,50 @@ +"""Registry mapping framework keys to :class:`Trainer` instances. + +Built-in trainers register themselves at import time (see the package +``__init__``). New libraries are added by implementing :class:`Trainer` and +calling :func:`register` — no edits to the Celery tasks, services, or frontend. +""" + +from __future__ import annotations + +from app.services.training.base import Trainer + +# The default framework for runs/artifacts created before frameworks existed, +# and whenever a caller omits the framework. +DEFAULT_FRAMEWORK = "ultralytics" + +_REGISTRY: dict[str, Trainer] = {} + + +class UnknownFrameworkError(Exception): + """Raised when a framework key has no registered trainer.""" + + +def register(trainer: Trainer) -> Trainer: + """Register a trainer instance under its ``key``. Idempotent (last wins).""" + if not trainer.key: + raise ValueError("Trainer.key must be a non-empty string") + _REGISTRY[trainer.key] = trainer + return trainer + + +def get_trainer(key: str | None) -> Trainer: + """Resolve a trainer by key, defaulting to Ultralytics for back-compat. + + A ``None``/empty key maps to :data:`DEFAULT_FRAMEWORK` so existing runs and + artifacts (which predate the ``framework`` column) keep working unchanged. + """ + resolved = (key or DEFAULT_FRAMEWORK).strip() or DEFAULT_FRAMEWORK + trainer = _REGISTRY.get(resolved) + if trainer is None: + raise UnknownFrameworkError(f"no trainer registered for framework {resolved!r}") + return trainer + + +def list_frameworks() -> list[Trainer]: + """Return all registered trainers (stable order by key).""" + return [_REGISTRY[k] for k in sorted(_REGISTRY)] + + +def is_registered(key: str | None) -> bool: + return bool(key) and key in _REGISTRY diff --git a/backend/src/app/services/training/timm_trainer.py b/backend/src/app/services/training/timm_trainer.py new file mode 100644 index 0000000..dc699c6 --- /dev/null +++ b/backend/src/app/services/training/timm_trainer.py @@ -0,0 +1,710 @@ +"""Timm (PyTorch Image Models) training backend for image classification. + +Gives users the full Timm surface: any of the ~1300 backbones, head/pooling +choices (``global_pool``, drop rates), the complete optimizer/scheduler matrix +(``create_optimizer_v2`` / ``create_scheduler_v2``), mixup/cutmix, model EMA, +label smoothing, AMP, channels-last, and the full timm augmentation pipeline +(RandAugment/AutoAugment via the ``aa`` string, color jitter, random erasing, +flips, random-resized-crop scale/ratio, interpolation, crop pct). + +All torch/timm imports are lazy so this module loads in environments without +the ML wheels (the capability schema is pure data). When the wheels are absent +at run time, :meth:`run` raises and the task records a clean failure — mirroring +the existing Ultralytics ``ImportError`` handling. +""" + +from __future__ import annotations + +import copy +from pathlib import Path +from typing import Any + +from app.services.training import datasets +from app.services.training.base import ( + Capabilities, + FieldDef, + FieldGroup, + TrainContext, + Trainer, + TrainResult, +) +from app.services.training.registry import register + +# A small, popular default set surfaced before the user searches the full +# catalogue (the frontend uses the searchable picker against list_models). +_POPULAR = [ + "resnet18", + "resnet50", + "resnext50_32x4d", + "efficientnet_b0", + "efficientnet_b3", + "convnext_tiny", + "convnext_small", + "mobilenetv3_large_100", + "mobilenetv3_small_050", + "vit_tiny_patch16_224", + "vit_small_patch16_224", + "vit_base_patch16_224", + "swin_tiny_patch4_window7_224", + "deit3_small_patch16_224", + "resnet10t", +] + + +class TimmTrainer(Trainer): + key = "timm" + label = "Timm (PyTorch Image Models)" + supported_tasks = {"classify"} + + # ------------------------------------------------------------------ schema + def capabilities(self) -> Capabilities: + groups = [ + FieldGroup( + "Architecture", + [ + FieldDef( + "model", + "Backbone", + "model", + "resnet50", + help="Any timm architecture — type to search the full catalogue", + ), + FieldDef( + "pretrained", + "Pretrained weights", + "bool", + True, + help="Initialise from ImageNet-pretrained weights", + ), + FieldDef( + "global_pool", + "Pooling head", + "select", + "avg", + options=["avg", "max", "avgmax", "catavgmax", ""], + help="Classifier global pooling", + ), + FieldDef("drop_rate", "Dropout", "number", 0.0, 0, 1, 0.01), + FieldDef( + "drop_path_rate", + "Drop path (stochastic depth)", + "number", + 0.0, + 0, + 1, + 0.01, + ), + ], + ), + FieldGroup( + "Core", + [ + FieldDef("epochs", "Epochs", "number", 20, 1, 2000), + FieldDef("batch_size", "Batch Size", "number", 32, 1, 1024), + FieldDef("img_size", "Image Size", "number", 224, 32, 1024, 16), + FieldDef("num_workers", "Data Workers", "number", 4, 0, 32), + FieldDef("seed", "Seed", "number", 42, 0), + FieldDef( + "patience", + "Early-stop Patience", + "number", + 0, + 0, + 1000, + help="Stop after N epochs w/o top-1 improvement (0 = off)", + ), + FieldDef("amp", "AMP", "bool", True, help="Automatic mixed precision"), + FieldDef("channels_last", "Channels-last", "bool", False), + ], + ), + FieldGroup( + "Optimizer", + [ + FieldDef( + "opt", + "Optimizer", + "select", + "adamw", + options=[ + "sgd", + "momentum", + "adam", + "adamw", + "nadam", + "radam", + "rmsprop", + "lamb", + "adabelief", + ], + ), + FieldDef("lr", "Learning Rate", "number", 0.001, 0.0, 10, 0.0001), + FieldDef("weight_decay", "Weight Decay", "number", 0.0001, 0, 1, 0.0001), + FieldDef("momentum", "Momentum (SGD)", "number", 0.9, 0, 1, 0.001), + FieldDef("opt_eps", "Epsilon", "number", 1e-8, 0, 1, 1e-9), + FieldDef( + "layer_decay", + "Layer-wise LR decay", + "number", + 0.0, + 0, + 1, + 0.01, + help="0 disables; e.g. 0.75 for ViT fine-tuning", + ), + ], + ), + FieldGroup( + "Schedule", + [ + FieldDef( + "sched", + "Scheduler", + "select", + "cosine", + options=["cosine", "step", "multistep", "plateau", "poly", "tanh", "none"], + ), + FieldDef("warmup_epochs", "Warmup Epochs", "number", 3, 0, 100), + FieldDef("warmup_lr", "Warmup LR", "number", 1e-5, 0, 1, 1e-6), + FieldDef("min_lr", "Min LR", "number", 1e-6, 0, 1, 1e-7), + FieldDef("decay_epochs", "Decay Epochs (step)", "number", 30, 1, 1000), + FieldDef("decay_rate", "Decay Rate", "number", 0.1, 0, 1, 0.01), + FieldDef("cooldown_epochs", "Cooldown Epochs", "number", 0, 0, 100), + ], + ), + FieldGroup( + "Regularization", + [ + FieldDef("smoothing", "Label Smoothing", "number", 0.1, 0, 1, 0.01), + FieldDef("mixup", "Mixup alpha", "number", 0.0, 0, 5, 0.1), + FieldDef("cutmix", "CutMix alpha", "number", 0.0, 0, 5, 0.1), + FieldDef("mixup_prob", "Mix prob", "number", 1.0, 0, 1, 0.05), + FieldDef("mixup_switch_prob", "Mix switch prob", "number", 0.5, 0, 1, 0.05), + FieldDef("model_ema", "Model EMA", "bool", False), + FieldDef("model_ema_decay", "EMA decay", "number", 0.9998, 0.9, 1, 0.0001), + ], + ), + FieldGroup( + "Augmentation", + [ + FieldDef( + "aa", + "Auto-augment policy", + "text", + "rand-m9-mstd0.5", + help="timm AA string, e.g. rand-m9-mstd0.5 / original / v0 (blank = off)", + ), + FieldDef("color_jitter", "Color Jitter", "number", 0.4, 0, 1, 0.01), + FieldDef("reprob", "Random Erase prob", "number", 0.25, 0, 1, 0.01), + FieldDef( + "remode", + "Random Erase mode", + "select", + "pixel", + options=["const", "rand", "pixel"], + ), + FieldDef("recount", "Random Erase count", "number", 1, 1, 10), + FieldDef("hflip", "Horizontal flip prob", "number", 0.5, 0, 1, 0.01), + FieldDef("vflip", "Vertical flip prob", "number", 0.0, 0, 1, 0.01), + FieldDef("scale_min", "RRC scale min", "number", 0.08, 0, 1, 0.01), + FieldDef("scale_max", "RRC scale max", "number", 1.0, 0, 1, 0.01), + FieldDef("ratio_min", "RRC ratio min", "number", 0.75, 0.1, 2, 0.01), + FieldDef("ratio_max", "RRC ratio max", "number", 1.3333, 0.5, 4, 0.01), + FieldDef( + "train_interpolation", + "Interpolation", + "select", + "bicubic", + options=["bilinear", "bicubic", "random"], + ), + FieldDef("crop_pct", "Eval crop pct", "number", 0.875, 0.1, 1, 0.01), + ], + ), + ] + return Capabilities( + key=self.key, + label=self.label, + supported_tasks=sorted(self.supported_tasks), + models_by_task={"classify": _POPULAR}, + groups=groups, + device_options=["cpu", "cuda", "cuda:0", "cuda:1"], + ) + + def list_models(self, task: str, query: str = "") -> list[str]: + try: + import timm # type: ignore + except Exception: + return super().list_models(task, query) + pattern = f"*{query.strip()}*" if query.strip() else "" + try: + names = ( + timm.list_models(pattern, pretrained=True) + if pattern + else timm.list_models(pretrained=True) + ) + except Exception: + return super().list_models(task, query) + return list(names)[:500] + + # ----------------------------------------------------------------- helpers + @staticmethod + def _resolve_device(params: dict[str, Any]) -> str: + import torch # type: ignore + + requested = str(params.get("device", "")).strip() + if requested in ("", "auto"): + return "cuda" if torch.cuda.is_available() else "cpu" + if requested.startswith("cuda") and not torch.cuda.is_available(): + return "cpu" + return requested + + def _build_transforms(self, params: dict[str, Any]) -> tuple[Any, Any]: + from timm.data import create_transform # type: ignore + + img_size = int(params.get("img_size", 224)) + aa = (params.get("aa") or "").strip() or None + scale = (float(params.get("scale_min", 0.08)), float(params.get("scale_max", 1.0))) + ratio = (float(params.get("ratio_min", 0.75)), float(params.get("ratio_max", 1.3333))) + train_tf = create_transform( + input_size=img_size, + is_training=True, + color_jitter=float(params.get("color_jitter", 0.4)), + auto_augment=aa, + interpolation=str(params.get("train_interpolation", "bicubic")), + re_prob=float(params.get("reprob", 0.25)), + re_mode=str(params.get("remode", "pixel")), + re_count=int(params.get("recount", 1)), + scale=scale, + ratio=ratio, + hflip=float(params.get("hflip", 0.5)), + vflip=float(params.get("vflip", 0.0)), + ) + val_tf = create_transform( + input_size=img_size, + is_training=False, + interpolation=str(params.get("train_interpolation", "bicubic")), + crop_pct=float(params.get("crop_pct", 0.875)), + ) + return train_tf, val_tf + + # -------------------------------------------------------------------- train + def run(self, ctx: TrainContext) -> TrainResult: # noqa: C901 - linear training loop + import timm # type: ignore + import torch # type: ignore + from torch.utils.data import DataLoader # type: ignore + from torchvision.datasets import ImageFolder # type: ignore + + params = ctx.params + device = self._resolve_device(params) + torch.manual_seed(int(params.get("seed", 42))) + + root = datasets.build_imagefolder( + ctx.assets, + ctx.annotations_by_asset, + ctx.class_names, + ctx.work_dir / "dataset", + ctx.minio_client, + ctx.bucket, + ctx.splits, + ) + ctx.report(progress=0.2) + + train_tf, val_tf = self._build_transforms(params) + train_ds = ImageFolder(str(root / "train"), transform=train_tf) + val_dir = root / "val" + has_val = val_dir.exists() and any(val_dir.iterdir()) + val_ds = ImageFolder(str(val_dir), transform=val_tf) if has_val else None + + # ImageFolder derives class order from sorted dir names; persist that so + # predictions map back to the right label. + classes = train_ds.classes + num_classes = len(classes) + + batch_size = int(params.get("batch_size", 32)) + num_workers = int(params.get("num_workers", 4)) + train_loader = DataLoader( + train_ds, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=device.startswith("cuda"), + drop_last=False, + ) + val_loader = ( + DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) + if val_ds is not None + else None + ) + + model = timm.create_model( + str(params.get("model", "resnet50")), + pretrained=bool(params.get("pretrained", True)), + num_classes=num_classes, + global_pool=str(params.get("global_pool", "avg")), + drop_rate=float(params.get("drop_rate", 0.0)), + drop_path_rate=float(params.get("drop_path_rate", 0.0)), + ) + memory_format = ( + torch.channels_last if params.get("channels_last") else torch.contiguous_format + ) + model = model.to(device, memory_format=memory_format) + + optimizer = self._make_optimizer(model, params) + epochs = int(params.get("epochs", 20)) + scheduler, epochs = self._make_scheduler(optimizer, params, epochs) + mixup_fn = self._make_mixup(params, num_classes) + train_loss_fn, val_loss_fn = self._make_losses(params, mixup_fn) + + ema = None + if params.get("model_ema"): + try: + from timm.utils import ModelEmaV2 # type: ignore + + ema = ModelEmaV2(model, decay=float(params.get("model_ema_decay", 0.9998))) + except Exception: + ema = None + + use_amp = bool(params.get("amp", True)) and device.startswith("cuda") + scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + + history: list[dict[str, Any]] = [] + best_top1 = -1.0 + best_state: dict[str, Any] | None = None + epochs_no_improve = 0 + patience = int(params.get("patience", 0)) + + for epoch in range(epochs): + model.train() + running_loss, seen = 0.0, 0 + for inputs, targets in train_loader: + inputs = inputs.to(device, memory_format=memory_format, non_blocking=True) + targets = targets.to(device, non_blocking=True) + if mixup_fn is not None and inputs.size(0) % 2 == 0: + inputs, targets = mixup_fn(inputs, targets) + optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=use_amp): + outputs = model(inputs) + loss = train_loss_fn(outputs, targets) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + if ema is not None: + ema.update(model) + running_loss += float(loss.item()) * inputs.size(0) + seen += inputs.size(0) + + train_loss = running_loss / max(seen, 1) + eval_model = ema.module if ema is not None else model + metrics = self._validate(eval_model, val_loader, val_loss_fn, device, memory_format) + + entry = { + "epoch": epoch, + "train_loss": round(train_loss, 5), + "lr": round(optimizer.param_groups[0]["lr"], 8), + **metrics, + } + history.append(entry) + ctx.report(progress=0.2 + 0.75 * ((epoch + 1) / max(epochs, 1)), epoch=entry) + + top1 = metrics.get("top1", 0.0) + if top1 > best_top1: + best_top1 = top1 + best_state = copy.deepcopy(eval_model.state_dict()) + epochs_no_improve = 0 + else: + epochs_no_improve += 1 + if patience and epochs_no_improve >= patience: + break + + if scheduler is not None: + try: + scheduler.step(epoch + 1, metric=top1) + except TypeError: + scheduler.step(epoch + 1) + + if best_state is not None: + model.load_state_dict(best_state) + + img_size = int(params.get("img_size", 224)) + checkpoint = { + "vf_framework": "timm", + "arch": str(params.get("model", "resnet50")), + "num_classes": num_classes, + "global_pool": str(params.get("global_pool", "avg")), + "img_size": img_size, + "classes": classes, + "state_dict": model.state_dict(), + } + best_path = ctx.work_dir / "best.pt" + torch.save(checkpoint, best_path) + + plot_files = self._make_plots( + ctx.work_dir, history, model, val_loader, classes, device, memory_format + ) + + final_metrics = {k: v for k, v in history[-1].items() if k != "epoch"} if history else {} + artifact_meta = { + "arch": checkpoint["arch"], + "num_classes": num_classes, + "global_pool": checkpoint["global_pool"], + "img_size": img_size, + "classes": classes, + } + return TrainResult( + best_model_path=best_path, + final_metrics=final_metrics, + plot_files=plot_files, + artifact_meta=artifact_meta, + ) + + # ---------------------------------------------------------- training pieces + def _make_optimizer(self, model: Any, params: dict[str, Any]) -> Any: + from timm.optim import create_optimizer_v2 # type: ignore + + kwargs: dict[str, Any] = { + "opt": str(params.get("opt", "adamw")), + "lr": float(params.get("lr", 0.001)), + "weight_decay": float(params.get("weight_decay", 0.0001)), + "momentum": float(params.get("momentum", 0.9)), + } + eps = params.get("opt_eps") + if eps is not None: + kwargs["eps"] = float(eps) + layer_decay = float(params.get("layer_decay", 0.0) or 0.0) + if layer_decay > 0: + kwargs["layer_decay"] = layer_decay + try: + return create_optimizer_v2(model, **kwargs) + except TypeError: + kwargs.pop("eps", None) + kwargs.pop("layer_decay", None) + return create_optimizer_v2(model, **kwargs) + + def _make_scheduler(self, optimizer: Any, params: dict[str, Any], epochs: int) -> tuple: + sched = str(params.get("sched", "cosine")) + if sched in ("", "none"): + return None, epochs + from timm.scheduler import create_scheduler_v2 # type: ignore + + try: + scheduler, num_epochs = create_scheduler_v2( + optimizer, + sched=sched, + num_epochs=epochs, + decay_epochs=float(params.get("decay_epochs", 30)), + decay_rate=float(params.get("decay_rate", 0.1)), + min_lr=float(params.get("min_lr", 1e-6)), + warmup_lr=float(params.get("warmup_lr", 1e-5)), + warmup_epochs=int(params.get("warmup_epochs", 3)), + cooldown_epochs=int(params.get("cooldown_epochs", 0)), + ) + return scheduler, int(num_epochs) + except Exception: + return None, epochs + + def _make_mixup(self, params: dict[str, Any], num_classes: int) -> Any: + mixup = float(params.get("mixup", 0.0) or 0.0) + cutmix = float(params.get("cutmix", 0.0) or 0.0) + if mixup <= 0 and cutmix <= 0: + return None + try: + from timm.data import Mixup # type: ignore + + return Mixup( + mixup_alpha=mixup, + cutmix_alpha=cutmix, + prob=float(params.get("mixup_prob", 1.0)), + switch_prob=float(params.get("mixup_switch_prob", 0.5)), + mode="batch", + label_smoothing=float(params.get("smoothing", 0.1)), + num_classes=num_classes, + ) + except Exception: + return None + + def _make_losses(self, params: dict[str, Any], mixup_fn: Any) -> tuple: + import torch.nn as nn # type: ignore + + smoothing = float(params.get("smoothing", 0.1)) + if mixup_fn is not None: + from timm.loss import SoftTargetCrossEntropy # type: ignore + + train_loss = SoftTargetCrossEntropy() + elif smoothing > 0: + try: + from timm.loss import LabelSmoothingCrossEntropy # type: ignore + + train_loss = LabelSmoothingCrossEntropy(smoothing=smoothing) + except Exception: + train_loss = nn.CrossEntropyLoss(label_smoothing=smoothing) + else: + train_loss = nn.CrossEntropyLoss() + return train_loss, nn.CrossEntropyLoss() + + def _validate( + self, model: Any, val_loader: Any, loss_fn: Any, device: str, memory_format: Any + ) -> dict[str, Any]: + if val_loader is None: + return {"top1": 0.0, "top5": 0.0, "val_loss": 0.0} + import torch # type: ignore + + model.eval() + total, correct1, correct5, loss_sum = 0, 0, 0, 0.0 + with torch.no_grad(): + for inputs, targets in val_loader: + inputs = inputs.to(device, memory_format=memory_format) + targets = targets.to(device) + outputs = model(inputs) + loss_sum += float(loss_fn(outputs, targets).item()) * inputs.size(0) + k = min(5, outputs.size(1)) + _, pred = outputs.topk(k, 1, True, True) + pred = pred.t() + correct = pred.eq(targets.view(1, -1).expand_as(pred)) + correct1 += int(correct[0].reshape(-1).float().sum().item()) + correct5 += int(correct[:k].reshape(-1).float().sum().item()) + total += targets.size(0) + return { + "top1": round(correct1 / max(total, 1), 5), + "top5": round(correct5 / max(total, 1), 5), + "val_loss": round(loss_sum / max(total, 1), 5), + } + + def _make_plots( + self, + work_dir: Path, + history: list[dict], + model: Any, + val_loader: Any, + classes: list[str], + device: str, + memory_format: Any, + ) -> list[Path]: + files: list[Path] = [] + try: + import matplotlib # type: ignore + + matplotlib.use("Agg") + import matplotlib.pyplot as plt # type: ignore + except Exception: + return files + + # Loss / accuracy curves. + try: + epochs = [h["epoch"] for h in history] + fig, ax1 = plt.subplots(figsize=(6, 4)) + ax1.plot(epochs, [h.get("train_loss", 0) for h in history], label="train_loss") + ax1.plot(epochs, [h.get("val_loss", 0) for h in history], label="val_loss") + ax1.set_xlabel("epoch") + ax1.set_ylabel("loss") + ax2 = ax1.twinx() + ax2.plot(epochs, [h.get("top1", 0) for h in history], "g--", label="top1") + ax2.set_ylabel("top1") + ax1.legend(loc="upper left") + fig.tight_layout() + results = work_dir / "results.png" + fig.savefig(results) + plt.close(fig) + files.append(results) + except Exception: + pass + + # Confusion matrix from a final val pass. + if val_loader is not None: + try: + import numpy as np # type: ignore + import torch # type: ignore + + n = len(classes) + cm = np.zeros((n, n), dtype=int) + model.eval() + with torch.no_grad(): + for inputs, targets in val_loader: + inputs = inputs.to(device, memory_format=memory_format) + preds = model(inputs).argmax(1).cpu().numpy() + for t, p in zip(targets.numpy(), preds, strict=False): + cm[int(t)][int(p)] += 1 + fig, ax = plt.subplots(figsize=(5, 5)) + ax.imshow(cm, cmap="Blues") + ax.set_xlabel("predicted") + ax.set_ylabel("true") + ax.set_xticks(range(n)) + ax.set_yticks(range(n)) + ax.set_xticklabels(classes, rotation=90, fontsize=6) + ax.set_yticklabels(classes, fontsize=6) + fig.tight_layout() + cm_path = work_dir / "confusion_matrix.png" + fig.savefig(cm_path) + plt.close(fig) + files.append(cm_path) + except Exception: + pass + return files + + # ----------------------------------------------------------- export / infer + def export_onnx( + self, local_pt: Path, out_dir: Path, *, opset: int | None = None, dynamic: bool = True + ) -> Path | None: + import timm # type: ignore + import torch # type: ignore + + ckpt = torch.load(str(local_pt), map_location="cpu", weights_only=False) + model = self._rebuild_model(timm, ckpt) + model.eval() + img = int(ckpt.get("img_size", 224)) + dummy = torch.zeros(1, 3, img, img) + out_path = out_dir / "model.onnx" + dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}} if dynamic else None + torch.onnx.export( + model, + dummy, + str(out_path), + input_names=["input"], + output_names=["output"], + opset_version=int(opset) if opset else 17, + dynamic_axes=dynamic_axes, + ) + return out_path if out_path.exists() else None + + def load_predictor(self, local_pt: Path) -> Any: + import timm # type: ignore + import torch # type: ignore + from timm.data import create_transform # type: ignore + + ckpt = torch.load(str(local_pt), map_location="cpu", weights_only=False) + model = self._rebuild_model(timm, ckpt) + model.eval() + img = int(ckpt.get("img_size", 224)) + transform = create_transform(input_size=img, is_training=False) + return { + "model": model, + "transform": transform, + "classes": ckpt.get("classes", []), + } + + def predict_classification(self, predictor: Any, image_path: str) -> tuple[str, float]: + import torch # type: ignore + from PIL import Image # type: ignore + + model = predictor["model"] + transform = predictor["transform"] + classes = predictor["classes"] + img = Image.open(image_path).convert("RGB") + tensor = transform(img).unsqueeze(0) + with torch.no_grad(): + probs = torch.softmax(model(tensor), dim=1)[0] + idx = int(probs.argmax().item()) + name = classes[idx] if 0 <= idx < len(classes) else str(idx) + return (name, float(probs[idx].item())) + + @staticmethod + def _rebuild_model(timm_mod: Any, ckpt: dict) -> Any: + model = timm_mod.create_model( + ckpt.get("arch", "resnet50"), + pretrained=False, + num_classes=int(ckpt.get("num_classes", 1000)), + global_pool=ckpt.get("global_pool", "avg"), + ) + model.load_state_dict(ckpt["state_dict"]) + return model + + +register(TimmTrainer()) diff --git a/backend/src/app/services/training/ultralytics_trainer.py b/backend/src/app/services/training/ultralytics_trainer.py new file mode 100644 index 0000000..7b41d87 --- /dev/null +++ b/backend/src/app/services/training/ultralytics_trainer.py @@ -0,0 +1,365 @@ +"""Ultralytics/YOLO training backend. + +This is a behaviour-preserving extraction of the YOLO logic that previously +lived inline in ``jobs/tasks/training.py`` (training), ``jobs/tasks/onnx_export`` +(export) and the inference/evaluation predict helpers. It now implements the +:class:`Trainer` interface so it sits behind the same abstraction as Timm. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from app.services.training import datasets +from app.services.training.base import ( + Capabilities, + FieldDef, + FieldGroup, + TrainContext, + Trainer, + TrainResult, +) +from app.services.training.registry import register + +# Map Ultralytics' verbose metric keys to the clean keys the frontend charts. +_METRIC_KEY_MAP = { + "metrics/mAP50(B)": "mAP50", + "metrics/mAP50-95(B)": "mAP50_95", + "metrics/precision(B)": "precision", + "metrics/recall(B)": "recall", + "train/box_loss": "train_box_loss", + "train/cls_loss": "train_cls_loss", + "train/dfl_loss": "train_dfl_loss", + "val/box_loss": "val_box_loss", + "val/cls_loss": "val_cls_loss", + "val/dfl_loss": "val_dfl_loss", + "lr/pg0": "lr", + "metrics/accuracy_top1": "top1", + "metrics/accuracy_top5": "top5", +} + +# Allow-list of ``model.train()`` args users may tune. Keys absent here are +# ignored so callers cannot inject unsafe/irrelevant kwargs. +ULTRALYTICS_TRAIN_ARGS: dict[str, Any] = { + "epochs": 50, + "imgsz": 640, + "batch": 16, + "patience": 100, + "rect": False, + "single_cls": False, + "seed": 0, + "optimizer": "auto", + "lr0": 0.01, + "lrf": 0.01, + "momentum": 0.937, + "weight_decay": 0.0005, + "warmup_epochs": 3.0, + "warmup_momentum": 0.8, + "warmup_bias_lr": 0.1, + "cos_lr": False, + "close_mosaic": 10, + "nbs": 64, + "amp": True, + "dropout": 0.0, + "label_smoothing": 0.0, + "box": 7.5, + "cls": 0.5, + "dfl": 1.5, + "overlap_mask": True, + "mask_ratio": 4, + "freeze": None, + "hsv_h": 0.015, + "hsv_s": 0.7, + "hsv_v": 0.4, + "degrees": 0.0, + "translate": 0.1, + "scale": 0.5, + "shear": 0.0, + "perspective": 0.0, + "flipud": 0.0, + "fliplr": 0.5, + "bgr": 0.0, + "mosaic": 1.0, + "mixup": 0.0, + "copy_paste": 0.0, + "erasing": 0.4, + "crop_fraction": 1.0, + "auto_augment": "randaugment", +} + +# Plot images Ultralytics writes into the run dir when plots=True. +_PLOT_FILES = [ + "results.png", + "PR_curve.png", + "P_curve.png", + "R_curve.png", + "F1_curve.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + "labels.jpg", + "BoxPR_curve.png", +] + +_SIZES = ["n", "s", "m", "l", "x"] + + +def _normalize_metrics(raw: dict) -> dict: + out: dict[str, Any] = {} + for k, v in raw.items(): + if k in _METRIC_KEY_MAP: + out[_METRIC_KEY_MAP[k]] = v + out[k] = v + return out + + +class UltralyticsTrainer(Trainer): + key = "ultralytics" + label = "Ultralytics (YOLO)" + supported_tasks = {"detect", "classify", "segment", "pose"} + + def capabilities(self) -> Capabilities: + models_by_task = { + "detect": [f"yolov8{s}.pt" for s in _SIZES], + "classify": [f"yolov8{s}-cls.pt" for s in _SIZES], + "segment": [f"yolov8{s}-seg.pt" for s in _SIZES], + "pose": [f"yolov8{s}-pose.pt" for s in _SIZES], + } + groups = [ + FieldGroup( + "Core", + [ + FieldDef("epochs", "Epochs", "number", 50, 1, 2000), + FieldDef("batch", "Batch Size", "number", 16, 1, 512), + FieldDef("imgsz", "Image Size", "number", 640, 32, 1920, 32), + FieldDef( + "patience", + "Patience", + "number", + 100, + 0, + 1000, + help="Early-stop after N epochs w/o improvement", + ), + FieldDef("seed", "Seed", "number", 0, 0), + FieldDef( + "rect", + "Rectangular", + "bool", + False, + help="Rectangular batches (min padding)", + ), + FieldDef("single_cls", "Single class", "bool", False), + ], + ), + FieldGroup( + "Optimizer & Schedule", + [ + FieldDef( + "optimizer", + "Optimizer", + "select", + "auto", + options=["auto", "SGD", "Adam", "AdamW", "NAdam", "RAdam", "RMSProp"], + ), + FieldDef("lr0", "Initial LR (lr0)", "number", 0.01, 0.00001, 1, 0.0001), + FieldDef("lrf", "Final LR (lrf)", "number", 0.01, 0.00001, 1, 0.0001), + FieldDef("momentum", "Momentum", "number", 0.937, 0, 1, 0.001), + FieldDef("weight_decay", "Weight Decay", "number", 0.0005, 0, 0.1, 0.0001), + FieldDef("warmup_epochs", "Warmup Epochs", "number", 3.0, 0, 20, 0.5), + FieldDef("warmup_momentum", "Warmup Momentum", "number", 0.8, 0, 1, 0.01), + FieldDef("warmup_bias_lr", "Warmup Bias LR", "number", 0.1, 0, 1, 0.01), + FieldDef("cos_lr", "Cosine LR", "bool", False), + FieldDef( + "close_mosaic", + "Close Mosaic", + "number", + 10, + 0, + 100, + help="Disable mosaic for last N epochs", + ), + FieldDef("nbs", "Nominal Batch", "number", 64, 1, 256), + FieldDef("amp", "AMP", "bool", True, help="Automatic mixed precision"), + ], + ), + FieldGroup( + "Regularization & Loss Gains", + [ + FieldDef("dropout", "Dropout", "number", 0.0, 0, 1, 0.01), + FieldDef("label_smoothing", "Label Smoothing", "number", 0.0, 0, 1, 0.01), + FieldDef("box", "Box Gain", "number", 7.5, 0, 20, 0.1), + FieldDef("cls", "Cls Gain", "number", 0.5, 0, 10, 0.1), + FieldDef("dfl", "DFL Gain", "number", 1.5, 0, 10, 0.1), + FieldDef("overlap_mask", "Overlap Mask", "bool", True), + FieldDef("mask_ratio", "Mask Ratio", "number", 4, 1, 16), + ], + ), + FieldGroup( + "Augmentation", + [ + FieldDef("hsv_h", "hsv_h", "number", 0.015, 0, 1, 0.001, help="Hue jitter"), + FieldDef("hsv_s", "hsv_s", "number", 0.7, 0, 1, 0.01, help="Saturation jitter"), + FieldDef("hsv_v", "hsv_v", "number", 0.4, 0, 1, 0.01, help="Value jitter"), + FieldDef("degrees", "degrees", "number", 0.0, 0, 180, 1, help="Rotation range"), + FieldDef("translate", "translate", "number", 0.1, 0, 1, 0.01), + FieldDef("scale", "scale", "number", 0.5, 0, 1, 0.01), + FieldDef("shear", "shear", "number", 0.0, 0, 10, 0.1), + FieldDef("perspective", "perspective", "number", 0.0, 0, 0.001, 0.0001), + FieldDef("flipud", "flipud", "number", 0.0, 0, 1, 0.01), + FieldDef("fliplr", "fliplr", "number", 0.5, 0, 1, 0.01), + FieldDef("bgr", "bgr", "number", 0.0, 0, 1, 0.01), + FieldDef("mosaic", "mosaic", "number", 1.0, 0, 1, 0.01), + FieldDef("mixup", "mixup", "number", 0.0, 0, 1, 0.01), + FieldDef("copy_paste", "copy_paste", "number", 0.0, 0, 1, 0.01), + FieldDef("erasing", "erasing", "number", 0.4, 0, 1, 0.01), + FieldDef("crop_fraction", "crop_fraction", "number", 1.0, 0, 1, 0.01), + FieldDef( + "auto_augment", + "auto_augment", + "select", + "randaugment", + options=["randaugment", "autoaugment", "augmix"], + ), + ], + ), + ] + return Capabilities( + key=self.key, + label=self.label, + supported_tasks=sorted(self.supported_tasks), + models_by_task=models_by_task, + groups=groups, + device_options=["cpu", "cuda", "mps", "0", "0,1"], + ) + + def run(self, ctx: TrainContext) -> TrainResult: + from ultralytics import YOLO # type: ignore + + params = ctx.params + base_model = params.get("base_model", "yolov8n.pt") + dataset_dir = ctx.work_dir / "dataset" + + if ctx.task == "classify": + data_arg = datasets.build_imagefolder( + ctx.assets, + ctx.annotations_by_asset, + ctx.class_names, + dataset_dir, + ctx.minio_client, + ctx.bucket, + ctx.splits, + ) + else: + data_arg = datasets.build_yolo_detect( + ctx.assets, + ctx.annotations_by_asset, + ctx.class_names, + dataset_dir, + ctx.minio_client, + ctx.bucket, + ctx.splits, + ) + + ctx.report(progress=0.2) + + output_dir = ctx.work_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + total_epochs = params.get("epochs", 50) + model = YOLO(base_model) + + def on_train_epoch_end(trainer: Any) -> None: # noqa: ANN001 + epoch_num = getattr(trainer, "epoch", 0) + metrics_dict: dict = {} + if hasattr(trainer, "metrics"): + raw_metrics = trainer.metrics + if hasattr(raw_metrics, "results_dict"): + metrics_dict = dict(raw_metrics.results_dict) + elif isinstance(raw_metrics, dict): + metrics_dict = raw_metrics + if hasattr(trainer, "loss"): + loss_val = trainer.loss + metrics_dict["loss"] = float(loss_val) if loss_val is not None else None + entry = {"epoch": epoch_num, **_normalize_metrics(metrics_dict)} + progress = 0.2 + 0.75 * (epoch_num / max(total_epochs, 1)) + ctx.report(progress=progress, epoch=entry) + + model.add_callback("on_train_epoch_end", on_train_epoch_end) + + train_kwargs: dict = { + "data": str(data_arg), + "device": params.get("device", "cpu"), + "project": str(output_dir), + "name": "train", + "plots": True, + } + for key, default in ULTRALYTICS_TRAIN_ARGS.items(): + val = params.get(key, default) + if val is not None: + train_kwargs[key] = val + train_kwargs["epochs"] = total_epochs + model.train(**train_kwargs) + + best = output_dir / "train" / "weights" / "best.pt" + if not best.exists(): + pts = list(output_dir.rglob("*.pt")) + best = pts[0] if pts else None + + results_dir = output_dir / "train" + plot_files = [results_dir / f for f in _PLOT_FILES if (results_dir / f).exists()] + + return TrainResult(best_model_path=best, plot_files=plot_files) + + # -- export / inference --------------------------------------------------- + def export_onnx( + self, local_pt: Path, out_dir: Path, *, opset: int | None = None, dynamic: bool = True + ) -> Path | None: + from ultralytics import YOLO # type: ignore + + model = YOLO(str(local_pt)) + export_kwargs: dict = {"format": "onnx", "simplify": True, "dynamic": bool(dynamic)} + if opset is not None: + export_kwargs["opset"] = int(opset) + result = model.export(**export_kwargs) + if result: + return Path(str(result)) + candidate = local_pt.with_suffix(".onnx") + return candidate if candidate.exists() else None + + def load_predictor(self, local_pt: Path) -> Any: + from ultralytics import YOLO # type: ignore + + return YOLO(str(local_pt)) + + def predict_classification(self, predictor: Any, image_path: str) -> tuple[str, float]: + results = predictor.predict(image_path, verbose=False) + for r in results: + probs = getattr(r, "probs", None) + if probs is not None: + top = int(probs.top1) + names = getattr(r, "names", {}) or {} + return (names.get(top, str(top)), float(probs.top1conf.item())) + return ("unknown", 0.0) + + def predict_detections( + self, predictor: Any, image_path: str, score_threshold: float + ) -> list[dict[str, Any]]: + results = predictor.predict(image_path, conf=score_threshold, verbose=False) + out: list[dict[str, Any]] = [] + for r in results: + names = getattr(r, "names", {}) or {} + for box in getattr(r, "boxes", []) or []: + cls_idx = int(box.cls.item()) if hasattr(box, "cls") else 0 + xy = box.xywh[0].tolist() if hasattr(box, "xywh") else [0, 0, 0, 0] + cx, cy, w, h = xy + out.append( + { + "class": names.get(cls_idx, str(cls_idx)), + "bbox": (cx - w / 2, cy - h / 2, w, h), + "score": float(box.conf.item()) if hasattr(box, "conf") else 0.0, + } + ) + return out + + +register(UltralyticsTrainer()) diff --git a/backend/src/app/services/training_service.py b/backend/src/app/services/training_service.py index 1749519..47e1972 100644 --- a/backend/src/app/services/training_service.py +++ b/backend/src/app/services/training_service.py @@ -26,13 +26,25 @@ def launch_training( base_model: str = "yolov8n.pt", owner_id: str | None = None, cluster_id: str | None = None, + framework: str = "ultralytics", ) -> dict[str, Any]: # Reject obvious task/dataset mismatches so we don't burn cluster time on # a run we already know will fail (e.g. classify on a detection dataset). from app.models.dataset import Dataset from app.models.dataset_version import DatasetVersion + from app.services import training as training_pkg requested = (task or "").lower() + + # The chosen framework must support the requested task (e.g. Timm is + # classification-only). Resolve through the registry so this stays generic. + try: + trainer = training_pkg.get_trainer(framework) + except training_pkg.registry.UnknownFrameworkError as exc: + raise TaskTypeMismatch(str(exc)) from exc + if requested and requested not in trainer.supported_tasks: + raise TaskTypeMismatch(f"framework '{trainer.key}' does not support task '{requested}'") + if requested in ("detect", "classify"): version = db.get(DatasetVersion, dataset_version_id) dataset = db.get(Dataset, version.dataset_id) if version else None @@ -41,9 +53,11 @@ def launch_training( f"task '{requested}' does not match dataset task_type " f"'{dataset.task_type}'" ) - # Build full params including task and base_model so the worker can read them + # Build full params including task, framework and base_model so the worker + # can read them. full_params = dict(params or {}) full_params.setdefault("task", task) + full_params.setdefault("framework", trainer.key) full_params.setdefault("base_model", base_model) # If a cluster was selected, reserve it before creating the run so we fail @@ -62,6 +76,7 @@ def launch_training( cluster_id=cluster_id, name=name, status="queued", + framework=trainer.key, params_json=json.dumps(full_params), ) db.add(run) diff --git a/backend/tests/unit/test_services_training.py b/backend/tests/unit/test_services_training.py index 814a4de..99c2bb7 100644 --- a/backend/tests/unit/test_services_training.py +++ b/backend/tests/unit/test_services_training.py @@ -1,15 +1,22 @@ +import json + +import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from app.db.base import Base -from app.services.training_service import launch_training +from app.models.experiment import ExperimentRun +from app.services.training_service import TaskTypeMismatch, launch_training -def test_launch_training_returns_job_id(tmp_path): +def _session(): engine = create_engine("sqlite+pysqlite:///:memory:", future=True) Base.metadata.create_all(bind=engine) - Session = sessionmaker(bind=engine, future=True) - db = Session() + return sessionmaker(bind=engine, future=True)() + + +def test_launch_training_returns_job_id(tmp_path): + db = _session() try: result = launch_training(db, project_id="p1", dataset_version_id="dv1", task="detect") assert result["status"] == "queued" @@ -17,3 +24,45 @@ def test_launch_training_returns_job_id(tmp_path): assert result["jobId"] finally: db.close() + + +def test_launch_training_defaults_to_ultralytics_framework(): + db = _session() + try: + result = launch_training(db, project_id="p1", dataset_version_id="dv1", task="detect") + run = db.get(ExperimentRun, result["experimentId"]) + assert run.framework == "ultralytics" + assert json.loads(run.params_json)["framework"] == "ultralytics" + finally: + db.close() + + +def test_launch_training_persists_timm_framework(): + db = _session() + try: + result = launch_training( + db, + project_id="p1", + dataset_version_id="dv1", + task="classify", + framework="timm", + params={"model": "resnet18"}, + ) + run = db.get(ExperimentRun, result["experimentId"]) + assert run.framework == "timm" + params = json.loads(run.params_json) + assert params["framework"] == "timm" and params["model"] == "resnet18" + finally: + db.close() + + +def test_launch_training_rejects_unsupported_task_for_framework(): + db = _session() + try: + # Timm is classification-only — detection must be rejected. + with pytest.raises(TaskTypeMismatch): + launch_training( + db, project_id="p1", dataset_version_id="dv1", task="detect", framework="timm" + ) + finally: + db.close() diff --git a/backend/tests/unit/test_timm_trainer.py b/backend/tests/unit/test_timm_trainer.py new file mode 100644 index 0000000..37ad3cc --- /dev/null +++ b/backend/tests/unit/test_timm_trainer.py @@ -0,0 +1,127 @@ +"""Tests for the Timm trainer. + +The capability-schema test is dependency-free. The training smoke test actually +runs a 1-epoch CPU fit and is skipped unless torch/timm/torchvision/PIL are +installed (e.g. on the GPU agent images / full backend env). +""" + +import io +from dataclasses import dataclass +from pathlib import Path + +import pytest + +from app.services import training as training_pkg +from app.services.training.base import TrainContext + + +def test_timm_capabilities_cover_full_surface(): + caps = training_pkg.get_trainer("timm").capabilities().to_dict() + titles = {g["title"] for g in caps["groups"]} + assert {"Architecture", "Core", "Optimizer", "Schedule", "Regularization", "Augmentation"} <= ( + titles + ) + all_keys = {f["key"] for g in caps["groups"] for f in g["fields"]} + # A representative slice of the full hyperparameter / augmentation surface. + for key in ("model", "global_pool", "opt", "sched", "mixup", "aa", "reprob", "model_ema"): + assert key in all_keys + + +@dataclass +class _Asset: + id: str + uri: str + width: int = 32 + height: int = 32 + + +@dataclass +class _Ann: + class_name: str + geometry: str = "{}" + type: str = "classification" + + +def _png_bytes(color): + from PIL import Image # type: ignore + + buf = io.BytesIO() + Image.new("RGB", (32, 32), color).save(buf, format="PNG") + return buf.getvalue() + + +def test_timm_cpu_smoke_train(tmp_path): + pytest.importorskip("torch") + pytest.importorskip("timm") + pytest.importorskip("torchvision") + pytest.importorskip("PIL") + + class _Resp: + def __init__(self, data): + self._d = data + + def read(self): + return self._d + + def close(self): + pass + + def release_conn(self): + pass + + class _Minio: + def get_object(self, bucket, key): + color = (255, 0, 0) if "cat" in key else (0, 0, 255) + return _Resp(_png_bytes(color)) + + assets, anns, splits = [], {}, {} + for i in range(6): + label = "cat" if i % 2 == 0 else "dog" + a = _Asset(f"a{i}", f"s3://b/{label}{i}.png") + assets.append(a) + anns[a.id] = [_Ann(label)] + splits[a.id] = "train" if i < 4 else "val" + + captured = [] + + def report(progress=None, epoch=None): + if epoch is not None: + captured.append(epoch) + + ctx = TrainContext( + db=None, + run=None, + params={ + "model": "resnet10t", + "pretrained": False, + "epochs": 1, + "batch_size": 2, + "img_size": 32, + "num_workers": 0, + "amp": False, + "device": "cpu", + "sched": "none", + }, + task="classify", + class_names=["cat", "dog"], + assets=assets, + annotations_by_asset=anns, + splits=splits, + minio_client=_Minio(), + bucket="b", + work_dir=tmp_path, + report=report, + ) + + result = training_pkg.get_trainer("timm").run(ctx) + assert result.best_model_path and Path(result.best_model_path).exists() + assert result.artifact_meta["arch"] == "resnet10t" + assert result.artifact_meta["classes"] == ["cat", "dog"] + assert captured and "top1" in captured[0] + + # The saved checkpoint is self-describing so export/inference can rebuild it. + import torch # type: ignore + + ckpt = torch.load(str(result.best_model_path), map_location="cpu", weights_only=False) + assert ckpt["vf_framework"] == "timm" + assert ckpt["classes"] == ["cat", "dog"] diff --git a/backend/tests/unit/test_training_datasets.py b/backend/tests/unit/test_training_datasets.py new file mode 100644 index 0000000..c51b399 --- /dev/null +++ b/backend/tests/unit/test_training_datasets.py @@ -0,0 +1,85 @@ +"""Unit tests for the shared dataset exporters.""" + +from dataclasses import dataclass + +from app.services.training import datasets + + +@dataclass +class _Asset: + id: str + uri: str + width: int = 64 + height: int = 64 + + +@dataclass +class _Ann: + class_name: str + geometry: str = "{}" + type: str = "classification" + + +class _FakeResponse: + def __init__(self, data: bytes): + self._data = data + + def read(self): + return self._data + + def close(self): + pass + + def release_conn(self): + pass + + +class _FakeMinio: + """Returns a tiny fixed byte payload for any object.""" + + def get_object(self, bucket, key): + return _FakeResponse(b"\x89PNG\r\n") + + +def test_build_imagefolder_layout_and_test_holdout(tmp_path): + assets = [ + _Asset("a1", "s3://b/cat1.png"), + _Asset("a2", "s3://b/dog1.png"), + _Asset("a3", "s3://b/cat2.png"), + _Asset("a4", "s3://b/secret.png"), + ] + anns = { + "a1": [_Ann("cat")], + "a2": [_Ann("dog")], + "a3": [_Ann("cat")], + "a4": [_Ann("dog")], + } + splits = {"a1": "train", "a2": "train", "a3": "val", "a4": "test"} + + root = datasets.build_imagefolder( + assets, anns, ["cat", "dog"], tmp_path, _FakeMinio(), "b", splits + ) + + # train/cat + train/dog populated, val/cat populated. + assert (root / "train" / "cat" / "a1.png").exists() + assert (root / "train" / "dog" / "a2.png").exists() + assert (root / "val" / "cat" / "a3.png").exists() + # The test asset is held out entirely — never written anywhere. + assert not list(root.rglob("a4.png")) + + +def test_build_yolo_detect_writes_labels_and_yaml(tmp_path): + assets = [_Asset("a1", "s3://b/img.png", width=100, height=100)] + anns = {"a1": [_Ann("cat", geometry='{"x1": 10, "y1": 10, "x2": 50, "y2": 50}', type="box")]} + splits = {"a1": "train"} + + data_yaml = datasets.build_yolo_detect( + assets, anns, ["cat", "dog"], tmp_path, _FakeMinio(), "b", splits + ) + assert data_yaml.name == "data.yaml" + label = (tmp_path / "train" / "labels" / "a1.txt").read_text().strip() + # class index 0 (cat); center (0.3,0.3), size (0.4,0.4). + parts = label.split() + assert parts[0] == "0" + assert abs(float(parts[1]) - 0.3) < 1e-3 + assert abs(float(parts[3]) - 0.4) < 1e-3 diff --git a/backend/tests/unit/test_training_frameworks_api.py b/backend/tests/unit/test_training_frameworks_api.py new file mode 100644 index 0000000..f8d1d5e --- /dev/null +++ b/backend/tests/unit/test_training_frameworks_api.py @@ -0,0 +1,33 @@ +"""Contract tests for the training capabilities API.""" + +from fastapi.testclient import TestClient + +from app.db.deps import get_current_user +from app.main import app +from app.models.user import User + +# Auth is bearer-token based; bypass it with a stub user for these read-only tests. +app.dependency_overrides[get_current_user] = lambda: User(id="u1", email="t@t.io", name="t") +client = TestClient(app) + + +def test_list_frameworks_returns_capabilities(): + r = client.get("/api/training/frameworks") + assert r.status_code == 200 + items = r.json()["items"] + keys = {item["key"] for item in items} + assert {"ultralytics", "timm"} <= keys + timm = next(i for i in items if i["key"] == "timm") + assert timm["supported_tasks"] == ["classify"] + assert any(g["title"] == "Architecture" for g in timm["groups"]) + + +def test_search_models_unknown_framework_404(): + r = client.get("/api/training/frameworks/nope/models") + assert r.status_code == 404 + + +def test_search_models_ultralytics_filter(): + r = client.get("/api/training/frameworks/ultralytics/models?task=detect&query=yolov8s") + assert r.status_code == 200 + assert r.json()["items"] == ["yolov8s.pt"] diff --git a/backend/tests/unit/test_training_registry.py b/backend/tests/unit/test_training_registry.py new file mode 100644 index 0000000..b0d8859 --- /dev/null +++ b/backend/tests/unit/test_training_registry.py @@ -0,0 +1,59 @@ +"""Unit tests for the pluggable training-backend registry + capabilities. + +These exercise only the framework-agnostic surface, so they run without torch / +timm / ultralytics installed. +""" + +import pytest + +from app.services import training as training_pkg +from app.services.training.registry import UnknownFrameworkError + + +def test_builtin_frameworks_registered(): + keys = {t.key for t in training_pkg.list_frameworks()} + assert {"ultralytics", "timm"}.issubset(keys) + + +def test_get_trainer_defaults_to_ultralytics_for_backcompat(): + # None / empty framework maps to ultralytics so pre-framework runs keep working. + assert training_pkg.get_trainer(None).key == "ultralytics" + assert training_pkg.get_trainer("").key == "ultralytics" + + +def test_get_trainer_unknown_raises(): + with pytest.raises(UnknownFrameworkError): + training_pkg.get_trainer("does-not-exist") + + +def test_supported_tasks(): + assert "classify" in training_pkg.get_trainer("timm").supported_tasks + assert training_pkg.get_trainer("timm").supported_tasks == {"classify"} + assert "detect" in training_pkg.get_trainer("ultralytics").supported_tasks + + +def test_capabilities_serialize_to_plain_dict(): + for trainer in training_pkg.list_frameworks(): + caps = trainer.capabilities().to_dict() + assert caps["key"] == trainer.key + assert isinstance(caps["supported_tasks"], list) + assert isinstance(caps["groups"], list) and caps["groups"] + # Every field is a serializable FieldDef-shaped dict. + for group in caps["groups"]: + assert "title" in group and "fields" in group + for field in group["fields"]: + assert {"key", "label", "type", "default"} <= set(field) + assert isinstance(caps["device_options"], list) + + +def test_timm_architecture_group_has_searchable_model_field(): + caps = training_pkg.get_trainer("timm").capabilities().to_dict() + arch = next(g for g in caps["groups"] if g["title"] == "Architecture") + model_field = next(f for f in arch["fields"] if f["key"] == "model") + assert model_field["type"] == "model" # searchable backbone picker + + +def test_list_models_filters_by_query_for_ultralytics(): + ul = training_pkg.get_trainer("ultralytics") + assert ul.list_models("detect", "yolov8n") == ["yolov8n.pt"] + assert set(ul.list_models("classify")) == {f"yolov8{s}-cls.pt" for s in "nsmlx"} diff --git a/frontend/src/pages/experiments/new.tsx b/frontend/src/pages/experiments/new.tsx index 29c2f92..1a8da29 100644 --- a/frontend/src/pages/experiments/new.tsx +++ b/frontend/src/pages/experiments/new.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from 'react'; +import React, { useState, useEffect, useMemo } from 'react'; import { useNavigate, useSearchParams, Link } from 'react-router-dom'; import Input from '@/components/ui/Input'; import Select from '@/components/ui/Select'; @@ -19,9 +19,10 @@ interface Dataset { latest_version_id?: string; } -const BASE_MODELS = ['yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt']; - -type FieldType = 'number' | 'bool' | 'select'; +// Field schema, mirroring the backend FieldDef shape. The training form is +// rendered entirely from the capabilities the backend reports, so no framework +// (Ultralytics, Timm, …) is hard-coded here. +type FieldType = 'number' | 'bool' | 'select' | 'text' | 'model'; interface FieldDef { key: string; label: string; @@ -33,273 +34,36 @@ interface FieldDef { options?: string[]; help?: string; } +interface FieldGroup { + title: string; + fields: FieldDef[]; +} +interface FrameworkCapabilities { + key: string; + label: string; + supported_tasks: string[]; + models_by_task: Record; + groups: FieldGroup[]; + device_options: string[]; +} -// Single source of truth for every tunable hyperparameter / augmentation knob. -// Adding a field here exposes it in the UI and forwards it to the backend. -const GROUPS: { title: string; fields: FieldDef[] }[] = [ - { - title: 'Core', - fields: [ - { key: 'epochs', label: 'Epochs', type: 'number', default: 50, min: 1, max: 2000 }, - { key: 'batch', label: 'Batch Size', type: 'number', default: 16, min: 1, max: 512 }, - { - key: 'imgsz', - label: 'Image Size', - type: 'number', - default: 640, - min: 32, - max: 1920, - step: 32, - }, - { - key: 'patience', - label: 'Patience', - type: 'number', - default: 100, - min: 0, - max: 1000, - help: 'Early-stop after N epochs w/o improvement', - }, - { key: 'seed', label: 'Seed', type: 'number', default: 0, min: 0 }, - { - key: 'rect', - label: 'Rectangular', - type: 'bool', - default: false, - help: 'Rectangular batches (min padding)', - }, - { key: 'single_cls', label: 'Single class', type: 'bool', default: false }, - ], - }, - { - title: 'Optimizer & Schedule', - fields: [ - { - key: 'optimizer', - label: 'Optimizer', - type: 'select', - default: 'auto', - options: ['auto', 'SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp'], - }, - { - key: 'lr0', - label: 'Initial LR (lr0)', - type: 'number', - default: 0.01, - min: 0.00001, - max: 1, - step: 0.0001, - }, - { - key: 'lrf', - label: 'Final LR (lrf)', - type: 'number', - default: 0.01, - min: 0.00001, - max: 1, - step: 0.0001, - }, - { - key: 'momentum', - label: 'Momentum', - type: 'number', - default: 0.937, - min: 0, - max: 1, - step: 0.001, - }, - { - key: 'weight_decay', - label: 'Weight Decay', - type: 'number', - default: 0.0005, - min: 0, - max: 0.1, - step: 0.0001, - }, - { - key: 'warmup_epochs', - label: 'Warmup Epochs', - type: 'number', - default: 3.0, - min: 0, - max: 20, - step: 0.5, - }, - { - key: 'warmup_momentum', - label: 'Warmup Momentum', - type: 'number', - default: 0.8, - min: 0, - max: 1, - step: 0.01, - }, - { - key: 'warmup_bias_lr', - label: 'Warmup Bias LR', - type: 'number', - default: 0.1, - min: 0, - max: 1, - step: 0.01, - }, - { key: 'cos_lr', label: 'Cosine LR', type: 'bool', default: false }, - { - key: 'close_mosaic', - label: 'Close Mosaic', - type: 'number', - default: 10, - min: 0, - max: 100, - help: 'Disable mosaic for last N epochs', - }, - { key: 'nbs', label: 'Nominal Batch', type: 'number', default: 64, min: 1, max: 256 }, - { key: 'amp', label: 'AMP', type: 'bool', default: true, help: 'Automatic mixed precision' }, - ], - }, - { - title: 'Regularization & Loss Gains', - fields: [ - { - key: 'dropout', - label: 'Dropout', - type: 'number', - default: 0.0, - min: 0, - max: 1, - step: 0.01, - }, - { - key: 'label_smoothing', - label: 'Label Smoothing', - type: 'number', - default: 0.0, - min: 0, - max: 1, - step: 0.01, - }, - { key: 'box', label: 'Box Gain', type: 'number', default: 7.5, min: 0, max: 20, step: 0.1 }, - { key: 'cls', label: 'Cls Gain', type: 'number', default: 0.5, min: 0, max: 10, step: 0.1 }, - { key: 'dfl', label: 'DFL Gain', type: 'number', default: 1.5, min: 0, max: 10, step: 0.1 }, - { key: 'overlap_mask', label: 'Overlap Mask', type: 'bool', default: true }, - { key: 'mask_ratio', label: 'Mask Ratio', type: 'number', default: 4, min: 1, max: 16 }, - ], - }, - { - title: 'Augmentation', - fields: [ - { - key: 'hsv_h', - label: 'hsv_h', - type: 'number', - default: 0.015, - min: 0, - max: 1, - step: 0.001, - help: 'Hue jitter fraction', - }, - { - key: 'hsv_s', - label: 'hsv_s', - type: 'number', - default: 0.7, - min: 0, - max: 1, - step: 0.01, - help: 'Saturation jitter', - }, - { - key: 'hsv_v', - label: 'hsv_v', - type: 'number', - default: 0.4, - min: 0, - max: 1, - step: 0.01, - help: 'Value jitter', - }, - { - key: 'degrees', - label: 'degrees', - type: 'number', - default: 0.0, - min: 0, - max: 180, - step: 1, - help: 'Rotation range', - }, - { - key: 'translate', - label: 'translate', - type: 'number', - default: 0.1, - min: 0, - max: 1, - step: 0.01, - }, - { key: 'scale', label: 'scale', type: 'number', default: 0.5, min: 0, max: 1, step: 0.01 }, - { key: 'shear', label: 'shear', type: 'number', default: 0.0, min: 0, max: 10, step: 0.1 }, - { - key: 'perspective', - label: 'perspective', - type: 'number', - default: 0.0, - min: 0, - max: 0.001, - step: 0.0001, - }, - { key: 'flipud', label: 'flipud', type: 'number', default: 0.0, min: 0, max: 1, step: 0.01 }, - { key: 'fliplr', label: 'fliplr', type: 'number', default: 0.5, min: 0, max: 1, step: 0.01 }, - { key: 'bgr', label: 'bgr', type: 'number', default: 0.0, min: 0, max: 1, step: 0.01 }, - { key: 'mosaic', label: 'mosaic', type: 'number', default: 1.0, min: 0, max: 1, step: 0.01 }, - { key: 'mixup', label: 'mixup', type: 'number', default: 0.0, min: 0, max: 1, step: 0.01 }, - { - key: 'copy_paste', - label: 'copy_paste', - type: 'number', - default: 0.0, - min: 0, - max: 1, - step: 0.01, - }, - { - key: 'erasing', - label: 'erasing', - type: 'number', - default: 0.4, - min: 0, - max: 1, - step: 0.01, - }, - { - key: 'crop_fraction', - label: 'crop_fraction', - type: 'number', - default: 1.0, - min: 0, - max: 1, - step: 0.01, - }, - { - key: 'auto_augment', - label: 'auto_augment', - type: 'select', - default: 'randaugment', - options: ['randaugment', 'autoaugment', 'augmix'], - }, - ], - }, -]; - -const DEVICES = ['cpu', 'cuda', 'mps', '0', '0,1']; +const TASK_LABELS: Record = { + detect: 'Object Detection', + classify: 'Classification', + segment: 'Segmentation', + pose: 'Pose Estimation', +}; -function buildDefaults(): Record { +function buildDefaults(groups: FieldGroup[]): Record { const out: Record = {}; - for (const g of GROUPS) for (const f of g.fields) out[f.key] = f.default; + for (const g of groups) for (const f of g.fields) out[f.key] = f.default; return out; } +function hasModelField(groups: FieldGroup[]): boolean { + return groups.some((g) => g.fields.some((f) => f.type === 'model')); +} + function FieldLabel({ htmlFor, children }: { htmlFor: string; children: React.ReactNode }) { return (