diff --git a/.gitignore b/.gitignore index 7a3bef7..84d5282 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ site/ archive/ install.sh *.code-workspace +.claude/ diff --git a/src/megadetector_ai/__init__.py b/src/megadetector_ai/__init__.py new file mode 100755 index 0000000..33dfa80 --- /dev/null +++ b/src/megadetector_ai/__init__.py @@ -0,0 +1,20 @@ +""" +MegaDetector: AI-powered wildlife detection for camera trap images. + +This package provides a simplified interface to MegaDetector models +via PyTorch Wildlife. MegaDetector detects animals, people, and vehicles +in camera trap images. + +Quick start: + >>> from megadetector_ai import MegaDetectorV6 + >>> model = MegaDetectorV6() + >>> results = model.single_image_detection("image.jpg") + +For more information, visit https://github.com/microsoft/MegaDetector +""" + +__version__ = "0.1.0" + +from megadetector_ai.detector import MegaDetectorV6, MegaDetectorV5 + +__all__ = ["MegaDetectorV6", "MegaDetectorV5"] diff --git a/src/megadetector_ai/cli.py b/src/megadetector_ai/cli.py new file mode 100755 index 0000000..09a389c --- /dev/null +++ b/src/megadetector_ai/cli.py @@ -0,0 +1,221 @@ +""" +MegaDetector command-line interface. + +Usage: + megadetector detect --input ./images/ + megadetector detect --input photo.jpg --output results.json + megadetector detect --input ./images/ --model MDV6-apa-rtdetr-e --threshold 0.2 + megadetector detect --input ./images/ --device cpu + megadetector train --config ./config.yaml + megadetector validate --config ./config.yaml + megadetector inference --config ./config.yaml +""" + +import argparse +import json +import sys +from pathlib import Path + + +def detect(args): + """Run MegaDetector on images.""" + from megadetector_ai import MegaDetectorV6 + + input_path = Path(args.input) + if not input_path.exists(): + print(f"Error: {input_path} does not exist", file=sys.stderr) + sys.exit(1) + + device = args.device + if device is None: + import torch + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + print(f"Loading MegaDetector ({args.model}) on {device}...") + model = MegaDetectorV6(device=device, pretrained=True, version=args.model) + + if input_path.is_file(): + print(f"Processing {input_path}...") + results = model.single_image_detection(str(input_path)) + detections = _format_detections(str(input_path), results, args.threshold) + all_results = [detections] + elif input_path.is_dir(): + extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} + image_files = sorted( + p for p in input_path.rglob("*") if p.suffix.lower() in extensions + ) + if not image_files: + print(f"No image files found in {input_path}", file=sys.stderr) + sys.exit(1) + + print(f"Found {len(image_files)} images in {input_path}") + all_results = [] + for i, img_path in enumerate(image_files): + results = model.single_image_detection(str(img_path)) + detections = _format_detections(str(img_path), results, args.threshold) + all_results.append(detections) + if (i + 1) % 100 == 0 or (i + 1) == len(image_files): + print(f" Processed {i + 1}/{len(image_files)}") + else: + print(f"Error: {input_path} is not a file or directory", file=sys.stderr) + sys.exit(1) + + total_detections = sum(len(r["detections"]) for r in all_results) + images_with_animals = sum( + 1 for r in all_results + if any(d["category"] == "animal" for d in r["detections"]) + ) + + if args.output: + output_path = Path(args.output) + with open(output_path, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nResults saved to {output_path}") + else: + print(json.dumps(all_results, indent=2)) + + print(f"\nSummary: {len(all_results)} images, {total_detections} detections, " + f"{images_with_animals} images with animals") + + +def _format_detections(image_path, results, threshold): + """Format detection results as a dict.""" + CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"} + detections = [] + + if results["detections"] is not None: + for xyxy, conf, cls_id in zip( + results["detections"].xyxy, + results["detections"].confidence, + results["detections"].class_id, + ): + if conf >= threshold: + detections.append({ + "category": CLASS_NAMES.get(int(cls_id), "unknown"), + "confidence": round(float(conf), 4), + "bbox": [round(float(x), 1) for x in xyxy], + }) + + return { + "file": image_path, + "detections": detections, + } + + +def train(args): + """Train a detection model.""" + from megadetector_ai.training import train as run_training + + config_path = args.config + if not Path(config_path).exists(): + print(f"Error: Config file {config_path} does not exist", file=sys.stderr) + sys.exit(1) + + print(f"Starting training with config: {config_path}") + results = run_training(config_path) + print("Training completed successfully") + return results + + +def validate(args): + """Validate a detection model.""" + from megadetector_ai.training import validate as run_validation + + config_path = args.config + if not Path(config_path).exists(): + print(f"Error: Config file {config_path} does not exist", file=sys.stderr) + sys.exit(1) + + print(f"Starting validation with config: {config_path}") + metrics = run_validation(config_path) + print("Validation completed successfully") + return metrics + + +def inference(args): + """Run inference on test data.""" + from megadetector_ai.training import inference as run_inference + + config_path = args.config + if not Path(config_path).exists(): + print(f"Error: Config file {config_path} does not exist", file=sys.stderr) + sys.exit(1) + + print(f"Starting inference with config: {config_path}") + results = run_inference(config_path) + print("Inference completed successfully") + return results + + +def main(): + parser = argparse.ArgumentParser( + prog="megadetector", + description="MegaDetector: AI-powered wildlife detection for camera trap images", + ) + subparsers = parser.add_subparsers(dest="command") + + detect_parser = subparsers.add_parser( + "detect", help="Run MegaDetector on images" + ) + detect_parser.add_argument( + "--input", "-i", required=True, + help="Path to an image file or directory of images", + ) + detect_parser.add_argument( + "--output", "-o", default=None, + help="Path to save JSON results (prints to stdout if omitted)", + ) + detect_parser.add_argument( + "--model", "-m", default="MDV6-yolov9-c", + help="Model variant (default: MDV6-yolov9-c)", + ) + detect_parser.add_argument( + "--threshold", "-t", type=float, default=0.2, + help="Confidence threshold (default: 0.2)", + ) + detect_parser.add_argument( + "--device", "-d", default=None, + help="Device: cuda:0, cpu, mps (default: auto-detect)", + ) + + train_parser = subparsers.add_parser( + "train", help="Train a detection model" + ) + train_parser.add_argument( + "--config", "-c", default="./config.yaml", + help="Path to training config file (default: ./config.yaml)", + ) + + validate_parser = subparsers.add_parser( + "validate", help="Validate a detection model" + ) + validate_parser.add_argument( + "--config", "-c", default="./config.yaml", + help="Path to validation config file (default: ./config.yaml)", + ) + + inference_parser = subparsers.add_parser( + "inference", help="Run inference on test data" + ) + inference_parser.add_argument( + "--config", "-c", default="./config.yaml", + help="Path to inference config file (default: ./config.yaml)", + ) + + args = parser.parse_args() + if args.command is None: + parser.print_help() + sys.exit(0) + + if args.command == "detect": + detect(args) + elif args.command == "train": + train(args) + elif args.command == "validate": + validate(args) + elif args.command == "inference": + inference(args) + + +if __name__ == "__main__": + main() diff --git a/src/megadetector_ai/detector.py b/src/megadetector_ai/detector.py new file mode 100755 index 0000000..9892629 --- /dev/null +++ b/src/megadetector_ai/detector.py @@ -0,0 +1,59 @@ +""" +Convenience wrappers around PyTorch Wildlife's MegaDetector models. + +These classes provide a simplified import path and sensible defaults +for common MegaDetector workflows. +""" + +from PytorchWildlife.models import detection as pw_detection + + +class MegaDetectorV6(pw_detection.MegaDetectorV6): + """MegaDetector V6 — the latest generation of MegaDetector. + + Detects animals, people, and vehicles in camera trap images using + modern architectures (YOLOv9, YOLOv10, RT-DETR). Multiple model + variants are available, ranging from 2.3M to 76M parameters. + + Args: + device: Device to run on. "cuda:0" for GPU, "cpu" for CPU, + "mps" for Apple Silicon. Defaults to CUDA if available. + pretrained: Whether to download pretrained weights. Default True. + version: Model variant to load. Options: + - "MDV6-yolov9-c" (default) — compact YOLOv9 + - "MDV6-yolov9-e" — extra-large YOLOv9 + - "MDV6-yolov10-c" — compact YOLOv10 (2.3M params) + - "MDV6-yolov10-e" — extra-large YOLOv10 + - "MDV6-rtdetr-c" — compact RT-DETR + - "MDV6-mit-yolov9-c" — MIT-licensed compact + - "MDV6-mit-yolov9-e" — MIT-licensed extra + - "MDV6-apa-rtdetr-c" — Apache-licensed compact + - "MDV6-apa-rtdetr-e" — Apache-licensed extra (best accuracy) + + Example: + >>> from megadetector_ai import MegaDetectorV6 + >>> model = MegaDetectorV6() + >>> results = model.single_image_detection("photo.jpg") + >>> print(results["detections"]) + """ + pass + + +class MegaDetectorV5(pw_detection.MegaDetectorV5): + """MegaDetector V5 — the previous generation, based on YOLOv5. + + Still available for backward compatibility. We recommend V6 for + new projects — it is smaller, faster, and offers permissive + license options. + + Args: + device: Device to run on. Default: CUDA if available. + pretrained: Whether to download pretrained weights. Default True. + version: "a" (default, recommended) or "b". + + Example: + >>> from megadetector_ai import MegaDetectorV5 + >>> model = MegaDetectorV5(version="a") + >>> results = model.single_image_detection("photo.jpg") + """ + pass diff --git a/src/megadetector_ai/training.py b/src/megadetector_ai/training.py new file mode 100755 index 0000000..e94bd4d --- /dev/null +++ b/src/megadetector_ai/training.py @@ -0,0 +1,114 @@ +""" +Fine-tuning module for MegaDetector models. + +Provides functions to train, validate, and run inference with detection models +using the ultralytics framework. +""" + +from ultralytics import YOLO, RTDETR +from megadetector_ai.training_utils import get_model_path +from munch import Munch +import yaml +import os + + +def load_config(config_path: str = './config.yaml') -> Munch: + """Load configuration from YAML file.""" + with open(config_path) as f: + cfg = Munch(yaml.load(f, Loader=yaml.FullLoader)) + return cfg + + +def _load_model(cfg: Munch): + """Load the appropriate model based on config.""" + if cfg.resume: + model_path = cfg.weights + else: + model_path = get_model_path(cfg.model_name) + + if cfg.model == "YOLO": + model = YOLO(model_path) + elif cfg.model == "RTDETR": + model = RTDETR(model_path) + else: + raise ValueError("Model not supported") + + return model + + +def _prepare_data_config(cfg: Munch): + """Ensure data config paths are absolute.""" + with open(cfg.data) as f: + data = yaml.safe_load(f) + + if not os.path.isabs(data["path"]): + data["path"] = os.path.abspath(data["path"]) + with open(cfg.data, 'w') as f: + yaml.dump(data, f) + + return data + + +def train(config_path: str = './config.yaml'): + """Train a detection model with the specified configuration.""" + cfg = load_config(config_path) + model = _load_model(cfg) + _prepare_data_config(cfg) + + model.info() + + results = model.train( + data=cfg.data, + epochs=cfg.epochs, + imgsz=cfg.imgsz, + device=cfg.device_train, + save_period=cfg.save_period, + workers=cfg.workers, + batch=cfg.batch_size_train, + val=cfg.val, + project=f"runs/train_{cfg.exp_name}", + name="exp", + patience=cfg.patience, + resume=cfg.resume + ) + + return results + + +def validate(config_path: str = './config.yaml'): + """Validate a detection model with the specified configuration.""" + cfg = load_config(config_path) + model = _load_model(cfg) + _prepare_data_config(cfg) + + model.info() + + metrics = model.val( + data=cfg.data, + save_json=cfg.save_json, + plots=cfg.plots, + device=cfg.device_val, + project=f'runs/val_{cfg.exp_name}', + name="exp", + batch=cfg.batch_size_val + ) + + return metrics + + +def inference(config_path: str = './config.yaml'): + """Run inference on test data with the specified configuration.""" + cfg = load_config(config_path) + model = _load_model(cfg) + _prepare_data_config(cfg) + + model.info() + + results = model(cfg.test_data) + save_path = os.path.join("inference_results", cfg.exp_name) + os.makedirs(save_path, exist_ok=True) + + for i in range(len(results)): + results[i].save(filename=os.path.join(save_path, f"inference_{i}.jpg")) + + return results diff --git a/src/megadetector_ai/training_utils.py b/src/megadetector_ai/training_utils.py new file mode 100755 index 0000000..26c04a5 --- /dev/null +++ b/src/megadetector_ai/training_utils.py @@ -0,0 +1,31 @@ +import os +import wget +import torch + +def get_model_path(model): + + if model == "MDV6-yolov9-c": + url = "https://zenodo.org/records/14567879/files/MDV6b-yolov9c.pt?download=1" + model_name = "MDV6b-yolov9c.pt" + elif model == "MDV6-yolov9-e": + url = "https://zenodo.org/records/14567879/files/MDV6-yolov9e.pt?download=1" + model_name = "MDV6-yolov9e.pt" + elif model == "MDV6-yolov10-c": + url = "https://zenodo.org/records/14567879/files/MDV6-yolov10n.pt?download=1" + model_name = "MDV6-yolov10n.pt" + elif model == "MDV6-yolov10-e": + url = "https://zenodo.org/records/14567879/files/MDV6-yolov10x.pt?download=1" + model_name = "MDV6-yolov10x.pt" + elif model == "MDV6-rtdetr-c": + url = "https://zenodo.org/records/14567879/files/MDV6b-rtdetrl.pt?download=1" + model_name = "MDV6b-rtdetrl.pt" + else: + raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c') + + if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", model_name)): + os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) + model_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) + else: + model_path = os.path.join(torch.hub.get_dir(), "checkpoints", model_name) + + return model_path