Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions backend/src/app/api/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from app.db.deps import get_current_user, get_db
from app.models.user import User
from app.schemas.split import SplitConfig, SplitSummary
from app.services import split_service
from app.services.annotation_service import get_asset_annotations
from app.services.asset_service import (
confirm_upload,
Expand Down Expand Up @@ -102,6 +104,7 @@ def list_dataset_assets(
dataset_id: str = Path(...),
version_id: str | None = Query(None),
label_status: str | None = Query(None),
split: str | None = Query(None, description="Filter by train/val/test split"),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
Expand All @@ -112,6 +115,7 @@ def list_dataset_assets(
dataset_id,
version_id=version_id,
label_status=label_status,
split=split,
limit=limit,
offset=offset,
)
Expand All @@ -120,10 +124,12 @@ def list_dataset_assets(
{
"id": a.id,
"uri": a.uri,
"download_url": _presign_download(a.uri),
"mime_type": a.mime_type,
"width": a.width,
"height": a.height,
"label_status": a.label_status,
"split": split_service.asset_split(a),
"created_at": a.created_at.isoformat() if a.created_at else None,
}
for a in assets
Expand All @@ -134,6 +140,40 @@ def list_dataset_assets(
}


@router.get("/datasets/{dataset_id}/versions/{version_id}/split", response_model=SplitSummary)
def get_version_split(
dataset_id: str = Path(...),
version_id: str = Path(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return persisted train/val/test counts and per-class breakdown for a version."""
return split_service.get_split_summary(db, version_id)


@router.post("/datasets/{dataset_id}/versions/{version_id}/split", response_model=SplitSummary)
def assign_version_split(
body: SplitConfig,
dataset_id: str = Path(...),
version_id: str = Path(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Deterministically (re)assign and persist the split for every asset in a version."""
try:
return split_service.assign_splits(
db,
version_id,
train=body.train,
val=body.val,
test=body.test,
seed=body.seed,
stratify=body.stratify,
)
except split_service.SplitConfigError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc


@router.get("/datasets/{dataset_id}/stats")
def dataset_stats(
dataset_id: str = Path(...),
Expand Down
81 changes: 79 additions & 2 deletions backend/src/app/api/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _run_to_schema(e: ExperimentModel) -> ExperimentSchema:
params_json=e.params_json,
dataset_version_id=e.dataset_version_id,
metrics_json=e.metrics_json,
artifacts=e.artifacts,
status=e.status,
code_hash=e.code_hash,
started_at=e.started_at,
Expand Down Expand Up @@ -100,12 +101,15 @@ def get_metrics(
if not e:
raise HTTPException(status_code=404, detail="Run not found")
metrics: list = []
summary: dict | None = None
plots: list = []
split: dict | None = None
if e.metrics_json:
try:
data = json.loads(e.metrics_json)
# metrics_json may be:
# - a list of epoch dicts: [{epoch, mAP50, ...}, ...]
# - {"epochs": [{epoch, mAP50, ...}, ...]} as written by train_task
# - {"epochs": [...], "summary": {...}, "plots": [...], "split": {...}}
# - {"error": "..."} on failure
if isinstance(data, list):
metrics = data
Expand All @@ -114,6 +118,79 @@ def get_metrics(
metrics = data["epochs"]
elif "error" not in data:
metrics = [data]
summary = data.get("summary")
plots = data.get("plots") or []
split = data.get("split")
except Exception:
pass
return {"run_id": runId, "status": e.status, "metrics": metrics}

# Attach presigned GET urls so the frontend can render plots in <img> tags
# without forwarding the auth header (mirrors asset download_url).
if plots:
plots = [dict(p) for p in plots]
try:
import os
from datetime import timedelta

from app.services import storage

client = storage.get_minio_client()
bucket = os.getenv("MINIO_BUCKET", os.getenv("S3_BUCKET", "visionforge"))
for p in plots:
if p.get("key"):
try:
p["url"] = client.presigned_get_object(
bucket, p["key"], expires=timedelta(hours=1)
)
except Exception:
p["url"] = None
except Exception:
pass
return {
"run_id": runId,
"status": e.status,
"metrics": metrics,
"summary": summary,
"plots": plots,
"split": split,
}


@router.get("/runs/{runId}/plots/{name}")
def get_plot(
runId: str = Path(...),
name: str = Path(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Stream a training plot PNG/JPEG that was generated by the run."""
import os

from fastapi.responses import StreamingResponse

e = db.get(ExperimentModel, runId)
if not e or not e.metrics_json:
raise HTTPException(status_code=404, detail="Run or plots not found")
try:
plots = json.loads(e.metrics_json).get("plots") or []
except Exception:
plots = []
record = next((p for p in plots if p.get("name") == name or p.get("file") == name), None)
if not record:
raise HTTPException(status_code=404, detail="Plot not found")

key = record["key"]
try:
from app.services import storage

client = storage.get_minio_client()
bucket = os.getenv("MINIO_BUCKET", os.getenv("S3_BUCKET", "visionforge"))
data = storage.get_bytes(client, key, bucket=bucket)
except Exception as exc: # pragma: no cover - storage failure path
raise HTTPException(status_code=502, detail=f"could not fetch plot: {exc}") from exc

ext = key.rsplit(".", 1)[-1].lower()
media = "image/jpeg" if ext in ("jpg", "jpeg") else "image/png"
import io as _io

return StreamingResponse(_io.BytesIO(data), media_type=media)
Loading
Loading