diff --git a/backend/src/app/api/annotations.py b/backend/src/app/api/annotations.py index 0d9575d..e96bb6c 100644 --- a/backend/src/app/api/annotations.py +++ b/backend/src/app/api/annotations.py @@ -7,6 +7,7 @@ from app.db.deps import get_current_user, get_db from app.models.annotation import REVIEW_STATUSES from app.models.user import User +from app.services import inference_service, suggestion_service from app.services.annotation_service import ( AnnotationError, VersionConflictError, @@ -247,6 +248,78 @@ def review_queue_summary( return review_summary(db, dataset_id=dataset_id, version_id=version_id) +class SuggestRequest(BaseModel): + asset_id: str + artifact_id: str | None = None + score_threshold: float = Field(0.25, ge=0.0, le=1.0) + + +@router.post("/suggest") +def suggest( + body: SuggestRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Run a trained model on an asset and return proposed annotations. + + Suggestions are not persisted; the annotator overlays them and saves any + accepted ones through the normal bulk path. 404 when no model is available, + 502 when inference fails. + """ + try: + artifact, suggestions = suggestion_service.suggest_annotations( + db, + asset_id=body.asset_id, + artifact_id=body.artifact_id, + score_threshold=body.score_threshold, + ) + except suggestion_service.NoModelError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except suggestion_service.SuggestionError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except inference_service.InferenceError as exc: + raise HTTPException(status_code=502, detail=f"inference failed: {exc}") from exc + + return { + "artifact": { + "id": artifact.id, + "name": artifact.name, + "version": artifact.version, + }, + "suggestions": [ + { + "type": s.type, + "geometry": s.geometry, + "class_name": s.class_name, + "score": s.score, + } + for s in suggestions + ], + } + + +@router.get("/suggest/artifacts") +def suggest_artifacts( + dataset_id: str = Query(...), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """List models trained on this dataset, newest first (override dropdown).""" + arts = suggestion_service.candidate_artifacts_for_dataset(db, dataset_id) + return { + "items": [ + { + "id": a.id, + "name": a.name, + "version": a.version, + "format": a.format, + "created_at": a.created_at.isoformat() if a.created_at else None, + } + for a in arts + ] + } + + class ErrorMineRequest(BaseModel): artifact_id: str dataset_version_id: str diff --git a/backend/src/app/api/assets.py b/backend/src/app/api/assets.py index a753a6e..d0b55c9 100644 --- a/backend/src/app/api/assets.py +++ b/backend/src/app/api/assets.py @@ -10,7 +10,13 @@ from app.db.deps import get_current_user, get_db from app.models.user import User from app.services.annotation_service import get_asset_annotations -from app.services.asset_service import confirm_upload, get_asset, get_dataset_stats, list_assets +from app.services.asset_service import ( + confirm_upload, + get_asset, + get_dataset_metrics, + get_dataset_stats, + list_assets, +) router = APIRouter(prefix="/api", tags=["assets"]) @@ -138,6 +144,17 @@ def dataset_stats( return get_dataset_stats(db, dataset_id, version_id=version_id) +@router.get("/datasets/{dataset_id}/metrics") +def dataset_metrics( + dataset_id: str = Path(...), + version_id: str | None = Query(None), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Detailed dataset health metrics for the metrics dashboard.""" + return get_dataset_metrics(db, dataset_id, version_id=version_id) + + @router.get("/assets/{asset_id}/neighbors") def get_asset_neighbors( asset_id: str = Path(...), diff --git a/backend/src/app/api/datasets.py b/backend/src/app/api/datasets.py index d0c1834..0f5ccc8 100644 --- a/backend/src/app/api/datasets.py +++ b/backend/src/app/api/datasets.py @@ -162,6 +162,10 @@ def get_dataset( classes = json.loads(class_map.classes) except Exception: classes = [] + # The newest version (latest) and the newest editable/unlocked version + # (open) — the UI writes new imagery/imports into the open version. + latest_v = versions[0] if versions else None + open_v = next((v for v in versions if not v.locked), None) return { "id": d.id, "project_id": d.project_id, @@ -169,6 +173,8 @@ def get_dataset( "description": d.description, "task_type": d.task_type, "classes": classes, + "latest_version_id": latest_v.id if latest_v else None, + "open_version_id": open_v.id if open_v else None, "versions": [ { "id": v.id, diff --git a/backend/src/app/services/asset_service.py b/backend/src/app/services/asset_service.py index fd89f8d..14ed9d9 100644 --- a/backend/src/app/services/asset_service.py +++ b/backend/src/app/services/asset_service.py @@ -1,14 +1,20 @@ from __future__ import annotations import json +from datetime import datetime, timedelta, timezone from sqlalchemy import func, select from sqlalchemy.orm import Session from app.models.annotation import Annotation from app.models.asset import Asset +from app.models.dataset import ClassMap, Dataset from app.models.dataset_version import DatasetVersion +# Cap on how many annotation geometries we parse in Python for the +# size/aspect-ratio histograms. Above this we sample and flag the result. +_GEOMETRY_SAMPLE_CAP = 5000 + def get_asset(db: Session, asset_id: str) -> Asset | None: return db.get(Asset, asset_id) @@ -100,3 +106,236 @@ def get_dataset_stats(db: Session, dataset_id: str, version_id: str | None = Non "class_distribution": class_counts, "annotation_count": sum(class_counts.values()), } + + +def _box_wh(geometry_json: str) -> tuple[float, float] | None: + """Parse a box annotation's geometry JSON into (width, height) in pixels.""" + try: + g = json.loads(geometry_json) + except Exception: + return None + w = g.get("w") + h = g.get("h") + if isinstance(w, (int, float)) and isinstance(h, (int, float)) and w > 0 and h > 0: + return float(w), float(h) + return None + + +def get_dataset_metrics(db: Session, dataset_id: str, version_id: str | None = None) -> dict: + """Gold-standard dataset health metrics. + + Aggregates in SQL where possible; geometry histograms parse a capped sample + of annotation rows in Python. Scopes to a single version when given. + """ + _LABELED = ("labeled", "prelabeled") + + def _scope_assets(q): + q = q.where(Asset.dataset_id == dataset_id) + if version_id: + q = q.where(Asset.version_id == version_id) + return q + + def _scope_anns(q): + # Explicit select_from(Annotation): several callers select only column + # expressions (counts), so SQLAlchemy can't infer the join's left side. + q = ( + q.select_from(Annotation) + .join(Asset, Annotation.asset_id == Asset.id) + .where(Asset.dataset_id == dataset_id) + ) + if version_id: + q = q.where(Asset.version_id == version_id) + return q + + # --- Asset / workflow counts ------------------------------------------ + status_rows = db.execute( + _scope_assets(select(Asset.label_status, func.count())).group_by(Asset.label_status) + ).all() + status_counts = {s: c for s, c in status_rows} + total_assets = sum(status_counts.values()) + labeled = sum(status_counts.get(s, 0) for s in _LABELED) + + annotated_assets = ( + db.scalar(_scope_anns(select(func.count(func.distinct(Annotation.asset_id))))) or 0 + ) + empty_images = max(total_assets - annotated_assets, 0) + + # --- Review workflow --------------------------------------------------- + review_rows = db.execute( + _scope_anns(select(Annotation.review_status, func.count())).group_by( + Annotation.review_status + ) + ).all() + review_counts = {(s or "unreviewed"): c for s, c in review_rows} + flagged = db.scalar(_scope_anns(select(func.count())).where(Annotation.flagged.is_(True))) or 0 + + # --- Class balance ----------------------------------------------------- + instance_rows = db.execute( + _scope_anns(select(Annotation.class_name, func.count())).group_by(Annotation.class_name) + ).all() + image_rows = db.execute( + _scope_anns( + select(Annotation.class_name, func.count(func.distinct(Annotation.asset_id))) + ).group_by(Annotation.class_name) + ).all() + instance_counts = {(c or "(none)"): n for c, n in instance_rows} + image_counts = {(c or "(none)"): n for c, n in image_rows} + total_annotations = sum(instance_counts.values()) + + nonzero = [n for c, n in instance_counts.items() if c != "(none)" and n > 0] + imbalance_ratio = round(max(nonzero) / min(nonzero), 2) if len(nonzero) >= 1 else None + + # Defined-but-unused classes (declared in the ClassMap, never annotated). + defined_classes: list[str] = [] + ds = db.get(Dataset, dataset_id) + if ds and ds.class_map_id: + cm = db.get(ClassMap, ds.class_map_id) + if cm: + try: + for c in json.loads(cm.classes): + name = c if isinstance(c, str) else c.get("name") + if name: + defined_classes.append(name) + except Exception: + pass + used = {c for c in instance_counts if c != "(none)"} + unused_classes = [c for c in defined_classes if c not in used] + + # --- Annotation type breakdown ---------------------------------------- + type_rows = db.execute( + _scope_anns(select(Annotation.type, func.count())).group_by(Annotation.type) + ).all() + type_counts = {t: n for t, n in type_rows} + + # --- Annotations per image -------------------------------------------- + per_asset_rows = db.execute( + _scope_anns(select(Annotation.asset_id, func.count())).group_by(Annotation.asset_id) + ).all() + per_image_counts = [n for _, n in per_asset_rows] + per_image_hist = {"0": empty_images, "1": 0, "2-5": 0, "6-10": 0, "10+": 0} + for n in per_image_counts: + if n == 1: + per_image_hist["1"] += 1 + elif n <= 5: + per_image_hist["2-5"] += 1 + elif n <= 10: + per_image_hist["6-10"] += 1 + else: + per_image_hist["10+"] += 1 + per_image_mean = round(total_annotations / total_assets, 2) if total_assets else 0.0 + per_image_max = max(per_image_counts) if per_image_counts else 0 + + # --- Box geometry (area + aspect ratio), sampled ---------------------- + geo_rows = db.execute( + _scope_anns(select(Annotation.geometry)) + .where(Annotation.type == "box") + .limit(_GEOMETRY_SAMPLE_CAP + 1) + ).all() + geometry_sampled = len(geo_rows) > _GEOMETRY_SAMPLE_CAP + area_hist = {"small (<32²)": 0, "medium (<96²)": 0, "large (≥96²)": 0} + aspect_hist = {"tall (<0.5)": 0, "square (0.5-2)": 0, "wide (>2)": 0} + for (geom,) in geo_rows[:_GEOMETRY_SAMPLE_CAP]: + wh = _box_wh(geom) + if not wh: + continue + w, h = wh + area = w * h + if area < 32 * 32: + area_hist["small (<32²)"] += 1 + elif area < 96 * 96: + area_hist["medium (<96²)"] += 1 + else: + area_hist["large (≥96²)"] += 1 + ar = w / h + if ar < 0.5: + aspect_hist["tall (<0.5)"] += 1 + elif ar <= 2.0: + aspect_hist["square (0.5-2)"] += 1 + else: + aspect_hist["wide (>2)"] += 1 + + # --- Image resolution -------------------------------------------------- + res_rows = db.execute( + _scope_assets(select(Asset.width, Asset.height)).where(Asset.width.is_not(None)) + ).all() + res_hist = {"<640": 0, "640-1280": 0, "1280-1920": 0, "≥1920": 0} + areas: list[int] = [] + for w, h in res_rows: + if not w or not h: + continue + areas.append(w * h) + m = max(w, h) + if m < 640: + res_hist["<640"] += 1 + elif m < 1280: + res_hist["640-1280"] += 1 + elif m < 1920: + res_hist["1280-1920"] += 1 + else: + res_hist["≥1920"] += 1 + areas.sort() + if areas: + median_area = areas[len(areas) // 2] + resolution = { + "min_pixels": areas[0], + "max_pixels": areas[-1], + "median_pixels": median_area, + "histogram": res_hist, + "with_dimensions": len(areas), + } + else: + resolution = { + "min_pixels": None, + "max_pixels": None, + "median_pixels": None, + "histogram": res_hist, + "with_dimensions": 0, + } + + # --- Labeling velocity (last 30 days) --------------------------------- + since = datetime.now(timezone.utc) - timedelta(days=30) + vel_rows = db.execute( + _scope_anns(select(func.date(Annotation.created_at), func.count())) + .where(Annotation.created_at >= since) + .group_by(func.date(Annotation.created_at)) + ).all() + velocity = [{"date": str(d), "count": n} for d, n in vel_rows if d is not None] + velocity.sort(key=lambda r: r["date"]) + + coverage_pct = round(labeled / total_assets * 100, 1) if total_assets else 0.0 + + return { + "total_assets": total_assets, + "total_annotations": total_annotations, + "coverage_pct": coverage_pct, + "labeled": labeled, + "empty_images": empty_images, + "label_status_distribution": status_counts, + "review": { + "unreviewed": review_counts.get("unreviewed", 0), + "approved": review_counts.get("approved", 0), + "rejected": review_counts.get("rejected", 0), + "flagged": flagged, + }, + "class_balance": { + "instances": instance_counts, + "images": image_counts, + "imbalance_ratio": imbalance_ratio, + "defined_classes": defined_classes, + "unused_classes": unused_classes, + }, + "annotation_types": type_counts, + "per_image": { + "histogram": per_image_hist, + "mean": per_image_mean, + "max": per_image_max, + }, + "box_geometry": { + "area_histogram": area_hist, + "aspect_histogram": aspect_hist, + "sampled": geometry_sampled, + }, + "resolution": resolution, + "velocity": velocity, + "split": None, # train/val/test split not modeled yet + } diff --git a/backend/src/app/services/dataset_service.py b/backend/src/app/services/dataset_service.py index 14ea236..cf681bf 100644 --- a/backend/src/app/services/dataset_service.py +++ b/backend/src/app/services/dataset_service.py @@ -62,3 +62,16 @@ def snapshot_version(db: Session, dataset_id: str, notes: str | None = None) -> db.commit() db.refresh(ver) return ver + + +def latest_open_version(db: Session, dataset_id: str) -> DatasetVersion | None: + """Return the newest editable (unlocked) version, or None if all are locked. + + New imagery and annotations must land in an unlocked version; locked + versions are immutable snapshots. + """ + return db.scalars( + select(DatasetVersion) + .where(DatasetVersion.dataset_id == dataset_id, DatasetVersion.locked.is_(False)) + .order_by(DatasetVersion.version.desc()) + ).first() diff --git a/backend/src/app/services/suggestion_service.py b/backend/src/app/services/suggestion_service.py new file mode 100644 index 0000000..71b954e --- /dev/null +++ b/backend/src/app/services/suggestion_service.py @@ -0,0 +1,161 @@ +"""AI-assisted annotation suggestions. + +Resolves the model trained on an asset's dataset, runs inference on the image, +and returns predicted annotations *without persisting them*. The annotator +overlays these as proposals the user can accept, edit, or reject; accepted ones +are saved through the normal bulk-annotation path. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.models.artifact import ModelArtifact +from app.models.asset import Asset +from app.models.dataset_version import DatasetVersion +from app.models.experiment import ExperimentRun +from app.services import inference_service +from app.services.asset_fetch import fetch_asset_bytes + + +class SuggestionError(Exception): + """Base error for the suggestion flow.""" + + +class NoModelError(SuggestionError): + """No successful model is available for the asset's dataset.""" + + +def _extract_minio_key(uri: str) -> str: + """Extract the object key from an asset URI (mirrors prelabels.py).""" + for prefix in ("s3://", "minio://"): + if uri.startswith(prefix): + parts = uri[len(prefix) :].split("/", 1) + return parts[1] if len(parts) > 1 else uri + if uri.startswith(("http://", "https://")): + path = uri.split("/", 3) + return path[3] if len(path) > 3 else uri + return uri + + +def _load_image_bytes(asset: Asset) -> bytes: + """Load an asset's image bytes from object storage (or an HTTP URL). + + Tries the SSRF-safe HTTP fetcher first for absolute URLs, then falls back to + MinIO using the object key. Raises SuggestionError when the image can't be + retrieved so the router can surface a clear error. + """ + uri = asset.uri or "" + if uri.startswith(("http://", "https://")): + data = fetch_asset_bytes(uri) + if data: + return data + try: + from app.services.storage import get_minio_client + + bucket = os.getenv("MINIO_BUCKET", os.getenv("S3_BUCKET", "visionforge")) + client = get_minio_client() + response = client.get_object(bucket, _extract_minio_key(uri)) + try: + return response.read() + finally: + response.close() + response.release_conn() + except Exception as exc: # pragma: no cover - exercised via monkeypatch in tests + raise SuggestionError(f"could not load image for asset {asset.id}: {exc}") from exc + + +@dataclass +class Suggestion: + type: str # "box" | "classification" + geometry: dict + class_name: str + score: float + + +def candidate_artifacts_for_dataset(db: Session, dataset_id: str) -> list[ModelArtifact]: + """All artifacts from successful runs on any version of this dataset, newest first. + + Drives the annotator's model-override dropdown. + """ + rows = ( + db.execute( + select(ModelArtifact) + .join(ExperimentRun, ModelArtifact.run_id == ExperimentRun.id) + .join(DatasetVersion, ExperimentRun.dataset_version_id == DatasetVersion.id) + .where( + DatasetVersion.dataset_id == dataset_id, + ExperimentRun.status == "succeeded", + ) + .order_by(ModelArtifact.created_at.desc()) + ) + .scalars() + .all() + ) + return list(rows) + + +def latest_artifact_for_dataset(db: Session, dataset_id: str) -> ModelArtifact | None: + arts = candidate_artifacts_for_dataset(db, dataset_id) + return arts[0] if arts else None + + +def suggest_annotations( + db: Session, + *, + asset_id: str, + artifact_id: str | None = None, + score_threshold: float = 0.25, +) -> tuple[ModelArtifact, list[Suggestion]]: + """Run a trained model on an asset and map detections to suggestions. + + Raises NoModelError when no usable artifact exists; SuggestionError for a + missing asset or an artifact that does not belong to the asset's dataset. + """ + asset = db.get(Asset, asset_id) + if not asset: + raise SuggestionError("asset not found") + + if artifact_id: + artifact = db.get(ModelArtifact, artifact_id) + if not artifact: + raise NoModelError("requested model not found") + # The override must be a model trained on this dataset. + allowed = {a.id for a in candidate_artifacts_for_dataset(db, asset.dataset_id)} + if artifact.id not in allowed: + raise SuggestionError("model was not trained on this dataset") + else: + artifact = latest_artifact_for_dataset(db, asset.dataset_id) + if not artifact: + raise NoModelError("no successful model trained on this dataset") + + image_bytes = _load_image_bytes(asset) + result = inference_service.predict(artifact, image_bytes, score_threshold=score_threshold) + + suggestions: list[Suggestion] = [] + for det in result.get("detections", []): + bbox = det.get("bbox") or [0, 0, 0, 0] + x, y, w, h = bbox + suggestions.append( + Suggestion( + type="box", + geometry={"x": x, "y": y, "w": w, "h": h}, + class_name=str(det.get("class", "")), + score=float(det.get("score", 0.0)), + ) + ) + cls = result.get("classification") + if cls: + suggestions.append( + Suggestion( + type="classification", + geometry={"class": cls.get("class", "")}, + class_name=str(cls.get("class", "")), + score=float(cls.get("score", 0.0)), + ) + ) + return artifact, suggestions diff --git a/backend/tests/unit/test_dataset_metrics.py b/backend/tests/unit/test_dataset_metrics.py new file mode 100644 index 0000000..cf49406 --- /dev/null +++ b/backend/tests/unit/test_dataset_metrics.py @@ -0,0 +1,159 @@ +"""Unit tests for asset_service.get_dataset_metrics using SQLite in-memory DB.""" + +from __future__ import annotations + +import json + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from app.db.base import Base + + +def _register_models() -> None: + """Import the ORM models so their tables register on Base.metadata before + create_all (the app conftest does this transitively; we do it explicitly so + the suite also runs standalone).""" + import app.models.annotation # noqa: F401 + import app.models.asset # noqa: F401 + import app.models.dataset # noqa: F401 + import app.models.dataset_version # noqa: F401 + import app.models.project # noqa: F401 + import app.models.workspace # noqa: F401 + + +@pytest.fixture +def db(): + _register_models() + engine = create_engine("sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}) + Base.metadata.create_all(engine) + Session = sessionmaker(bind=engine, autoflush=False, autocommit=False) + session = Session() + yield session + session.close() + Base.metadata.drop_all(engine) + + +def _seed_dataset(db, dataset_id="ds-1"): + from app.models.dataset import Dataset + from app.models.dataset_version import DatasetVersion + from app.models.project import Project + from app.models.workspace import Workspace + + db.add(Workspace(id="ws-1", name="WS", created_by="user-1")) + db.add(Project(id="proj-1", workspace_id="ws-1", name="P", slug="p")) + ds = Dataset(id=dataset_id, project_id="proj-1", name="DS") + db.add(ds) + ver = DatasetVersion(id="ver-1", dataset_id=dataset_id, version=1) + db.add(ver) + db.commit() + return ds, ver + + +_ASSET_SEQ = [0] + + +def _seed_asset(db, ds, ver, *, label_status="unlabeled", width=None, height=None): + from app.models.asset import Asset + + _ASSET_SEQ[0] += 1 + aid = f"asset-{_ASSET_SEQ[0]}" + a = Asset( + id=aid, + dataset_id=ds.id, + version_id=ver.id, + uri=f"datasets/v1/{aid}.jpg", + mime_type="image/jpeg", + label_status=label_status, + width=width, + height=height, + ) + db.add(a) + db.commit() + return a + + +_ANN_SEQ = [0] + + +def _seed_annotation(db, asset, *, class_name="cat", geometry=None, ann_type="box"): + from app.models.annotation import Annotation + + _ANN_SEQ[0] += 1 + geom = geometry or json.dumps({"x": 0, "y": 0, "w": 10, "h": 10}) + ann = Annotation( + id=f"ann-{_ANN_SEQ[0]}", + asset_id=asset.id, + type=ann_type, + geometry=geom, + class_name=class_name, + author_id="user-1", + ) + db.add(ann) + db.commit() + return ann + + +def test_get_dataset_metrics(db): + from app.services import asset_service + + ds, ver = _seed_dataset(db) + # a1: labeled, 3 small cat boxes, with dimensions + a1 = _seed_asset(db, ds, ver, label_status="labeled", width=800, height=600) + for _ in range(3): + _seed_annotation(db, a1, class_name="cat") + # a2: labeled, 1 large dog box + a2 = _seed_asset(db, ds, ver, label_status="labeled", width=1920, height=1080) + _seed_annotation( + db, a2, class_name="dog", geometry=json.dumps({"x": 0, "y": 0, "w": 200, "h": 200}) + ) + # a3: empty image (no annotations) + _seed_asset(db, ds, ver, label_status="unlabeled", width=400, height=300) + + m = asset_service.get_dataset_metrics(db, ds.id) + + assert m["total_assets"] == 3 + assert m["total_annotations"] == 4 + assert m["empty_images"] == 1 + assert m["class_balance"]["instances"]["cat"] == 3 + assert m["class_balance"]["instances"]["dog"] == 1 + assert m["class_balance"]["images"]["cat"] == 1 + assert m["class_balance"]["imbalance_ratio"] == 3.0 + assert m["per_image"]["histogram"]["0"] == 1 + assert m["per_image"]["histogram"]["1"] == 1 + assert m["per_image"]["histogram"]["2-5"] == 1 + assert m["per_image"]["max"] == 3 + assert m["box_geometry"]["area_histogram"]["small (<32²)"] == 3 + assert m["box_geometry"]["area_histogram"]["large (≥96²)"] == 1 + assert m["box_geometry"]["sampled"] is False + assert m["resolution"]["with_dimensions"] == 3 + assert m["resolution"]["histogram"]["≥1920"] == 1 + + +def test_get_dataset_metrics_empty(db): + from app.services import asset_service + + ds, _ = _seed_dataset(db) + m = asset_service.get_dataset_metrics(db, ds.id) + assert m["total_assets"] == 0 + assert m["total_annotations"] == 0 + assert m["coverage_pct"] == 0.0 + assert m["class_balance"]["imbalance_ratio"] is None + + +def test_get_dataset_metrics_unused_classes(db): + from app.models.dataset import ClassMap + from app.services import asset_service + + ds, ver = _seed_dataset(db) + cm = ClassMap(id="cm-1", project_id="proj-1", classes=json.dumps(["cat", "dog", "bird"])) + db.add(cm) + ds.class_map_id = "cm-1" + db.add(ds) + db.commit() + a1 = _seed_asset(db, ds, ver, label_status="labeled", width=640, height=480) + _seed_annotation(db, a1, class_name="cat") + + m = asset_service.get_dataset_metrics(db, ds.id) + assert set(m["class_balance"]["unused_classes"]) == {"dog", "bird"} diff --git a/backend/tests/unit/test_suggestion_service.py b/backend/tests/unit/test_suggestion_service.py new file mode 100644 index 0000000..8c616e8 --- /dev/null +++ b/backend/tests/unit/test_suggestion_service.py @@ -0,0 +1,195 @@ +"""Unit tests for suggestion_service. Inference + storage are monkeypatched so +the suite stays hermetic (no real models, no MinIO).""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from app.db.base import Base + + +def _register_models() -> None: + """Import the ORM models so their tables register on Base.metadata before + create_all (the app conftest does this transitively; we do it explicitly so + the suite also runs standalone).""" + import app.models.annotation # noqa: F401 + import app.models.artifact # noqa: F401 + import app.models.asset # noqa: F401 + import app.models.dataset # noqa: F401 + import app.models.dataset_version # noqa: F401 + import app.models.experiment # noqa: F401 + import app.models.project # noqa: F401 + import app.models.user # noqa: F401 + import app.models.workspace # noqa: F401 + + +@pytest.fixture +def db(): + _register_models() + engine = create_engine("sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}) + Base.metadata.create_all(engine) + Session = sessionmaker(bind=engine, autoflush=False, autocommit=False) + session = Session() + yield session + session.close() + Base.metadata.drop_all(engine) + + +def _seed(db): + """Dataset + two versions + asset; returns ids.""" + from app.models.asset import Asset + from app.models.dataset import Dataset + from app.models.dataset_version import DatasetVersion + from app.models.project import Project + from app.models.workspace import Workspace + + db.add(Workspace(id="ws-1", name="WS", created_by="user-1")) + db.add(Project(id="proj-1", workspace_id="ws-1", name="P", slug="p")) + db.add(Dataset(id="ds-1", project_id="proj-1", name="DS")) + db.add(DatasetVersion(id="ver-1", dataset_id="ds-1", version=1)) + db.add( + Asset( + id="asset-1", + dataset_id="ds-1", + version_id="ver-1", + uri="datasets/v1/a.jpg", + mime_type="image/jpeg", + label_status="unlabeled", + ) + ) + db.commit() + + +def _add_run_and_artifact(db, *, run_id, artifact_id, status, created_offset_days, name="m"): + from app.models.artifact import ModelArtifact + from app.models.experiment import ExperimentRun + + db.add( + ExperimentRun( + id=run_id, + project_id="proj-1", + dataset_version_id="ver-1", + owner_id="user-1", + status=status, + ) + ) + db.add( + ModelArtifact( + id=artifact_id, + project_id="proj-1", + run_id=run_id, + name=name, + version="1", + type="yolo", + format="pt", + storage_path=f"models/{artifact_id}.pt", + created_at=datetime.now(timezone.utc) + timedelta(days=created_offset_days), + ) + ) + db.commit() + + +def test_latest_artifact_picks_newest_succeeded(db): + from app.services import suggestion_service + + _seed(db) + _add_run_and_artifact( + db, run_id="r1", artifact_id="art-old", status="succeeded", created_offset_days=-2 + ) + _add_run_and_artifact( + db, run_id="r2", artifact_id="art-new", status="succeeded", created_offset_days=-1 + ) + # A failed run's artifact must be ignored. + _add_run_and_artifact( + db, run_id="r3", artifact_id="art-fail", status="failed", created_offset_days=0 + ) + + latest = suggestion_service.latest_artifact_for_dataset(db, "ds-1") + assert latest is not None + assert latest.id == "art-new" + + candidates = suggestion_service.candidate_artifacts_for_dataset(db, "ds-1") + assert [a.id for a in candidates] == ["art-new", "art-old"] + + +def test_no_model_raises(db): + from app.services import suggestion_service + + _seed(db) + with pytest.raises(suggestion_service.NoModelError): + suggestion_service.suggest_annotations(db, asset_id="asset-1") + + +def test_suggest_maps_detections(db, monkeypatch): + from app.services import inference_service, suggestion_service + + _seed(db) + _add_run_and_artifact( + db, run_id="r1", artifact_id="art-1", status="succeeded", created_offset_days=-1 + ) + + monkeypatch.setattr(suggestion_service, "_load_image_bytes", lambda asset: b"img") + monkeypatch.setattr( + inference_service, + "predict", + lambda artifact, image_bytes, **kw: { + "detections": [ + {"class": "cat", "bbox": [10, 20, 30, 40], "score": 0.9}, + {"class": "dog", "bbox": [1, 2, 3, 4], "score": 0.5}, + ], + "classification": {"class": "animal", "score": 0.8}, + }, + ) + + artifact, suggestions = suggestion_service.suggest_annotations(db, asset_id="asset-1") + assert artifact.id == "art-1" + boxes = [s for s in suggestions if s.type == "box"] + assert len(boxes) == 2 + assert boxes[0].geometry == {"x": 10, "y": 20, "w": 30, "h": 40} + assert boxes[0].class_name == "cat" + assert boxes[0].score == 0.9 + cls = [s for s in suggestions if s.type == "classification"] + assert len(cls) == 1 + assert cls[0].class_name == "animal" + + +def test_override_must_belong_to_dataset(db, monkeypatch): + from app.models.artifact import ModelArtifact + from app.models.experiment import ExperimentRun + from app.services import suggestion_service + + _seed(db) + _add_run_and_artifact( + db, run_id="r1", artifact_id="art-1", status="succeeded", created_offset_days=-1 + ) + # An artifact from an unrelated run/dataset version. + db.add( + ExperimentRun( + id="r-other", + project_id="proj-1", + dataset_version_id="ver-other", + owner_id="user-1", + status="succeeded", + ) + ) + db.add( + ModelArtifact( + id="art-other", + project_id="proj-1", + run_id="r-other", + name="other", + version="1", + type="yolo", + format="pt", + storage_path="models/other.pt", + ) + ) + db.commit() + + monkeypatch.setattr(suggestion_service, "_load_image_bytes", lambda asset: b"img") + with pytest.raises(suggestion_service.SuggestionError): + suggestion_service.suggest_annotations(db, asset_id="asset-1", artifact_id="art-other") diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 339f330..1a063f7 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -12,7 +12,9 @@ import ProjectDashboard from "./pages/projects/[projectId]/index"; import DatasetUpload from "./pages/datasets/upload"; import DatasetVersion from "./pages/datasets/version"; import DatasetsIndex from "./pages/datasets/index"; +import DatasetNew from "./pages/datasets/new"; import DatasetDetail from "./pages/datasets/[datasetId]/index"; +import DatasetMetrics from "./pages/datasets/[datasetId]/metrics"; import DatasetAnnotateGateway from "./pages/datasets/[datasetId]/annotate"; import DatasetReviewQueue from "./pages/datasets/[datasetId]/review"; import ExperimentsIndex from "./pages/experiments/index"; @@ -228,9 +230,11 @@ export default function App() { } /> {/* Datasets */} } /> + } /> } /> } /> } /> + } /> } /> } /> {/* Experiments */} diff --git a/frontend/src/components/charts/BarChart.tsx b/frontend/src/components/charts/BarChart.tsx new file mode 100644 index 0000000..9eafc81 --- /dev/null +++ b/frontend/src/components/charts/BarChart.tsx @@ -0,0 +1,51 @@ +import React from 'react'; +import { colorAt } from '@/components/charts/colors'; + +export interface BarDatum { + label: string; + value: number; + /** Override the auto color (e.g. to flag an imbalanced/empty class). */ + color?: string; +} + +interface BarChartProps { + data: BarDatum[]; + /** Optional unit suffix for the value labels. */ + unit?: string; + maxBars?: number; +} + +/** Horizontal bar chart — used for per-class instance/image counts. */ +export default function BarChart({ data, unit = '', maxBars = 30 }: BarChartProps) { + const rows = [...data].sort((a, b) => b.value - a.value).slice(0, maxBars); + const max = rows.reduce((m, d) => Math.max(m, d.value), 0) || 1; + + if (rows.length === 0) { + return

No data

; + } + + return ( +
+ {rows.map((d, i) => ( +
+ + {d.label} + +
+
+
+ + {d.value} + {unit} + +
+ ))} +
+ ); +} diff --git a/frontend/src/components/charts/DonutChart.tsx b/frontend/src/components/charts/DonutChart.tsx new file mode 100644 index 0000000..475ce9c --- /dev/null +++ b/frontend/src/components/charts/DonutChart.tsx @@ -0,0 +1,101 @@ +import React from 'react'; + +export interface DonutSegment { + label: string; + value: number; + color: string; +} + +interface DonutChartProps { + segments: DonutSegment[]; + size?: number; + centerLabel?: string; + centerValue?: string | number; +} + +/** Donut chart built from stacked stroke-dasharray arcs. */ +export default function DonutChart({ + segments, + size = 140, + centerLabel, + centerValue, +}: DonutChartProps) { + const total = segments.reduce((s, seg) => s + seg.value, 0); + const stroke = 16; + const radius = (size - stroke) / 2; + const circ = 2 * Math.PI * radius; + + let offset = 0; + const arcs = segments + .filter((s) => s.value > 0) + .map((seg) => { + const frac = total > 0 ? seg.value / total : 0; + const dash = frac * circ; + const arc = ( + + ); + offset += dash; + return arc; + }); + + return ( +
+ + + {arcs} + {(centerValue !== undefined || centerLabel) && ( + <> + + {centerValue} + + {centerLabel && ( + + {centerLabel.toUpperCase()} + + )} + + )} + +
+ {segments.map((seg) => ( +
+ + {seg.label} + {seg.value} +
+ ))} +
+
+ ); +} diff --git a/frontend/src/components/charts/Histogram.tsx b/frontend/src/components/charts/Histogram.tsx new file mode 100644 index 0000000..6c5414f --- /dev/null +++ b/frontend/src/components/charts/Histogram.tsx @@ -0,0 +1,45 @@ +import React from 'react'; + +export interface HistBucket { + label: string; + value: number; +} + +interface HistogramProps { + data: HistBucket[]; + color?: string; + height?: number; +} + +/** Vertical bucketed histogram — per-image counts, bbox area, resolution. */ +export default function Histogram({ + data, + color = 'var(--hud-accent)', + height = 120, +}: HistogramProps) { + const max = data.reduce((m, d) => Math.max(m, d.value), 0) || 1; + if (data.length === 0) { + return

No data

; + } + + return ( +
+ {data.map((d) => ( +
+ {d.value} +
0 ? 2 : 0)}px`, + background: color, + minHeight: d.value > 0 ? 2 : 0, + }} + /> + + {d.label} + +
+ ))} +
+ ); +} diff --git a/frontend/src/components/charts/Sparkline.tsx b/frontend/src/components/charts/Sparkline.tsx new file mode 100644 index 0000000..cff2c92 --- /dev/null +++ b/frontend/src/components/charts/Sparkline.tsx @@ -0,0 +1,55 @@ +import React from 'react'; + +interface SparklinePoint { + date: string; + count: number; +} + +interface SparklineProps { + data: SparklinePoint[]; + height?: number; +} + +/** Compact line chart for labeling velocity over time. */ +export default function Sparkline({ data, height = 120 }: SparklineProps) { + if (data.length === 0) { + return

No recent activity

; + } + const w = 500; + const h = height; + const pad = 8; + const max = data.reduce((m, d) => Math.max(m, d.count), 0) || 1; + const n = data.length; + const x = (i: number) => (n === 1 ? w / 2 : pad + (i / (n - 1)) * (w - 2 * pad)); + const y = (v: number) => h - pad - (v / max) * (h - 2 * pad); + const points = data.map((d, i) => `${x(i)},${y(d.count)}`).join(' '); + const total = data.reduce((s, d) => s + d.count, 0); + + return ( +
+ + + {data.map((d, i) => ( + + ))} + +
+ {data[0].date} + {total} annotations / 30d + {data[data.length - 1].date} +
+
+ ); +} diff --git a/frontend/src/components/charts/StatTile.tsx b/frontend/src/components/charts/StatTile.tsx new file mode 100644 index 0000000..5b104b9 --- /dev/null +++ b/frontend/src/components/charts/StatTile.tsx @@ -0,0 +1,31 @@ +import React from 'react'; + +interface StatTileProps { + label: string; + value: string | number; + sub?: string; + tone?: 'default' | 'accent' | 'success' | 'warning' | 'danger'; +} + +const TONE: Record, string> = { + default: 'var(--hud-text-data)', + accent: 'var(--hud-text-accent)', + success: 'var(--hud-success-text)', + warning: 'var(--hud-warning-text)', + danger: 'var(--hud-danger-text)', +}; + +/** Big-number metric tile in the HUD summary-chip style. */ +export default function StatTile({ label, value, sub, tone = 'default' }: StatTileProps) { + return ( +
+
{label}
+
+ {value} +
+ {sub && ( +
{sub}
+ )} +
+ ); +} diff --git a/frontend/src/components/charts/colors.ts b/frontend/src/components/charts/colors.ts new file mode 100644 index 0000000..55b7ddb --- /dev/null +++ b/frontend/src/components/charts/colors.ts @@ -0,0 +1,16 @@ +// Shared HUD-muted chart palette (OKLCH), matching the annotator class colors +// and the experiments metric chart so visuals stay consistent across the app. +export const CHART_COLORS = [ + 'oklch(0.72 0.10 82)', // amber / accent + 'oklch(0.60 0.10 155)', // green + 'oklch(0.68 0.16 20)', // red + 'oklch(0.72 0.08 230)', // blue + 'oklch(0.70 0.10 75)', // gold + 'oklch(0.65 0.10 200)', // teal + 'oklch(0.62 0.10 100)', // olive + 'oklch(0.58 0.12 310)', // violet + 'oklch(0.68 0.10 40)', // orange + 'oklch(0.60 0.08 180)', // cyan +]; + +export const colorAt = (i: number): string => CHART_COLORS[i % CHART_COLORS.length]; diff --git a/frontend/src/components/datasets/ArchiveImporter.tsx b/frontend/src/components/datasets/ArchiveImporter.tsx new file mode 100644 index 0000000..6f1097b --- /dev/null +++ b/frontend/src/components/datasets/ArchiveImporter.tsx @@ -0,0 +1,160 @@ +import React, { useEffect, useState } from 'react'; +import Select from '@/components/ui/Select'; +import Input from '@/components/ui/Input'; +import Button from '@/components/ui/Button'; +import Alert from '@/components/ui/Alert'; +import Spinner from '@/components/ui/Spinner'; +import { apiGet, apiUrl } from '@/services/api'; +import { getStoredToken } from '@/services/token-store'; + +export interface ImportResult { + dataset_id: string; + version_id: string; + format: string; + asset_count: number; + annotation_count: number; + classes: string[]; + warnings: string[]; +} + +interface ArchiveImporterProps { + datasetId: string; + /** Target an existing open version; omit to let the backend create one. */ + versionId?: string; + onComplete?: (result: ImportResult) => void; +} + +const DEFAULT_FORMATS = ['coco', 'yolo', 'pascal_voc', 'cvat', 'labelme', 'datumaro']; + +/** + * Imports a labeled archive (zip) into a dataset via `/api/datasets/{id}/import`. + * Uses a raw multipart fetch (the api.ts helpers are JSON-only). Extracted from + * the legacy import page so it can be embedded in the wizard and detail tab. + */ +export default function ArchiveImporter({ + datasetId, + versionId, + onComplete, +}: ArchiveImporterProps) { + const [formats, setFormats] = useState(DEFAULT_FORMATS); + const [fmt, setFmt] = useState('coco'); + const [imageUriBase, setImageUriBase] = useState(''); + const [file, setFile] = useState(null); + const [result, setResult] = useState(null); + const [error, setError] = useState(null); + const [loading, setLoading] = useState(false); + + useEffect(() => { + apiGet<{ formats: string[] }>('/api/datasets/formats') + .then((r) => r.formats?.length && setFormats(r.formats)) + .catch(() => {}); + }, []); + + async function onSubmit(e: React.FormEvent) { + e.preventDefault(); + if (!datasetId || !file) { + setError('Pick an archive to import'); + return; + } + setLoading(true); + setError(null); + setResult(null); + try { + const fd = new FormData(); + fd.append('file', file); + fd.append('fmt', fmt); + if (versionId) fd.append('version_id', versionId); + if (imageUriBase) fd.append('image_uri_base', imageUriBase); + const token = getStoredToken(); + const res = await fetch(apiUrl(`/api/datasets/${datasetId}/import`), { + method: 'POST', + body: fd, + headers: token ? { Authorization: `Bearer ${token}` } : {}, + }); + if (!res.ok) { + const detail = await res.json().catch(() => null); + throw new Error(detail?.detail || `HTTP ${res.status}`); + } + const json = (await res.json()) as ImportResult; + setResult(json); + onComplete?.(json); + } catch (err) { + setError(err instanceof Error ? err.message : 'Import failed'); + } finally { + setLoading(false); + } + } + + if (result) { + return ( + + Imported {result.asset_count} assets and {result.annotation_count} annotations ( + {result.classes.length} classes) into version{' '} + {result.version_id.slice(0, 8)}. + {result.warnings.length > 0 && ( +
+ {result.warnings.length} warning{result.warnings.length !== 1 ? 's' : ''} +
+ )} +
+ +
+
+ ); + } + + return ( +
+

+ Upload a zip in COCO, YOLO, Pascal VOC, CVAT, LabelMe, or Datumaro format. Images and labels + are added to the open version. +

+
+ + +
+
+ + setImageUriBase(e.target.value)} + placeholder="e.g. datasets//imported/" + /> +
+
+ + setFile(e.target.files?.[0] || null)} + className="text-xs font-mono" + /> +
+ {error && {error}} + +
+ ); +} diff --git a/frontend/src/components/datasets/DataSourcePanel.tsx b/frontend/src/components/datasets/DataSourcePanel.tsx new file mode 100644 index 0000000..f1f835f --- /dev/null +++ b/frontend/src/components/datasets/DataSourcePanel.tsx @@ -0,0 +1,62 @@ +import React, { useState } from 'react'; +import ImageUploader from '@/components/datasets/ImageUploader'; +import ArchiveImporter, { type ImportResult } from '@/components/datasets/ArchiveImporter'; + +type Tab = 'upload' | 'import'; + +interface DataSourcePanelProps { + datasetId: string; + /** Open (unlocked) version that new data should land in. */ + versionId: string; + onUploaded?: (count: number) => void; + onImported?: (result: ImportResult) => void; +} + +/** + * Tabbed data-entry surface: upload new imagery or import a labeled archive. + * Shared by the dataset creation wizard (step 2) and the detail "Data" tab. + */ +export default function DataSourcePanel({ + datasetId, + versionId, + onUploaded, + onImported, +}: DataSourcePanelProps) { + const [tab, setTab] = useState('upload'); + + const tabBtn = (id: Tab, label: string) => { + const active = tab === id; + return ( + + ); + }; + + return ( +
+
+ {tabBtn('upload', 'UPLOAD IMAGES')} + {tabBtn('import', 'IMPORT LABELED ARCHIVE')} +
+
+ {tab === 'upload' ? ( + + ) : ( + + )} +
+
+ ); +} diff --git a/frontend/src/components/datasets/ImageUploader.tsx b/frontend/src/components/datasets/ImageUploader.tsx new file mode 100644 index 0000000..e7c3232 --- /dev/null +++ b/frontend/src/components/datasets/ImageUploader.tsx @@ -0,0 +1,230 @@ +import React, { useRef, useState } from 'react'; +import Button from '@/components/ui/Button'; +import Alert from '@/components/ui/Alert'; +import Spinner from '@/components/ui/Spinner'; +import { apiPost } from '@/services/api'; + +interface UploadEntry { + name: string; + progress: number; + status: 'pending' | 'uploading' | 'done' | 'error'; + error?: string; +} + +interface ImageUploaderProps { + datasetId: string; + versionId: string; + /** Called after a batch finishes with the count of successfully uploaded files. */ + onComplete?: (uploaded: number) => void; +} + +/** + * Drag-and-drop image uploader. Requests a presigned PUT URL per file + * (`/api/ingest/upload-url`), streams the bytes with progress, then registers + * the asset (`/api/ingest/confirm`). Extracted from the legacy upload page so + * the creation wizard and the dataset detail "Data" tab can share it. + */ +export default function ImageUploader({ datasetId, versionId, onComplete }: ImageUploaderProps) { + const [files, setFiles] = useState([]); + const [uploads, setUploads] = useState([]); + const [uploading, setUploading] = useState(false); + const [error, setError] = useState(null); + const [allDone, setAllDone] = useState(false); + const fileInputRef = useRef(null); + + function selectFiles(selected: File[]) { + setFiles(selected); + setUploads([]); + setAllDone(false); + setError(null); + } + + function onDrop(e: React.DragEvent) { + e.preventDefault(); + selectFiles(Array.from(e.dataTransfer.files)); + } + + async function onUpload() { + if (!files.length || !datasetId || !versionId) { + setError('Create a dataset version first before uploading.'); + return; + } + setUploading(true); + setError(null); + setAllDone(false); + + const initial: UploadEntry[] = files.map((f) => ({ + name: f.name, + progress: 0, + status: 'pending', + })); + setUploads(initial); + + let hasError = false; + let uploaded = 0; + for (let i = 0; i < files.length; i++) { + const file = files[i]; + setUploads((prev) => prev.map((u, idx) => (idx === i ? { ...u, status: 'uploading' } : u))); + try { + const { url, objectKey } = await apiPost<{ + url: string; + fields: Record; + objectKey: string; + }>('/api/ingest/upload-url', { + datasetVersionId: versionId, + filename: file.name, + contentType: file.type, + }); + + await new Promise((resolve, reject) => { + const xhr = new XMLHttpRequest(); + xhr.upload.onprogress = (e) => { + if (e.lengthComputable) { + const pct = Math.round((e.loaded / e.total) * 100); + setUploads((prev) => prev.map((u, idx) => (idx === i ? { ...u, progress: pct } : u))); + } + }; + xhr.onload = () => + xhr.status < 300 ? resolve() : reject(new Error(`Upload failed: ${xhr.status}`)); + xhr.onerror = () => reject(new Error('Network error during upload')); + xhr.open('PUT', url); + xhr.setRequestHeader('Content-Type', file.type); + xhr.send(file); + }); + + await apiPost('/api/ingest/confirm', { + dataset_id: datasetId, + version_id: versionId, + storage_key: objectKey, + filename: file.name, + content_type: file.type, + }); + + uploaded += 1; + setUploads((prev) => + prev.map((u, idx) => (idx === i ? { ...u, progress: 100, status: 'done' } : u)), + ); + } catch (err) { + hasError = true; + const msg = err instanceof Error ? err.message : 'Upload failed'; + setUploads((prev) => + prev.map((u, idx) => (idx === i ? { ...u, status: 'error', error: msg } : u)), + ); + setError(msg); + } + } + + setUploading(false); + if (!hasError) setAllDone(true); + onComplete?.(uploaded); + } + + const doneCount = uploads.filter((u) => u.status === 'done').length; + + return ( +
+
e.preventDefault()} + onClick={() => fileInputRef.current?.click()} + className="flex flex-col items-center justify-center border border-dashed border-[var(--hud-border-strong)] bg-[var(--hud-inset)] p-8 cursor-pointer hover:border-[var(--hud-accent)] hover:bg-[var(--hud-elevated)] transition-colors group" + > +
+ +
+

+ Drop images here or click to browse +

+

JPG · PNG · WEBP

+ selectFiles(Array.from(e.target.files || []))} + /> +
+ + {files.length > 0 && uploads.length === 0 && ( +
+ {files.length} file(s) selected +
+ )} + + {uploads.length > 0 && ( +
+ {uploads.map((u, i) => ( +
+
+ {u.name} + + {u.status === 'done' + ? '✓ DONE' + : u.status === 'error' + ? '✗ ERROR' + : u.status === 'uploading' + ? `${u.progress}%` + : 'PENDING'} + +
+
+
+
+ {u.error && ( +

+ {u.error} +

+ )} +
+ ))} +
+ )} + + {error && {error}} + +
+ + {allDone && ( + + ✓ {doneCount} file(s) uploaded + + )} +
+
+ ); +} diff --git a/frontend/src/pages/annotate/Annotator.jsx b/frontend/src/pages/annotate/Annotator.tsx similarity index 69% rename from frontend/src/pages/annotate/Annotator.jsx rename to frontend/src/pages/annotate/Annotator.tsx index edf6c94..2f6f43c 100644 --- a/frontend/src/pages/annotate/Annotator.jsx +++ b/frontend/src/pages/annotate/Annotator.tsx @@ -1,6 +1,16 @@ import React, { useCallback, useEffect, useRef, useState } from 'react'; import { useNavigate, useParams } from 'react-router-dom'; import { apiGet, apiPost, apiDelete, apiUrl } from '@/services/api'; +import { useSuggestions } from './useSuggestions'; +import type { + Annotation, + BoxGeometry, + Mode, + NeighborInfo, + Point, + PointsGeometry, + Suggestion, +} from './types'; // --------------------------------------------------------------------------- // Constants @@ -30,37 +40,37 @@ const HUD = { textMuted: 'oklch(0.42 0.008 240)', textData: 'oklch(0.96 0.003 240)', accent: 'oklch(0.72 0.10 82)', + suggest: 'oklch(0.72 0.08 230)', }; // --------------------------------------------------------------------------- // Utility // --------------------------------------------------------------------------- -function classColor(className, classes) { +function classColor(className: string, classes: string[]): string { const idx = classes.indexOf(className); return CLASS_COLORS[(idx === -1 ? 0 : idx) % CLASS_COLORS.length]; } -function deepClone(value) { +function deepClone(value: T): T { if (typeof structuredClone === 'function') return structuredClone(value); return JSON.parse(JSON.stringify(value)); } -function distSq(a, b) { +function distSq(a: Point, b: Point): number { const dx = a.x - b.x; const dy = a.y - b.y; return dx * dx + dy * dy; } -function pointInPolygon(px, py, points) { +function pointInPolygon(px: number, py: number, points: Point[]): boolean { let inside = false; for (let i = 0, j = points.length - 1; i < points.length; j = i++) { const xi = points[i].x; const yi = points[i].y; const xj = points[j].x; const yj = points[j].y; - const intersect = - yi > py !== yj > py && px < ((xj - xi) * (py - yi)) / (yj - yi + 1e-9) + xi; + const intersect = yi > py !== yj > py && px < ((xj - xi) * (py - yi)) / (yj - yi + 1e-9) + xi; if (intersect) inside = !inside; } return inside; @@ -70,43 +80,57 @@ function pointInPolygon(px, py, points) { // Main component // --------------------------------------------------------------------------- +interface AssetData { + id: string; + dataset_id?: string; + filename?: string; + uri?: string; + download_url?: string; +} + export default function AnnotatorPage() { - const { assetId } = useParams(); + const { assetId } = useParams<{ assetId: string }>(); const navigate = useNavigate(); - const [asset, setAsset] = useState(null); - const [annotations, setAnnotations] = useState([]); - const [classes, setClasses] = useState(['object']); + const [asset, setAsset] = useState(null); + const [annotations, setAnnotations] = useState([]); + const [classes, setClasses] = useState(['object']); const [selectedClass, setSelectedClass] = useState('object'); - const [selectedAnnotationIdx, setSelectedAnnotationIdx] = useState(null); - // mode: 'box' | 'polygon' | 'keypoint' | 'classify' | 'select' - const [mode, setMode] = useState('box'); + const [selectedAnnotationIdx, setSelectedAnnotationIdx] = useState(null); + const [mode, setMode] = useState('box'); const [dirty, setDirty] = useState(false); const [status, setStatus] = useState('Loading…'); const [newClassName, setNewClassName] = useState(''); const [imageError, setImageError] = useState(false); const [scaleFactor, setScaleFactor] = useState(1); - const [neighbors, setNeighbors] = useState({ prev: null, next: null, index: null, total: 0 }); + const [neighbors, setNeighbors] = useState({ + prev: null, + next: null, + index: null, + total: 0, + }); const [datasetName, setDatasetName] = useState(''); const [imgDims, setImgDims] = useState({ w: 0, h: 0 }); const [showAddClass, setShowAddClass] = useState(false); - // In-progress polygon: array of points (image coordinates) - const [polygonInProgress, setPolygonInProgress] = useState([]); + const [polygonInProgress, setPolygonInProgress] = useState([]); + + const suggest = useSuggestions(asset?.dataset_id); // Undo/redo stacks. Each entry is a snapshot of the annotations array. - const undoStack = useRef([]); - const redoStack = useRef([]); + const undoStack = useRef([]); + const redoStack = useRef([]); const drawingRef = useRef({ active: false, startX: 0, startY: 0, currentX: 0, currentY: 0 }); - const canvasRef = useRef(null); - const imageRef = useRef(new Image()); - const annotationsRef = useRef(annotations); - const selectedIdxRef = useRef(selectedAnnotationIdx); + const canvasRef = useRef(null); + const imageRef = useRef(new Image()); + const annotationsRef = useRef(annotations); + const selectedIdxRef = useRef(selectedAnnotationIdx); const scaleRef = useRef(scaleFactor); - const classesRef = useRef(classes); + const classesRef = useRef(classes); const selectedClassRef = useRef(selectedClass); - const modeRef = useRef(mode); - const polygonInProgressRef = useRef(polygonInProgress); + const modeRef = useRef(mode); + const polygonInProgressRef = useRef(polygonInProgress); + const suggestionsRef = useRef(suggest.suggestions); useEffect(() => { annotationsRef.current = annotations; @@ -129,6 +153,9 @@ export default function AnnotatorPage() { useEffect(() => { polygonInProgressRef.current = polygonInProgress; }, [polygonInProgress]); + useEffect(() => { + suggestionsRef.current = suggest.suggestions; + }, [suggest.suggestions]); // ------------------------------------------------------------------------- // History (undo/redo) @@ -140,7 +167,7 @@ export default function AnnotatorPage() { redoStack.current = []; } - function applySnapshot(snapshot) { + function applySnapshot(snapshot: Annotation[]) { annotationsRef.current = snapshot; setAnnotations(snapshot); setSelectedAnnotationIdx(null); @@ -150,7 +177,7 @@ export default function AnnotatorPage() { function undo() { if (undoStack.current.length === 0) return; redoStack.current.push(deepClone(annotationsRef.current)); - const prev = undoStack.current.pop(); + const prev = undoStack.current.pop()!; applySnapshot(prev); setStatus('Undid'); } @@ -158,7 +185,7 @@ export default function AnnotatorPage() { function redo() { if (redoStack.current.length === 0) return; undoStack.current.push(deepClone(annotationsRef.current)); - const next = redoStack.current.pop(); + const next = redoStack.current.pop()!; applySnapshot(next); setStatus('Redid'); } @@ -171,6 +198,7 @@ export default function AnnotatorPage() { const canvas = canvasRef.current; if (!canvas) return; const ctx = canvas.getContext('2d'); + if (!ctx) return; const img = imageRef.current; const sf = scaleRef.current; const anns = annotationsRef.current; @@ -178,6 +206,7 @@ export default function AnnotatorPage() { const cls = classesRef.current; const drawing = drawingRef.current; const polyInProgress = polygonInProgressRef.current; + const sugs = suggestionsRef.current; ctx.clearRect(0, 0, canvas.width, canvas.height); @@ -201,12 +230,12 @@ export default function AnnotatorPage() { ctx.setLineDash([]); if (ann.type === 'box') { - const { x, y, w, h } = ann.geometry; + const { x, y, w, h } = ann.geometry as BoxGeometry; ctx.strokeRect(x * sf, y * sf, w * sf, h * sf); labelTag(ctx, ann.class_name || '', x * sf, y * sf, color); if (isSel) drawHandlesForBox(ctx, x * sf, y * sf, w * sf, h * sf, color); } else if (ann.type === 'polygon') { - const pts = ann.geometry.points || []; + const pts = (ann.geometry as PointsGeometry).points || []; if (pts.length > 1) { ctx.beginPath(); ctx.moveTo(pts[0].x * sf, pts[0].y * sf); @@ -217,25 +246,30 @@ export default function AnnotatorPage() { ctx.fill(); ctx.globalAlpha = 1; ctx.stroke(); - labelTag( - ctx, - ann.class_name || '', - pts[0].x * sf, - pts[0].y * sf, - color - ); + labelTag(ctx, ann.class_name || '', pts[0].x * sf, pts[0].y * sf, color); if (isSel) pts.forEach((p) => drawPointHandle(ctx, p.x * sf, p.y * sf, color)); } } else if (ann.type === 'keypoint') { - const pts = ann.geometry.points || []; + const pts = (ann.geometry as PointsGeometry).points || []; pts.forEach((p, i) => { drawPointHandle(ctx, p.x * sf, p.y * sf, color, isSel ? 6 : 4); - if (i === 0) - labelTag(ctx, ann.class_name || '', p.x * sf - 2, p.y * sf, color); + if (i === 0) labelTag(ctx, ann.class_name || '', p.x * sf - 2, p.y * sf, color); }); } }); + // Suggestions — dashed overlays with a score label. + sugs.forEach((s) => { + if (s.type !== 'box') return; + const { x, y, w, h } = s.geometry as BoxGeometry; + ctx.strokeStyle = HUD.suggest; + ctx.lineWidth = 1.5; + ctx.setLineDash([6, 4]); + ctx.strokeRect(x * sf, y * sf, w * sf, h * sf); + ctx.setLineDash([]); + labelTag(ctx, `${s.class_name} ${s.score.toFixed(2)}`, x * sf, y * sf, HUD.suggest); + }); + // In-progress polygon if (modeRef.current === 'polygon' && polyInProgress.length > 0) { ctx.strokeStyle = HUD.accent; @@ -247,9 +281,7 @@ export default function AnnotatorPage() { ctx.lineTo(polyInProgress[i].x * sf, polyInProgress[i].y * sf); ctx.stroke(); ctx.setLineDash([]); - polyInProgress.forEach((p) => - drawPointHandle(ctx, p.x * sf, p.y * sf, HUD.accent, 4) - ); + polyInProgress.forEach((p) => drawPointHandle(ctx, p.x * sf, p.y * sf, HUD.accent, 4)); } // In-progress box drag @@ -263,7 +295,13 @@ export default function AnnotatorPage() { } }, []); - function labelTag(ctx, label, x, y, color) { + function labelTag( + ctx: CanvasRenderingContext2D, + label: string, + x: number, + y: number, + color: string, + ) { if (!label) return; ctx.font = '10px monospace'; const w = ctx.measureText(label).width + 8; @@ -275,7 +313,13 @@ export default function AnnotatorPage() { ctx.fillText(label, x + 4, y - 3); } - function drawPointHandle(ctx, x, y, color, size = 5) { + function drawPointHandle( + ctx: CanvasRenderingContext2D, + x: number, + y: number, + color: string, + size = 5, + ) { ctx.fillStyle = HUD.textData; ctx.strokeStyle = color; ctx.lineWidth = 1; @@ -285,7 +329,14 @@ export default function AnnotatorPage() { ctx.stroke(); } - function drawHandlesForBox(ctx, cx, cy, cw, ch, color) { + function drawHandlesForBox( + ctx: CanvasRenderingContext2D, + cx: number, + cy: number, + cw: number, + ch: number, + color: string, + ) { const handles = [ [cx, cy], [cx + cw / 2, cy], @@ -312,6 +363,7 @@ export default function AnnotatorPage() { setAnnotations([]); setSelectedAnnotationIdx(null); setPolygonInProgress([]); + suggest.clear(); undoStack.current = []; redoStack.current = []; setDirty(false); @@ -319,19 +371,20 @@ export default function AnnotatorPage() { async function load() { try { - const assetData = await apiGet(`/api/assets/${assetId}`); + const assetData = await apiGet(`/api/assets/${assetId}`); setAsset(assetData); - // Pre-populate the class sidebar from the dataset's ClassMap so - // annotators don't have to re-add classes every session. try { if (assetData.dataset_id) { - const dataset = await apiGet(`/api/datasets/${assetData.dataset_id}`); + const dataset = await apiGet<{ + name?: string; + classes?: Array; + }>(`/api/datasets/${assetData.dataset_id}`); if (dataset?.name) setDatasetName(dataset.name); const seeded = Array.isArray(dataset.classes) ? dataset.classes : []; const names = seeded .map((c) => (typeof c === 'string' ? c : c?.name)) - .filter(Boolean); + .filter((n): n is string => Boolean(n)); if (names.length > 0) { setClasses(names); setSelectedClass(names[0]); @@ -342,9 +395,13 @@ export default function AnnotatorPage() { } try { - const annsData = await apiGet(`/api/assets/${assetId}/annotations`); - const loaded = Array.isArray(annsData) ? annsData : annsData.items ?? []; - setAnnotations(loaded.map((a) => ({ ...a, isNew: false, dirty: false }))); + const annsData = await apiGet( + `/api/assets/${assetId}/annotations`, + ); + const loaded = Array.isArray(annsData) ? annsData : (annsData.items ?? []); + setAnnotations( + loaded.map((a) => ({ ...a, isNew: false, dirty: false, origin: 'manual' })), + ); const existingClasses = loaded.map((a) => a.class_name).filter(Boolean); if (existingClasses.length > 0) { setClasses((prev) => Array.from(new Set([...prev, ...existingClasses]))); @@ -354,7 +411,7 @@ export default function AnnotatorPage() { } try { - const n = await apiGet(`/api/assets/${assetId}/neighbors`); + const n = await apiGet(`/api/assets/${assetId}/neighbors`); setNeighbors(n); } catch { setNeighbors({ prev: null, next: null, index: null, total: 0 }); @@ -366,11 +423,7 @@ export default function AnnotatorPage() { img.onload = () => { const canvas = canvasRef.current; if (!canvas) return; - const sf = Math.min( - MAX_CANVAS_W / img.naturalWidth, - MAX_CANVAS_H / img.naturalHeight, - 1 - ); + const sf = Math.min(MAX_CANVAS_W / img.naturalWidth, MAX_CANVAS_H / img.naturalHeight, 1); canvas.width = Math.round(img.naturalWidth * sf); canvas.height = Math.round(img.naturalHeight * sf); setImgDims({ w: img.naturalWidth, h: img.naturalHeight }); @@ -387,47 +440,54 @@ export default function AnnotatorPage() { canvas.height = MAX_CANVAS_H; redraw(); }; - img.src = - assetData.download_url || assetData.uri || apiUrl(`/api/assets/${assetId}/file`); + img.src = assetData.download_url || assetData.uri || apiUrl(`/api/assets/${assetId}/file`); } catch (err) { - setStatus(`Error: ${err.message}`); + setStatus(`Error: ${err instanceof Error ? err.message : 'failed to load'}`); } } load(); + // eslint-disable-next-line react-hooks/exhaustive-deps }, [assetId, redraw]); useEffect(() => { redraw(); - }, [annotations, selectedAnnotationIdx, scaleFactor, polygonInProgress, redraw]); + }, [ + annotations, + selectedAnnotationIdx, + scaleFactor, + polygonInProgress, + suggest.suggestions, + redraw, + ]); // ------------------------------------------------------------------------- // Mouse handlers // ------------------------------------------------------------------------- - function canvasCoords(e) { - const canvas = canvasRef.current; + function canvasCoords(e: React.MouseEvent): Point { + const canvas = canvasRef.current!; const rect = canvas.getBoundingClientRect(); return { x: e.clientX - rect.left, y: e.clientY - rect.top }; } - function imageCoords(canvasX, canvasY) { + function imageCoords(canvasX: number, canvasY: number): Point { const sf = scaleRef.current; return { x: canvasX / sf, y: canvasY / sf }; } - function hitTestAnnotation(canvasX, canvasY) { + function hitTestAnnotation(canvasX: number, canvasY: number): number | null { const anns = annotationsRef.current; const sf = scaleRef.current; const { x: ix, y: iy } = imageCoords(canvasX, canvasY); for (let i = anns.length - 1; i >= 0; i--) { const ann = anns[i]; if (ann.type === 'box') { - const { x, y, w, h } = ann.geometry; + const { x, y, w, h } = ann.geometry as BoxGeometry; if (ix >= x && ix <= x + w && iy >= y && iy <= y + h) return i; } else if (ann.type === 'polygon') { - if (pointInPolygon(ix, iy, ann.geometry.points || [])) return i; + if (pointInPolygon(ix, iy, (ann.geometry as PointsGeometry).points || [])) return i; } else if (ann.type === 'keypoint') { - const pts = ann.geometry.points || []; + const pts = (ann.geometry as PointsGeometry).points || []; for (const p of pts) { if (Math.sqrt(distSq({ x: p.x * sf, y: p.y * sf }, { x: canvasX, y: canvasY })) < 8) return i; @@ -437,7 +497,7 @@ export default function AnnotatorPage() { return null; } - function handleMouseDown(e) { + function handleMouseDown(e: React.MouseEvent) { const { x, y } = canvasCoords(e); const m = modeRef.current; if (m === 'box') { @@ -446,10 +506,11 @@ export default function AnnotatorPage() { } else if (m === 'polygon') { const ip = imageCoords(x, y); const pts = polygonInProgressRef.current; - // double-click logic: finalize when clicking near the first point with >=3 pts if ( pts.length >= 3 && - Math.sqrt(distSq({ x: pts[0].x * scaleRef.current, y: pts[0].y * scaleRef.current }, { x, y })) < 10 + Math.sqrt( + distSq({ x: pts[0].x * scaleRef.current, y: pts[0].y * scaleRef.current }, { x, y }), + ) < 10 ) { finalizePolygon(); return; @@ -461,7 +522,7 @@ export default function AnnotatorPage() { } else if (m === 'keypoint') { const ip = imageCoords(x, y); pushHistory(); - const newAnn = { + const newAnn: Annotation = { id: null, type: 'keypoint', class_name: selectedClassRef.current, @@ -469,10 +530,11 @@ export default function AnnotatorPage() { isNew: true, dirty: true, version: 0, + origin: 'manual', }; let newIdx = 0; setAnnotations((prev) => { - newIdx = prev.length; // index of the appended item in the new array + newIdx = prev.length; const updated = [...prev, newAnn]; annotationsRef.current = updated; return updated; @@ -485,7 +547,7 @@ export default function AnnotatorPage() { } } - function handleMouseMove(e) { + function handleMouseMove(e: React.MouseEvent) { if (!drawingRef.current.active) return; const { x, y } = canvasCoords(e); drawingRef.current = { ...drawingRef.current, currentX: x, currentY: y }; @@ -506,13 +568,12 @@ export default function AnnotatorPage() { const imgW = rawW / sf; const imgH = rawH / sf; if (imgW < 5 || imgH < 5) { - // Treat a non-drag as a click: select whatever box is under the cursor. setSelectedAnnotationIdx(hitTestAnnotation(currentX, currentY)); redraw(); return; } pushHistory(); - const newAnn = { + const newAnn: Annotation = { id: null, type: 'box', class_name: selectedClassRef.current, @@ -520,6 +581,7 @@ export default function AnnotatorPage() { isNew: true, dirty: true, version: 0, + origin: 'manual', }; let newIdx = 0; setAnnotations((prev) => { @@ -540,7 +602,7 @@ export default function AnnotatorPage() { return; } pushHistory(); - const newAnn = { + const newAnn: Annotation = { id: null, type: 'polygon', class_name: selectedClassRef.current, @@ -548,6 +610,7 @@ export default function AnnotatorPage() { isNew: true, dirty: true, version: 0, + origin: 'manual', }; let newIdx = 0; setAnnotations((prev) => { @@ -568,8 +631,9 @@ export default function AnnotatorPage() { // ------------------------------------------------------------------------- useEffect(() => { - function onKeyDown(e) { - if (e.target.tagName === 'INPUT' || e.target.tagName === 'TEXTAREA') return; + function onKeyDown(e: KeyboardEvent) { + const target = e.target as HTMLElement; + if (target.tagName === 'INPUT' || target.tagName === 'TEXTAREA') return; if ((e.ctrlKey || e.metaKey) && e.key.toLowerCase() === 'z') { e.preventDefault(); if (e.shiftKey) redo(); @@ -601,19 +665,20 @@ export default function AnnotatorPage() { else if (e.key === 'p' || e.key === 'P') setMode('polygon'); else if (e.key === 'k' || e.key === 'K') setMode('keypoint'); else if (e.key === 'v' || e.key === 'V') setMode('select'); - else if (e.key === 'c' || e.key === 'C') setMode('classify'); + else if (e.key === 'c' || e.key === 'C') setMode('classification'); else if (e.key === 'ArrowLeft' && neighbors.prev) navigateAsset(-1); else if (e.key === 'ArrowRight' && neighbors.next) navigateAsset(1); } window.addEventListener('keydown', onKeyDown); return () => window.removeEventListener('keydown', onKeyDown); + // eslint-disable-next-line react-hooks/exhaustive-deps }, [neighbors.prev, neighbors.next]); // ------------------------------------------------------------------------- // Annotation CRUD // ------------------------------------------------------------------------- - function deleteAnnotation(idx) { + function deleteAnnotation(idx: number) { pushHistory(); setAnnotations((prev) => { const ann = prev[idx]; @@ -627,11 +692,11 @@ export default function AnnotatorPage() { setDirty(true); } - function setAnnotationClass(idx, className) { + function setAnnotationClass(idx: number, className: string) { pushHistory(); setAnnotations((prev) => { const updated = prev.map((a, i) => - i === idx ? { ...a, class_name: className, dirty: true } : a + i === idx ? { ...a, class_name: className, dirty: true } : a, ); annotationsRef.current = updated; return updated; @@ -639,11 +704,11 @@ export default function AnnotatorPage() { setDirty(true); } - function applyClassification(className) { + function applyClassification(className: string) { pushHistory(); setAnnotations((prev) => { const withoutClassify = prev.filter((a) => a.type !== 'classification'); - const newAnn = { + const newAnn: Annotation = { id: null, type: 'classification', class_name: className, @@ -651,6 +716,7 @@ export default function AnnotatorPage() { isNew: true, dirty: true, version: 0, + origin: 'manual', }; const updated = [...withoutClassify, newAnn]; annotationsRef.current = updated; @@ -661,12 +727,102 @@ export default function AnnotatorPage() { setStatus(`Classification → "${className}"`); } - async function handleSave() { + // ------------------------------------------------------------------------- + // Suggestions + // ------------------------------------------------------------------------- + + async function runSuggest() { + if (!assetId) return; + setStatus('Requesting suggestions…'); + try { + const count = await suggest.fetchSuggestions(assetId); + setStatus(count > 0 ? `${count} suggestion(s) ready` : 'No suggestions returned'); + } catch (err) { + setStatus(`Suggest failed: ${err instanceof Error ? err.message : 'error'}`); + } + } + + function acceptSuggestion(tempId: string, alsoSelect = false) { + const s = suggestionsRef.current.find((x) => x.tempId === tempId); + if (!s) return; + pushHistory(); + const newAnn: Annotation = { + id: null, + type: s.type, + class_name: s.class_name || selectedClassRef.current, + geometry: s.geometry, + isNew: true, + dirty: true, + version: 0, + origin: 'suggested', + }; + let newIdx = 0; + setAnnotations((prev) => { + newIdx = prev.length; + const updated = [...prev, newAnn]; + annotationsRef.current = updated; + return updated; + }); + if (s.class_name && !classesRef.current.includes(s.class_name)) { + setClasses((prev) => Array.from(new Set([...prev, s.class_name]))); + } + suggest.setSuggestions((prev) => prev.filter((x) => x.tempId !== tempId)); + setDirty(true); + if (alsoSelect) { + setMode('select'); + setSelectedAnnotationIdx(newIdx); + } + } + + function rejectSuggestion(tempId: string) { + suggest.setSuggestions((prev) => prev.filter((x) => x.tempId !== tempId)); + } + + function acceptAllSuggestions() { + const all = suggestionsRef.current; + if (all.length === 0) return; + pushHistory(); + const newAnns: Annotation[] = all.map((s) => ({ + id: null, + type: s.type, + class_name: s.class_name || selectedClassRef.current, + geometry: s.geometry, + isNew: true, + dirty: true, + version: 0, + origin: 'suggested', + })); + setAnnotations((prev) => { + const updated = [...prev, ...newAnns]; + annotationsRef.current = updated; + return updated; + }); + const newClasses = all.map((s) => s.class_name).filter(Boolean); + if (newClasses.length > 0) { + setClasses((prev) => Array.from(new Set([...prev, ...newClasses]))); + } + suggest.clear(); + setDirty(true); + setStatus(`Accepted ${newAnns.length} suggestion(s)`); + } + + const handleSave = useCallback(async () => { if (!assetId) return; setStatus('Saving…'); try { - const creates = []; - const updates = []; + const creates: Array<{ + client_id: string; + asset_id: string; + type: string; + geometry: unknown; + class_name: string; + }> = []; + const updates: Array<{ + id: string; + geometry: unknown; + class_name: string; + expected_version: number; + }> = []; const current = annotationsRef.current; current.forEach((ann, idx) => { if (ann.isNew || !ann.id) { @@ -691,19 +847,25 @@ export default function AnnotatorPage() { setStatus('Nothing to save'); return; } - const result = await apiPost('/api/annotations/bulk', { + const result = await apiPost<{ + created?: Array<{ client_id?: string; status: string; annotation?: Annotation }>; + updated?: Array<{ id: string; status: string; annotation?: Annotation }>; + }>('/api/annotations/bulk', { creates, updates, deletes: [], }); - const createdById = new Map(); + const createdById = new Map(); (result.created || []).forEach((row) => { if (row.client_id && row.status === 'ok' && row.annotation) { createdById.set(row.client_id, row.annotation); } }); - const updatedById = new Map(); + const updatedById = new Map< + string, + { id: string; status: string; annotation?: Annotation } + >(); (result.updated || []).forEach((row) => { if (row.id) updatedById.set(row.id, row); }); @@ -712,13 +874,7 @@ export default function AnnotatorPage() { if (ann.isNew || !ann.id) { const created = createdById.get(`idx-${idx}`); if (created) { - return { - ...ann, - id: created.id, - version: created.version, - isNew: false, - dirty: false, - }; + return { ...ann, id: created.id, version: created.version, isNew: false, dirty: false }; } return ann; } @@ -750,9 +906,10 @@ export default function AnnotatorPage() { setStatus(`Saved with ${errors.length} issue(s) — see status`); } } catch (err) { - setStatus(`Save failed: ${err.message}`); + setStatus(`Save failed: ${err instanceof Error ? err.message : 'error'}`); } - } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [assetId]); function addClass() { const trimmed = newClassName.trim(); @@ -762,7 +919,7 @@ export default function AnnotatorPage() { setNewClassName(''); } - function navigateAsset(direction) { + function navigateAsset(direction: number) { if (dirty && !window.confirm('You have unsaved changes. Leave anyway?')) return; const target = direction === 1 ? neighbors.next : neighbors.prev; if (target) { @@ -783,16 +940,15 @@ export default function AnnotatorPage() { const classificationAnn = annotations.find((a) => a.type === 'classification'); const visibleAnns = annotations.filter((a) => a.type !== 'classification'); - // Normalized 0–1 area of a shape, as a percentage of the frame. - function annoArea(ann) { + function annoArea(ann: Annotation): number { const W = imgDims.w || 1; const H = imgDims.h || 1; if (ann.type === 'box') { - const { w, h } = ann.geometry; + const { w, h } = ann.geometry as BoxGeometry; return ((w * h) / (W * H)) * 100; } if (ann.type === 'polygon') { - const pts = ann.geometry.points || []; + const pts = (ann.geometry as PointsGeometry).points || []; let a = 0; for (let i = 0, j = pts.length - 1; i < pts.length; j = i++) { a += pts[j].x * pts[i].y - pts[i].x * pts[j].y; @@ -804,8 +960,7 @@ export default function AnnotatorPage() { const frameNo = (neighbors.index ?? 0) + 1; const frameTotal = neighbors.total || frameNo; - const fileName = - asset?.filename || asset?.uri?.split('/').pop() || `frame_${frameNo}`; + const fileName = asset?.filename || asset?.uri?.split('/').pop() || `frame_${frameNo}`; const frameLabel = imgDims.w ? `${fileName} · ${imgDims.w}×${imgDims.h}` : fileName; function handleExit() { @@ -814,11 +969,11 @@ export default function AnnotatorPage() { else navigate(-1); } - const TOOLS = [ + const TOOLS: { id: Mode; label: string; key: string }[] = [ { id: 'box', label: 'BOX', key: 'B' }, { id: 'polygon', label: 'POLYGON', key: 'P' }, { id: 'keypoint', label: 'KEYPOINT', key: 'K' }, - { id: 'classify', label: 'CLASSIFY', key: 'C' }, + { id: 'classification', label: 'CLASSIFY', key: 'C' }, ]; const navBtn = @@ -852,7 +1007,9 @@ export default function AnnotatorPage() { ← EXIT - {'//'} {datasetName || 'dataset'} + + {'//'} {datasetName || 'dataset'} +
{/* Center — tools */} @@ -934,19 +1091,19 @@ export default function AnnotatorPage() { {classes.map((cls, i) => { const color = CLASS_COLORS[i % CLASS_COLORS.length]; const active = - mode === 'classify' + mode === 'classification' ? classificationAnn?.class_name === cls : selectedClass === cls; return ( ); })} - {mode === 'classify' && classificationAnn && ( + {mode === 'classification' && classificationAnn && (

✓ {classificationAnn.class_name}

@@ -1014,6 +1167,59 @@ export default function AnnotatorPage() { {/* Center — canvas + frame nav */}
+ {/* Suggestion bar */} +
+ + + {suggest.suggestions.length > 0 && ( + <> + + {suggest.suggestions.length} pending + + + + + )} + {!suggest.hasModel && ( + + Train a model on this dataset to enable AI suggestions + + )} +
+
- {/* Right — Annotations */} + {/* Right — Annotations + Suggestions */}
+ {suggest.suggestions.length > 0 && ( +
+
+ Suggestions + + {suggest.suggestions.length} + +
+
+ {suggest.suggestions.map((s) => ( +
+ +
+
+ {s.class_name} +
+
+ {(s.score * 100).toFixed(0)}% conf +
+
+ + + +
+ ))} +
+
+ )}
Annotations @@ -1138,6 +1396,9 @@ export default function AnnotatorPage() {
{ann.class_name} + {ann.origin === 'suggested' && ( + + )}
{annoArea(ann).toFixed(1)}% area diff --git a/frontend/src/pages/annotate/types.ts b/frontend/src/pages/annotate/types.ts new file mode 100644 index 0000000..7c97f94 --- /dev/null +++ b/frontend/src/pages/annotate/types.ts @@ -0,0 +1,59 @@ +// Shared types for the annotator. + +export type AnnotationType = 'box' | 'polygon' | 'keypoint' | 'classification'; +export type Mode = AnnotationType | 'select'; + +export interface Point { + x: number; + y: number; +} + +export interface BoxGeometry { + x: number; + y: number; + w: number; + h: number; +} +export interface PointsGeometry { + points: Point[]; +} +export interface ClassGeometry { + class: string; +} +export type Geometry = BoxGeometry | PointsGeometry | ClassGeometry; + +export interface Annotation { + id: string | null; + type: AnnotationType; + class_name: string; + geometry: Geometry; + version: number; + isNew: boolean; + dirty: boolean; + /** Marks annotations that originated from an accepted model suggestion. */ + origin?: 'manual' | 'suggested'; +} + +export interface Suggestion { + /** Stable client-side id, e.g. `sug-3`. */ + tempId: string; + type: AnnotationType; + class_name: string; + geometry: Geometry; + score: number; +} + +export interface SuggestModel { + id: string; + name: string; + version: string; + format?: string | null; + created_at?: string | null; +} + +export interface NeighborInfo { + prev: string | null; + next: string | null; + index: number | null; + total: number; +} diff --git a/frontend/src/pages/annotate/useSuggestions.ts b/frontend/src/pages/annotate/useSuggestions.ts new file mode 100644 index 0000000..09694ca --- /dev/null +++ b/frontend/src/pages/annotate/useSuggestions.ts @@ -0,0 +1,101 @@ +import { useCallback, useEffect, useState } from 'react'; +import { apiGet, apiPost } from '@/services/api'; +import type { Suggestion, SuggestModel } from './types'; + +interface SuggestResponse { + artifact: { id: string; name: string; version: string }; + suggestions: { + type: Suggestion['type']; + geometry: Suggestion['geometry']; + class_name: string; + score: number; + }[]; +} + +interface UseSuggestionsResult { + models: SuggestModel[]; + selectedModelId: string; + setSelectedModelId: (id: string) => void; + suggestions: Suggestion[]; + setSuggestions: React.Dispatch>; + loading: boolean; + error: string | null; + hasModel: boolean; + fetchSuggestions: (assetId: string) => Promise; + clear: () => void; +} + +/** + * Loads the models trained on a dataset (newest first, auto-selecting the + * latest) and fetches one-click annotation suggestions for an asset. Returned + * suggestions are kept separate from committed annotations until accepted. + */ +export function useSuggestions(datasetId: string | undefined): UseSuggestionsResult { + const [models, setModels] = useState([]); + const [selectedModelId, setSelectedModelId] = useState(''); + const [suggestions, setSuggestions] = useState([]); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + + useEffect(() => { + if (!datasetId) return; + let cancelled = false; + apiGet<{ items: SuggestModel[] }>(`/api/annotations/suggest/artifacts?dataset_id=${datasetId}`) + .then((d) => { + if (cancelled) return; + const items = d.items || []; + setModels(items); + if (items.length > 0) setSelectedModelId(items[0].id); + }) + .catch(() => { + if (!cancelled) setModels([]); + }); + return () => { + cancelled = true; + }; + }, [datasetId]); + + const fetchSuggestions = useCallback( + async (assetId: string): Promise => { + setLoading(true); + setError(null); + try { + const res = await apiPost('/api/annotations/suggest', { + asset_id: assetId, + artifact_id: selectedModelId || undefined, + }); + const mapped: Suggestion[] = res.suggestions.map((s, i) => ({ + tempId: `sug-${Date.now()}-${i}`, + type: s.type, + class_name: s.class_name, + geometry: s.geometry, + score: s.score, + })); + setSuggestions(mapped); + return mapped.length; + } catch (err) { + const msg = err instanceof Error ? err.message : 'Suggestion failed'; + setError(msg); + throw new Error(msg); + } finally { + setLoading(false); + } + }, + [selectedModelId], + ); + + const clear = useCallback(() => setSuggestions([]), []); + + return { + models, + selectedModelId, + setSelectedModelId, + suggestions, + setSuggestions, + loading, + error, + hasModel: models.length > 0, + fetchSuggestions, + clear, + }; +} diff --git a/frontend/src/pages/datasets/[datasetId]/index.tsx b/frontend/src/pages/datasets/[datasetId]/index.tsx index aecb00a..7feac76 100644 --- a/frontend/src/pages/datasets/[datasetId]/index.tsx +++ b/frontend/src/pages/datasets/[datasetId]/index.tsx @@ -4,6 +4,7 @@ import Badge from '@/components/ui/Badge'; import Button from '@/components/ui/Button'; import Loading from '@/components/common/Loading'; import ErrorState from '@/components/common/ErrorState'; +import DataSourcePanel from '@/components/datasets/DataSourcePanel'; import { apiGet, apiPost } from '@/services/api'; interface DatasetVersion { @@ -22,6 +23,7 @@ interface DatasetDetail { project_id?: string; classes: Array; versions: DatasetVersion[]; + open_version_id?: string | null; created_at?: string; } @@ -32,6 +34,7 @@ export default function DatasetDetail() { const [error, setError] = useState(null); const [snapshotLoading, setSnapshotLoading] = useState(false); const [snapshotMsg, setSnapshotMsg] = useState(null); + const [showData, setShowData] = useState(false); function reload() { if (!datasetId) return; @@ -59,21 +62,26 @@ export default function DatasetDetail() { } } - if (loading) return
; + if (loading) + return ( +
+ +
+ ); if (error) return ; if (!dataset) return ; const latestVersion = dataset.versions[0]; - const classNames = dataset.classes.map((c) => - typeof c === 'string' ? c : c.name - ); + const classNames = dataset.classes.map((c) => (typeof c === 'string' ? c : c.name)); return (
{/* Header */}
@@ -81,15 +89,27 @@ export default function DatasetDetail() {

{dataset.name} - {latestVersion && ( - v{latestVersion.version} - )} + {latestVersion && v{latestVersion.version}}

{dataset.description && ( -

{dataset.description}

+

+ {dataset.description} +

)}
+ + + METRICS + {latestVersion && ( <> + {showData && ( +
+
Add Data
+ {dataset.open_version_id ? ( + reload()} + onImported={() => reload()} + /> + ) : ( +
+ No unlocked version available to add data to. +
+ )} +
+ )} +
{/* Version History */}
@@ -135,21 +173,22 @@ export default function DatasetDetail() { ) : (
{dataset.versions.map((v, idx) => ( -
+
v{v.version} - {idx === 0 && ( - LATEST - )} - {v.locked && ( - LOCKED - )} + {idx === 0 && LATEST} + {v.locked && LOCKED}
{v.notes && ( -

{v.notes}

+

+ {v.notes} +

)} {v.created_at && (

@@ -223,10 +262,17 @@ export default function DatasetDetail() { { label: 'ID', value: {dataset.id} }, { label: 'Versions', value: dataset.versions.length }, { label: 'Total Assets', value: latestVersion?.asset_count ?? 0 }, - ...(dataset.created_at ? [{ label: 'Created', value: new Date(dataset.created_at).toLocaleDateString() }] : []), + ...(dataset.created_at + ? [{ label: 'Created', value: new Date(dataset.created_at).toLocaleDateString() }] + : []), ].map(({ label, value }) => ( -

- {label} +
+ + {label} + {value}
))} @@ -277,13 +323,7 @@ interface AssetSummary { label_status: string; } -function FrameExtractionCard({ - datasetId, - versionId, -}: { - datasetId: string; - versionId: string; -}) { +function FrameExtractionCard({ datasetId, versionId }: { datasetId: string; versionId: string }) { const [videos, setVideos] = useState([]); const [running, setRunning] = useState(null); const [error, setError] = useState(null); @@ -338,8 +378,8 @@ function FrameExtractionCard({
Video → Frames

- Extract one frame every N seconds - and persist as new image assets. + Extract one frame every N seconds and + persist as new image assets.