diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000..02c4515 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,64 @@ +name: Bug Report +description: Problems with MegaDetector-Classifier +labels: [bug] +body: + - type: markdown + attributes: + value: | + Thank you for submitting a Bug Report! + + - type: checkboxes + attributes: + label: Search before asking + description: > + Please search the [issues](https://github.com/microsoft/MegaDetector-Classifier/issues) to see if a similar bug report already exists. + options: + - label: > + I have searched the MegaDetector-Classifier [issues](https://github.com/microsoft/MegaDetector-Classifier/issues) and found no similar bug report. + required: true + + - type: textarea + attributes: + label: Bug + description: Provide console output with error messages and/or screenshots of the bug. + placeholder: | + 💡 ProTip! Include as much information as possible (error messages, screenshots, logs, tracebacks, etc.) to receive the most helpful response. + validations: + required: true + + - type: textarea + attributes: + label: Environment + description: Please specify the software and hardware you used to produce the bug. + placeholder: | + - PytorchWildlife: 1.3.0 + - OS: Ubuntu 22.04 + - Python: 3.10.0 + - CUDA: 12.1 (or CPU) + validations: + required: false + + - type: textarea + attributes: + label: Minimal Reproducible Example + description: > + This is referred to by community members as creating a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). + placeholder: | + ``` + # Code to reproduce your issue here + ``` + validations: + required: false + + - type: textarea + attributes: + label: Additional + description: Anything else you would like to share? + + - type: checkboxes + attributes: + label: Are you willing to submit a PR? + description: > + (Optional) We encourage you to submit a [Pull Request](https://github.com/microsoft/MegaDetector-Classifier/pulls) (PR) to help contribute to MegaDetector-Classifier for everyone, especially if you have a good understanding of how to implement a fix or feature. + options: + - label: Yes I'd like to help by submitting a PR! diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 0000000..2e650b6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,48 @@ +name: Feature Request +description: Suggest an enhancement for MegaDetector-Classifier +labels: [enhancement] +body: + - type: markdown + attributes: + value: | + Thank you for submitting a MegaDetector-Classifier Feature Request! + + - type: checkboxes + attributes: + label: Search before asking + description: > + Please search the [issues](https://github.com/microsoft/MegaDetector-Classifier/issues) to see if a similar feature request already exists. + options: + - label: > + I have searched the MegaDetector-Classifier [issues](https://github.com/microsoft/MegaDetector-Classifier/issues) and found no similar feature request. + required: true + + - type: textarea + attributes: + label: Description + description: A short description of your feature. + placeholder: | + What new feature would you like to see in MegaDetector-Classifier? + validations: + required: true + + - type: textarea + attributes: + label: Use case + description: | + Describe the use case of your feature request. It will help us understand and prioritize the feature request. + placeholder: | + How would this feature be used, and who would use it? + + - type: textarea + attributes: + label: Additional + description: Anything else you would like to share? + + - type: checkboxes + attributes: + label: Are you willing to submit a PR? + description: > + (Optional) We encourage you to submit a [Pull Request](https://github.com/microsoft/MegaDetector-Classifier/pulls) (PR) to help contribute to MegaDetector-Classifier for everyone, especially if you have a good understanding of how to implement a fix or feature. + options: + - label: Yes I'd like to help by submitting a PR! diff --git a/.github/ISSUE_TEMPLATE/question.yml b/.github/ISSUE_TEMPLATE/question.yml new file mode 100644 index 0000000..59d3139 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.yml @@ -0,0 +1,32 @@ +name: Question +description: Ask a MegaDetector-Classifier question +labels: [question] +body: + - type: markdown + attributes: + value: | + Thank you for asking a general question! + + - type: checkboxes + attributes: + label: Search before asking + description: > + Please search the [issues](https://github.com/microsoft/MegaDetector-Classifier/issues) to see if a similar question already exists. + options: + - label: > + I have searched the MegaDetector-Classifier [issues](https://github.com/microsoft/MegaDetector-Classifier/issues) and found no similar question. + required: true + + - type: textarea + attributes: + label: Question + description: What is your question? + placeholder: | + 💡 ProTip! Include as much information as possible to receive the most helpful response. + validations: + required: true + + - type: textarea + attributes: + label: Additional + description: Anything else you would like to share? diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml new file mode 100644 index 0000000..9a838ee --- /dev/null +++ b/.github/workflows/deploy-docs.yml @@ -0,0 +1,28 @@ +name: Deploy MkDocs site + +on: + push: + branches: + - main + paths: + - 'docs/**' + - 'mkdocs.yml' + - 'docs-requirements.txt' + +jobs: + deploy: + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install MkDocs dependencies + run: pip install -r docs-requirements.txt + + - name: Deploy to GitHub Pages + run: mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dec29f7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ +dist/ +build/ +*.egg +.eggs/ +*.so +.env +.venv +env/ +venv/ +.tox/ +.coverage +htmlcov/ +.pytest_cache/ +*.log +.DS_Store +Thumbs.db +Brewfile +site/ +archive/ +*.code-workspace diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9003ea9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Microsoft + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index bc1e65b..8c8dafc 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,162 @@ -# Repository setup required :wave: - -Please visit the website URL :point_right: for this repository to complete the setup of this repository and configure access controls. \ No newline at end of file +# MegaDetector-Classifier + +Microsoft AI for Good Lab's open-source classification fine-tuning tool — train custom species classifiers on your own camera-trap datasets and deploy them through PyTorch-Wildlife. + +[![License](https://img.shields.io/github/license/microsoft/MegaDetector-Classifier)](https://github.com/microsoft/MegaDetector-Classifier/blob/main/LICENSE) +[![Python 3.9+](https://img.shields.io/badge/python-3.9%2B-blue.svg)](https://www.python.org/downloads/) +[![PyTorch-Wildlife](https://img.shields.io/badge/PyTorch--Wildlife-ecosystem-green.svg)](https://github.com/microsoft/Biodiversity) + +MegaDetector-Classifier is part of the [microsoft/Biodiversity](https://github.com/microsoft/Biodiversity) ecosystem and is powered by the [PyTorch-Wildlife](https://github.com/microsoft/PytorchWildlife) framework. It is free, open-source, and available under the MIT license. + +## Part of the Biodiversity Ecosystem + +MegaDetector-Classifier is one tool in a larger open-source ecosystem from the Microsoft AI for Good Lab. Each project lives in its own repository, with the [microsoft/Biodiversity](https://github.com/microsoft/Biodiversity) umbrella tying them together. + +| Repository | Description | +|---|---| +| [microsoft/Biodiversity](https://github.com/microsoft/Biodiversity) | The umbrella repository — documentation hub for the AI for Good Lab's biodiversity work | +| [microsoft/MegaDetector](https://github.com/microsoft/MegaDetector) | Animal, human, and vehicle detection for camera-trap images | +| [microsoft/PytorchWildlife](https://github.com/microsoft/PytorchWildlife) | The collaborative deep learning framework that hosts MegaDetector, species classifiers, and demo notebooks | +| [microsoft/MegaDetector-Acoustic](https://github.com/microsoft/MegaDetector-Acoustic) | Bioacoustic AI for audio-based wildlife detection and classification | +| [microsoft/MegaDetector-Classifier](https://github.com/microsoft/MegaDetector-Classifier) | **This repo** — classification fine-tuning for camera-trap species identification | +| [microsoft/SPARROW](https://github.com/microsoft/SPARROW) | Solar-Powered Acoustic and Remote Recording Observation Watch — AI-enabled edge device for field recording | +| [SPARROW-Studio](https://github.com/microsoft/Biodiversity/tree/main/SPARROW-Studio) | Desktop application wrapping all AI for Good Lab models in a graphical interface | + +## Overview + +MegaDetector-Classifier is a training toolkit for fine-tuning ResNet-based species classifiers on custom camera-trap image datasets. The output weights integrate directly with the [PyTorch-Wildlife](https://github.com/microsoft/PytorchWildlife) framework, making it straightforward to deploy a classifier trained on your own data. + +**Key capabilities:** +- ResNet-18 and ResNet-50 classifier training using PyTorch Lightning +- Three data-splitting strategies designed for camera-trap realities: random, location-based, and sequence-based +- YAML-based configuration — no code changes required for most use cases +- Demo data included for immediate testing without your own dataset + +**Designed for:** +- Conservation practitioners adapting existing classifiers to new geographic regions +- Researchers adding new species to the PyTorch-Wildlife model zoo +- Projects running MegaDetector detection upstream and needing a matched classifier downstream + +## Installation + +### Using pip + +```bash +git clone https://github.com/microsoft/MegaDetector-Classifier +cd MegaDetector-Classifier +pip install -r requirements.txt +``` + +### Using conda + +```bash +git clone https://github.com/microsoft/MegaDetector-Classifier +cd MegaDetector-Classifier +conda env create -f environment.yaml +conda activate PT_Finetuning +``` + +**Requirements:** Python 3.9+ + +## Quick Start + +1. Configure `configs/config.yaml` — set `dataset_root`, `annotation_dir`, `num_classes`, and `split_type` +2. Run training: + +```bash +python main.py +``` + +Output weights are saved to the `weights/` directory and can be loaded directly into PyTorch-Wildlife. + +## Data Preparation + +### Data Structure + +Images should be stored in a single flat directory (no nested subdirectories). An `annotations.csv` file — placed outside the images directory — maps each image to its class: + +```plaintext +MegaDetector-Classifier/ +├── data/ +│ ├── imgs/ # All images stored here (flat) +│ └── annotation_example.csv # Annotations file +└── configs/config.yaml +``` + +### Annotation File Format + +The CSV must contain three columns: + +| Column | Description | Example | +|---|---|---| +| `path` | Relative path to the image | `imgs/leopard_001.jpg` | +| `classification` | Integer class ID | `0` | +| `label` | Human-readable class name | `leopard` | + +### Data Splitting + +MegaDetector-Classifier supports three splitting strategies, selected via `split_type` in `config.yaml`: + +| Strategy | When to use | Extra column required | +|---|---|---| +| `random` | Balanced class distribution; not recommended for camera-trap bursts | None | +| `location` | Keeps all images from one camera location in the same split | `Location` | +| `sequence` | Groups burst images within 30-second windows before splitting | `Photo_time` (YYYY-MM-DD HH:MM:SS) | + +> **Camera-trap note:** Random splitting is not recommended because burst images of the same animal can appear in both training and validation sets, causing artificially high validation accuracy. Use `location` or `sequence` splitting instead. + +### Demo Data + +Download demo data to test the pipeline without your own dataset: + +```bash +# Download and extract +wget https://zenodo.org/records/15376499/files/demo_data_clf.zip +unzip demo_data_clf.zip -d data/ +``` + +Then set `dataset_root: ./data/imgs` in `configs/config.yaml` and run `python main.py`. + +## Repository Structure + +``` +MegaDetector-Classifier/ +├── main.py # Training entry point +├── requirements.txt # pip dependencies +├── environment.yaml # conda environment +├── configs/ +│ └── config.yaml # Training configuration +└── src/ + ├── algorithms/ + │ └── plain.py # Training algorithm (PyTorch Lightning) + ├── datasets/ + │ └── custom.py # Custom dataset loader + ├── models/ + │ └── plain_resnet.py # ResNet-18/50 classifier + └── utils/ + ├── batch_detection_cropping.py # MegaDetector crop integration + ├── data_splitting.py # Random / location / sequence splits + └── utils.py # Shared utilities +``` + +## Citation + +If you use MegaDetector-Classifier in your research, please cite the PyTorch-Wildlife paper: + +```bibtex +@misc{hernandez2024pytorchwildlife, + title={Pytorch-Wildlife: A Collaborative Deep Learning Framework for Conservation}, + author={Andres Hernandez and Zhongqi Miao and Luisa Vargas and Sara Beery and Rahul Dodhia and Juan Lavista}, + year={2024}, + eprint={2405.12930}, + archivePrefix={arXiv}, +} +``` + +You can also cite this software directly using the [`citation.cff`](citation.cff) file in this repository. + +## Contributing + +Issues, feature requests, and pull requests are welcome at [microsoft/MegaDetector-Classifier/issues](https://github.com/microsoft/MegaDetector-Classifier/issues). + +For framework-level changes (PyTorch-Wildlife API, models, datasets), see [microsoft/PytorchWildlife](https://github.com/microsoft/PytorchWildlife). For ecosystem-wide questions, see the [microsoft/Biodiversity](https://github.com/microsoft/Biodiversity) umbrella. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..b3c89ef --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). + + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..c18128e --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +from batch_detection_cropping import * \ No newline at end of file diff --git a/citation.cff b/citation.cff new file mode 100644 index 0000000..ffe67d2 --- /dev/null +++ b/citation.cff @@ -0,0 +1,37 @@ +cff-version: 1.2.0 +title: "MegaDetector-Classifier: Open-Source Classification Fine-Tuning for Wildlife Camera Traps" +message: "If you use MegaDetector-Classifier, please cite it using the metadata from this file." +type: software +version: "1.0.0" +date-released: "2026-05-15" +authors: + - given-names: Andres + family-names: Hernandez + - given-names: Zhongqi + family-names: Miao + - given-names: Luisa + family-names: Vargas + - given-names: Sara + family-names: Beery + - given-names: Rahul + family-names: Dodhia + - given-names: Juan + family-names: Lavista + - name: "Microsoft AI for Good Lab" +identifiers: + - type: url + value: "https://arxiv.org/abs/2405.12930" + description: "Pytorch-Wildlife: A Collaborative Deep Learning Framework for Conservation" +repository-code: "https://github.com/microsoft/MegaDetector-Classifier" +url: "https://microsoft.github.io/MegaDetector-Classifier/" +keywords: + - MegaDetector-Classifier + - camera-trap + - species classification + - fine-tuning + - transfer learning + - wildlife monitoring + - conservation + - PyTorch-Wildlife + - ai-for-good +license: MIT diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..e0c4bb1 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,41 @@ +# training +conf_id: Crop_Res18_plain_071824 +algorithm: Plain +log_dir: Crop +num_epochs: 30 +log_interval: 10 +parallel: 0 + +# data +dataset_root: ./data/imgs +dataset_name: Custom_Crop +# annotation directory (if you have train/val/test splits) +annotation_dir: ./data/imgs +# data splitting (if you don't have train/val/test splits) +split_path: ./data/imgs/annotation_example.csv +test_size: 0.2 +val_size: 0.2 +split_data: True +split_type: location # options are: random, location, sequence +# data loading +batch_size: 32 +num_workers: 4 #40 +# model +num_classes: 2 +model_name: PlainResNetClassifier +num_layers: 18 +weights_init: ImageNet + +# optim +## feature +lr_feature: 0.01 +momentum_feature: 0.9 +weight_decay_feature: 0.0005 +## classifier +lr_classifier: 0.01 +momentum_classifier: 0.9 +weight_decay_classifier: 0.0005 +## lr_scheduler +step_size: 10 +gamma: 0.1 + diff --git a/docs-requirements.txt b/docs-requirements.txt new file mode 100644 index 0000000..99207a8 --- /dev/null +++ b/docs-requirements.txt @@ -0,0 +1,7 @@ +mkdocs +mkdocs-get-deps +mkdocs-material +mkdocs-material-extensions +pymdown-extensions +mkdocstrings +mkdocstrings-python diff --git a/docs/build_mkdocs.md b/docs/build_mkdocs.md new file mode 100644 index 0000000..f218578 --- /dev/null +++ b/docs/build_mkdocs.md @@ -0,0 +1,63 @@ +--- +description: "How to build and deploy the MegaDetector-Classifier MkDocs documentation site locally and to GitHub Pages." +tags: + - MkDocs + - documentation + - developer guide + - MegaDetector-Classifier +--- + +# Developer Guide — Building the Docs + +This page explains how to build and preview the MegaDetector-Classifier documentation site locally. + +## Prerequisites + +Install the documentation dependencies (separate from the ML requirements): + +```bash +pip install -r docs-requirements.txt +``` + +## Preview locally + +```bash +mkdocs serve +``` + +Then open [http://127.0.0.1:8000](http://127.0.0.1:8000) in your browser. The server hot-reloads on file changes. + +## Build (offline check) + +```bash +mkdocs build --strict +``` + +`--strict` treats warnings as errors. Fix any broken links or missing pages before opening a PR. + +## Deploy to GitHub Pages + +Deployment is automatic. The [GitHub Actions workflow](https://github.com/microsoft/MegaDetector-Classifier/blob/main/.github/workflows/deploy-docs.yml) triggers on every push to `main` that touches `docs/**`, `mkdocs.yml`, or `docs-requirements.txt`. + +To deploy manually (maintainers only): + +```bash +mkdocs gh-deploy --force +``` + +This builds the site and force-pushes to the `gh-pages` branch. Do not commit the `site/` directory — it is generated and is in `.gitignore`. + +## Adding a new page + +1. Create a new `.md` file under `docs/` +2. Add SEO front matter at the top: + ```yaml + --- + description: "One sentence describing this page for search engines." + tags: + - relevant tag + - another tag + --- + ``` +3. Add the page to the `nav:` section in `mkdocs.yml` +4. Run `mkdocs build --strict` to verify no errors diff --git a/docs/cite.md b/docs/cite.md new file mode 100644 index 0000000..da0a573 --- /dev/null +++ b/docs/cite.md @@ -0,0 +1,33 @@ +--- +description: "How to cite MegaDetector-Classifier — BibTeX and citation.cff entries for the PyTorch-Wildlife camera-trap classification fine-tuning tool." +tags: + - cite MegaDetector-Classifier + - MegaDetector citation + - PyTorch-Wildlife citation + - camera-trap classification research + - BibTeX +--- + +# :fountain_pen: Cite Us + +If you use MegaDetector-Classifier in your research, please cite the PyTorch-Wildlife paper: + +```bibtex +@misc{hernandez2024pytorchwildlife, + title={Pytorch-Wildlife: A Collaborative Deep Learning Framework for Conservation}, + author={Andres Hernandez and Zhongqi Miao and Luisa Vargas and Sara Beery and Rahul Dodhia and Juan Lavista}, + year={2024}, + eprint={2405.12930}, + archivePrefix={arXiv}, +} +``` + +You can also cite MegaDetector-Classifier as software directly. The [`citation.cff`](https://github.com/microsoft/MegaDetector-Classifier/blob/main/citation.cff) file in the repository is machine-readable and is used by GitHub's "Cite this repository" widget. + +--- + +## Related Work + +If your research uses MegaDetector detection upstream of MegaDetector-Classifier, please also cite: + +- **Beery, Morris, Yang 2019** — *Efficient Pipeline for Camera Trap Image Review* — for any use of MegaDetector specifically diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..9ff32fd --- /dev/null +++ b/docs/index.md @@ -0,0 +1,59 @@ +--- +description: "MegaDetector-Classifier — Microsoft AI for Good Lab's open-source classification fine-tuning tool for training custom species classifiers on camera-trap datasets." +tags: + - MegaDetector-Classifier + - camera-trap classification + - species identification + - fine-tuning + - transfer learning + - PyTorch-Wildlife + - conservation AI +--- + +# Microsoft MegaDetector-Classifier + +**Open-source classification fine-tuning for camera-trap species identification.** + +MegaDetector-Classifier is a training toolkit from the [Microsoft AI for Good Lab](https://www.microsoft.com/en-us/research/group/ai-for-good-research-lab/) for fine-tuning ResNet-based species classifiers on custom camera-trap image datasets. Output weights integrate directly with the [PyTorch-Wildlife](https://github.com/microsoft/PytorchWildlife) framework. It is part of the [microsoft/Biodiversity](https://github.com/microsoft/Biodiversity) ecosystem. + +--- + +## What It Does + +MegaDetector-Classifier takes your labeled camera-trap images and produces a trained classifier ready to deploy in PyTorch-Wildlife: + +1. **Data preparation** — flat image directory + a simple `annotations.csv`; three splitting strategies (random, location, sequence) designed for camera-trap realities +2. **Training** — ResNet-18 or ResNet-50 classifiers via PyTorch Lightning, configured entirely through `config.yaml` +3. **Output** — trained weights saved to `weights/`, ready to load into PyTorch-Wildlife for inference + +--- + +## Get Started + +See the [Installation](installation.md) page, then configure and run: + +```bash +git clone https://github.com/microsoft/MegaDetector-Classifier +cd MegaDetector-Classifier +pip install -r requirements.txt +# edit configs/config.yaml, then: +python main.py +``` + +Demo data is available on [Zenodo](https://zenodo.org/records/15376499/files/demo_data_clf.zip) for immediate testing without your own dataset. See the [Usage Guide](usage.md) for a full walkthrough. + +--- + +## Part of the Biodiversity Ecosystem + +MegaDetector-Classifier is one tool in a larger open-source ecosystem from the Microsoft AI for Good Lab. + +| Repository | Description | +|---|---| +| [microsoft/Biodiversity](https://github.com/microsoft/Biodiversity) | Umbrella hub — PyTorch-Wildlife, MegaDetector, ecosystem overview | +| [microsoft/MegaDetector](https://github.com/microsoft/MegaDetector) | Animal, human, and vehicle detection for camera-trap images | +| [microsoft/PytorchWildlife](https://github.com/microsoft/PytorchWildlife) | The collaborative deep learning framework for wildlife monitoring | +| [microsoft/MegaDetector-Acoustic](https://github.com/microsoft/MegaDetector-Acoustic) | Bioacoustic AI for audio-based wildlife detection and classification | +| [microsoft/MegaDetector-Classifier](https://github.com/microsoft/MegaDetector-Classifier) | **This repo** — classification fine-tuning for camera-trap species identification | +| [microsoft/SPARROW](https://github.com/microsoft/SPARROW) | Solar-Powered Acoustic and Remote Recording Observation Watch — AI-enabled edge device | +| [SPARROW-Studio](https://github.com/microsoft/Biodiversity/tree/main/SPARROW-Studio) | Desktop application for all AI for Good Lab models | diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 0000000..921a159 --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,67 @@ +--- +description: "How to install and set up MegaDetector-Classifier for camera-trap species classification fine-tuning." +tags: + - MegaDetector-Classifier installation + - camera-trap classification setup + - PyTorch-Wildlife + - fine-tuning setup + - wildlife monitoring +--- + +# Installation + +## Requirements + +- Python 3.9+ +- PyTorch 2.0+ +- CUDA (optional, but recommended for faster training) + +## Install with pip + +```bash +git clone https://github.com/microsoft/MegaDetector-Classifier +cd MegaDetector-Classifier +pip install -r requirements.txt +``` + +This installs the following dependencies: + +| Package | Purpose | +|---|---| +| `PytorchWildlife` | Core models and framework integration | +| `lightning` | Training loop and checkpointing (PyTorch Lightning) | +| `scikit_learn` | Data splitting utilities | +| `munch` | YAML config as dot-accessible object | +| `typer` | CLI argument handling | + +## Install with conda + +```bash +git clone https://github.com/microsoft/MegaDetector-Classifier +cd MegaDetector-Classifier +conda env create -f environment.yaml +conda activate PT_Finetuning +``` + +## Verify + +```python +from PytorchWildlife.models import classification as pw_classification +print("MegaDetector-Classifier is ready.") +``` + +## GPU Setup + +Training on GPU is recommended for faster iteration. Verify CUDA availability: + +```python +import torch +print(torch.cuda.is_available()) # should print True on a CUDA-enabled machine +``` + +CPU training is supported but will be slower for larger datasets. + +## Next Steps + +- Download [demo data](https://zenodo.org/records/15376499/files/demo_data_clf.zip) for an immediate end-to-end test +- Read the [Usage Guide](usage.md) for a full walkthrough of data preparation, configuration, and training diff --git a/docs/tags.md b/docs/tags.md new file mode 100644 index 0000000..5403590 --- /dev/null +++ b/docs/tags.md @@ -0,0 +1,9 @@ +--- +description: "Tag index for MegaDetector-Classifier documentation." +tags: + - MegaDetector-Classifier +--- + +# Tags + +[TAGS] diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 0000000..74eedfc --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,152 @@ +--- +description: "How to use MegaDetector-Classifier — data preparation, splitting strategies, configuration, training, and integrating output weights with PyTorch-Wildlife." +tags: + - MegaDetector-Classifier usage + - camera-trap classification training + - data splitting + - config.yaml + - species classification fine-tuning + - PyTorch-Wildlife integration +--- + +# Usage Guide + +This page walks through the full MegaDetector-Classifier workflow: preparing your data, configuring training, running the classifier, and using the output weights. + +--- + +## 1. Data Structure + +Images must be stored in a single **flat directory** (no subdirectories). An `annotations.csv` file placed alongside (not inside) the images directory maps each image to its class. + +```plaintext +MegaDetector-Classifier/ +├── data/ +│ ├── imgs/ # All images stored here (flat — no nested folders) +│ └── annotation_example.csv # Annotations file +└── configs/config.yaml +``` + +### Annotation File Format + +The CSV must contain exactly these three columns: + +| Column | Type | Description | Example | +|---|---|---|---| +| `path` | string | Relative path to the image from the CSV location | `imgs/leopard_001.jpg` | +| `classification` | integer | Unique integer ID for each class | `0` | +| `label` | string | Human-readable class name | `leopard` | + +--- + +## 2. Data Splitting + +Set `split_data: True` in `config.yaml` to have MegaDetector-Classifier split your annotations into train, validation, and test sets automatically. The splitting strategy is controlled by `split_type`. + +### Splitting Strategies + +| Strategy | When to use | Extra column required | +|---|---|---| +| `random` | General-purpose balanced split | None | +| `location` | Keeps all images from one camera location in the same split | `Location` | +| `sequence` | Groups burst images within 30-second windows before splitting | `Photo_time` (YYYY-MM-DD HH:MM:SS) | + +> **Important — camera-trap burst images:** Camera traps frequently capture bursts of images of the same animal within seconds. With random splitting, nearly identical frames can end up in both training and validation sets, causing artificially inflated validation accuracy (overfitting). Use `location` or `sequence` splitting to prevent this. + +If you already have pre-split CSV files for train/val/test, set `split_data: False` and point `annotation_dir` to the directory containing those files. + +--- + +## 3. Configuration + +All training parameters live in `configs/config.yaml`. Edit this file before running `python main.py`. + +### Training Parameters + +| Parameter | Description | Default | +|---|---|---| +| `conf_id` | Unique identifier for this training run | `Crop_Res18_plain_071824` | +| `algorithm` | Training algorithm | `Plain` | +| `log_dir` | Directory for training logs | `Crop` | +| `num_epochs` | Total training epochs | `30` | +| `log_interval` | How often to log training info (in steps) | `10` | +| `parallel` | Set to `1` to enable multi-GPU training | `0` | + +### Data Parameters + +| Parameter | Description | +|---|---| +| `dataset_root` | Root directory where images are stored | +| `dataset_name` | Dataset type (`Custom_Crop` for fine-tuning) | +| `annotation_dir` | Directory containing annotation CSV files | +| `split_path` | Path to single CSV for auto-splitting | +| `test_size` | Proportion of data for test set (e.g. `0.2`) | +| `val_size` | Proportion of data for validation set (e.g. `0.2`) | +| `split_data` | `True` to auto-split, `False` if splits already exist | +| `split_type` | `random`, `location`, or `sequence` | +| `batch_size` | Images per batch (default: `32`) | +| `num_workers` | Dataloader worker processes (default: `4`) | + +### Model Parameters + +| Parameter | Description | +|---|---| +| `num_classes` | Number of species classes in your dataset | +| `model_name` | Architecture (`PlainResNetClassifier`) | +| `num_layers` | ResNet depth — `18` or `50` | +| `weights_init` | Initial weights — `ImageNet` for transfer learning | + +### Optimization Parameters + +| Parameter | Description | +|---|---| +| `lr_feature` | Learning rate for the feature extractor | +| `momentum_feature` | Momentum for the feature extractor optimizer | +| `weight_decay_feature` | Weight decay for the feature extractor | +| `lr_classifier` | Learning rate for the classifier head | +| `momentum_classifier` | Momentum for the classifier head optimizer | +| `weight_decay_classifier` | Weight decay for the classifier head | +| `step_size` | LR scheduler step size (epochs) | +| `gamma` | LR scheduler decay factor | + +> **Architecture note:** The current version supports only `PlainResNetClassifier` with ResNet-18 or ResNet-50 backbones. The classifier head and feature extractor are trained with separate optimizers, which is required for compatibility with the PyTorch-Wildlife framework. + +--- + +## 4. Running Training + +After configuring `configs/config.yaml`: + +```bash +python main.py +``` + +Monitor the console output for per-epoch loss and accuracy. Logs are written to the directory specified in `log_dir`. + +### Quick test with demo data + +```bash +# Download demo data +wget https://zenodo.org/records/15376499/files/demo_data_clf.zip +unzip demo_data_clf.zip -d data/ + +# The demo config already points to ./data/imgs — just run: +python main.py +``` + +--- + +## 5. Output + +Trained weights are saved to the `weights/` directory at the end of training. These weights follow the PyTorch-Wildlife classifier interface and can be loaded directly into the framework for inference on new images. + +--- + +## 6. Integration with MegaDetector + +A common workflow pairs MegaDetector detection upstream with MegaDetector-Classifier downstream: + +1. Run [MegaDetector](https://github.com/microsoft/MegaDetector) on your camera-trap images to detect and crop animals +2. Use `src/utils/batch_detection_cropping.py` to generate cropped images from MegaDetector outputs +3. Train MegaDetector-Classifier on the cropped detections +4. Deploy the resulting classifier through [PyTorch-Wildlife](https://github.com/microsoft/PytorchWildlife) diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..811cc9b --- /dev/null +++ b/environment.yaml @@ -0,0 +1,155 @@ +name: PT_Finetuning +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - bzip2=1.0.8=hd590300_5 + - ca-certificates=2023.11.17=hbcca054_0 + - ld_impl_linux-64=2.40=h41732ed_0 + - libffi=3.4.2=h7f98852_5 + - libgcc-ng=13.2.0=h807b86a_4 + - libgomp=13.2.0=h807b86a_4 + - libnsl=2.0.1=hd590300_0 + - libsqlite=3.44.2=h2797004_0 + - libuuid=2.38.1=h0b41bf4_0 + - libxcrypt=4.4.36=hd590300_1 + - libzlib=1.2.13=hd590300_5 + - ncurses=6.4=h59595ed_2 + - openssl=3.2.0=hd590300_1 + - pip=23.3.2=pyhd8ed1ab_0 + - python=3.8.18=hd12c33a_1_cpython + - readline=8.2=h8228510_1 + - setuptools=69.0.3=pyhd8ed1ab_0 + - tk=8.6.13=noxft_h4845f30_101 + - wheel=0.42.0=pyhd8ed1ab_0 + - xz=5.2.6=h166bdaf_0 + - pip: + - absl-py==2.1.0 + - aiofiles==23.2.1 + - aiohttp==3.9.3 + - aiosignal==1.3.1 + - altair==5.2.0 + - annotated-types==0.6.0 + - anyio==4.2.0 + - asttokens==2.4.1 + - async-timeout==4.0.3 + - attrs==23.2.0 + - backcall==0.2.0 + - cachetools==5.3.2 + - certifi==2023.11.17 + - charset-normalizer==3.3.2 + - click==8.1.7 + - colorama==0.4.6 + - contourpy==1.1.1 + - cycler==0.12.1 + - decorator==5.1.1 + - exceptiongroup==1.2.0 + - executing==2.0.1 + - fastapi==0.109.0 + - ffmpy==0.3.1 + - filelock==3.13.1 + - fire==0.5.0 + - fonttools==4.47.2 + - frozenlist==1.4.1 + - fsspec==2023.12.2 + - google-auth==2.27.0 + - google-auth-oauthlib==1.0.0 + - gradio==4.8.0 + - gradio-client==0.7.1 + - grpcio==1.60.0 + - h11==0.14.0 + - httpcore==1.0.2 + - httpx==0.26.0 + - huggingface-hub==0.20.3 + - idna==3.6 + - importlib-metadata==7.0.1 + - importlib-resources==6.1.1 + - ipython==8.12.3 + - jedi==0.19.1 + - jinja2==3.1.3 + - joblib==1.3.2 + - jsonschema==4.21.1 + - jsonschema-specifications==2023.12.1 + - kiwisolver==1.4.5 + - lightning-utilities==0.10.1 + - markdown==3.5.2 + - markdown-it-py==3.0.0 + - markupsafe==2.1.4 + - matplotlib==3.7.4 + - matplotlib-inline==0.1.6 + - mdurl==0.1.2 + - multidict==6.0.4 + - munch==2.5.0 + - numpy==1.24.4 + - oauthlib==3.2.2 + - opencv-python==4.9.0.80 + - opencv-python-headless==4.9.0.80 + - orjson==3.9.12 + - packaging==23.2 + - pandas==2.0.3 + - parso==0.8.3 + - pexpect==4.9.0 + - pickleshare==0.7.5 + - pillow==10.1.0 + - pkgutil-resolve-name==1.3.10 + - prompt-toolkit==3.0.43 + - protobuf==3.20.1 + - psutil==5.9.8 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyasn1==0.5.1 + - pyasn1-modules==0.3.0 + - pydantic==2.6.0 + - pydantic-core==2.16.1 + - pydub==0.25.1 + - pygments==2.17.2 + - pyparsing==3.1.1 + - python-dateutil==2.8.2 + - python-multipart==0.0.6 + - pytorch-lightning==1.9.0 + - pytorchwildlife==1.0.1.1 + - pytz==2023.4 + - pyyaml==6.0.1 + - referencing==0.33.0 + - requests==2.31.0 + - requests-oauthlib==1.3.1 + - rich==13.7.0 + - rpds-py==0.17.1 + - rsa==4.9 + - scikit-learn==1.2.0 + - scipy==1.10.1 + - seaborn==0.13.2 + - semantic-version==2.10.0 + - shellingham==1.5.4 + - six==1.16.0 + - sniffio==1.3.0 + - stack-data==0.6.3 + - starlette==0.35.1 + - supervision==0.16.0 + - tensorboard==2.14.0 + - tensorboard-data-server==0.7.2 + - termcolor==2.4.0 + - thop==0.1.1-2209072238 + - threadpoolctl==3.2.0 + - tomlkit==0.12.0 + - toolz==0.12.1 + - torch==1.10.1 + - torchaudio==0.10.1 + - torchmetrics==1.3.0.post0 + - torchvision==0.11.2 + - tqdm==4.66.1 + - traitlets==5.14.1 + - typer==0.9.0 + - typing-extensions==4.9.0 + - tzdata==2023.4 + - ultralytics-yolov5==0.1.1 + - urllib3==2.2.0 + - uvicorn==0.27.0.post1 + - wcwidth==0.2.13 + - websockets==11.0.3 + - werkzeug==3.0.1 + - yarl==1.9.4 + - zipp==3.17.0 +prefix: /home/andreshernandezcelisadeccoc/.conda/envs/PT_Finetuning diff --git a/main.py b/main.py new file mode 100644 index 0000000..ec66333 --- /dev/null +++ b/main.py @@ -0,0 +1,182 @@ +# %% +# Importing libraries +import os +import yaml +import typer +from munch import Munch +# %% +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.loggers import CSVLogger, CometLogger, TensorBoardLogger, WandbLogger +# %% +from src import algorithms +from src import datasets +# %% +from src.utils import batch_detection_cropping +from src.utils import data_splitting + +app = typer.Typer(pretty_exceptions_short=True, pretty_exceptions_show_locals=False) +# %% +@app.command() +def main( + config:str='./configs/config.yaml', + project:str='Custom-classification', + gpus:str='0', + logger_type:str='csv', + evaluate:str=None, + np_threads:str='32', + session:int=0, + seed:int=0, + dev:bool=False, + val:bool=False, + test:bool=False, + predict:bool=False, + predict_root:str="" + ): + """ + Main function for training or evaluating a ResNet model (50 or 18) using PyTorch Lightning. + It loads configurations, initializes the model, logger, and other components based on provided arguments. + + Args: + config (str): Path to the configuration file. + project (str): Name of the project for logging. + gpus (str): Comma-separated GPU ids for training. + logger_type (str): Type of logger to use (wandb, comet, tensorboard, csv). + evaluate (str): Path to the model checkpoint for evaluation. + np_threads (str): Number of numpy threads to use. + session (int): Session number for logging purposes. + seed (int): Random seed for reproducibility. + dev (bool): Development mode flag. + val (bool): Validation mode flag. + predict (bool): Prediction mode flag. + predict_root (str): Root directory for prediction outputs. + """ + + # GPU configuration: set up GPUs based on availability and user specification + gpus = gpus if torch.cuda.is_available() else None + gpus = [int(i) for i in gpus.split(',')] + + # Environment variable setup for numpy multi-threading. It is important to avoid cpu and ram issues. + os.environ["OMP_NUM_THREADS"] = str(np_threads) + os.environ["OPENBLAS_NUM_THREADS"] = str(np_threads) + os.environ["MKL_NUM_THREADS"] = str(np_threads) + os.environ["VECLIB_MAXIMUM_THREADS"] = str(np_threads) + os.environ["NUMEXPR_NUM_THREADS"] = str(np_threads) + # Load and set configurations from the YAML file + with open(config) as f: + conf = Munch(yaml.load(f, Loader=yaml.FullLoader)) + conf.evaluate = evaluate + conf.val = val + conf.test = test + conf.predict = predict + conf.predict_root = predict_root + + # Set a global seed for reproducibility + pl.seed_everything(seed) + + # If the annotation directory does not have a data split, split the data first + if conf.split_data: + # Replace annotation dir from config with the directory containing the split files + conf.annotation_dir = os.path.dirname(conf.split_path) + # Split the data according to the split type + if conf.split_type == 'location': + data_splitting.split_by_location(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size) + elif conf.split_type == 'sequence': + data_splitting.split_by_seq(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size) + elif conf.split_type == 'random': + data_splitting.create_splits(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size) + else: + raise ValueError('Invalid split type: {}. Available options: random, location, sequence.'.format(conf.split_type)) + + if not conf.predict: + # Get the path to the annotation files, and we only want to do this if we are not predicting + if conf.test: + test_annotations = os.path.join(conf.dataset_root, 'test_annotations.csv') + # Crop test data (most likely we don't need this) + batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), test_annotations) + else: + train_annotations = os.path.join(conf.dataset_root, 'train_annotations.csv') + val_annotations = os.path.join(conf.dataset_root, 'val_annotations.csv') + # Crop training data + batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), train_annotations) + # Crop validation data + batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), val_annotations) + + # Dataset and algorithm loading based on the configuration + dataset = datasets.__dict__[conf.dataset_name](conf=conf) + learner = algorithms.__dict__[conf.algorithm](conf=conf, + train_class_counts=dataset.train_class_counts, + id_to_labels=dataset.id_to_labels) + + # Logger setup based on the specified logger type + log_folder = 'log_dev' if dev else 'log' + logger = None + if logger_type == 'csv': + logger = CSVLogger( + save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm), + prefix=project, + name='{}_{}'.format(conf.algorithm, conf.conf_id), + version=session + ) + elif logger_type == 'tensorboard': + logger = TensorBoardLogger( + save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm), + prefix=project, + name='{}_{}'.format(conf.algorithm, conf.conf_id), + version=session + ) + elif logger_type == 'comet': + logger = CometLogger( + api_key=os.environ.get("COMET_API_KEY"), + save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm), + project_name=project, + experiment_name='{}_{}_{}'.format(conf.algorithm, conf.conf_id, session), + ) + elif logger_type == 'wandb': + logger = WandbLogger( + save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm), + project=project, + name='{}_{}_{}'.format(conf.algorithm, conf.conf_id, session), + ) + + # Callbacks for model checkpointing and learning rate monitoring + weights_folder = 'weights_dev' if dev else 'weights' + checkpoint_callback = ModelCheckpoint( + monitor='valid_mac_acc', mode='max', dirpath='./{}/{}/{}'.format(weights_folder, conf.log_dir, conf.algorithm), + save_top_k=1, filename='{}-{}'.format(conf.conf_id, session) + '-{epoch:02d}-{valid_mac_acc:.2f}', verbose=True + ) + + lr_monitor = LearningRateMonitor(logging_interval='step') + + # Trainer configuration in PyTorch Lightning + trainer = pl.Trainer( + max_epochs=conf.num_epochs, + check_val_every_n_epoch=1, + log_every_n_steps = conf.log_interval, + accelerator='gpu', + devices=gpus, + logger=None if evaluate is not None else logger, + callbacks=[lr_monitor, checkpoint_callback], + strategy='auto', + num_sanity_val_steps=0, + profiler=None + ) + # Training, validation, or evaluation execution based on the mode + if evaluate is not None: + if val: + trainer.validate(learner, dataloaders=[dataset.val_dataloader()], ckpt_path=evaluate) + elif predict: + trainer.predict(learner, dataloaders=[dataset.predict_dataloader()], ckpt_path=evaluate) + elif test: + trainer.test(learner, dataloaders=[dataset.test_dataloader()], ckpt_path=evaluate) + else: + print('Invalid mode for evaluation.') + else: + trainer.fit(learner, datamodule=dataset) +# %% +if __name__ == '__main__': + app() + +# %% diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..40ff7d7 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,104 @@ +site_name: MegaDetector-Classifier +site_url: https://microsoft.github.io/MegaDetector-Classifier/ +site_description: "MegaDetector-Classifier — Microsoft AI for Good Lab's open-source classification fine-tuning tool for camera-trap species identification. Part of the PyTorch-Wildlife ecosystem." +docs_dir: docs +site_dir: site +repo_url: https://github.com/microsoft/MegaDetector-Classifier +repo_name: microsoft/MegaDetector-Classifier +copyright: Copyright (c) 2023 Microsoft Corporation + +theme: + name: material + favicon: https://zenodo.org/records/15376499/files/cat.png + logo: https://zenodo.org/records/15376499/files/cat.png + icon: + menu: material/menu + alternate: material/translate + search: material/magnify + share: material/share-variant + close: material/close + top: material/arrow-up + edit: material/pencil + view: material/eye + repo: fontawesome/brands/git-alt + admonition: + note: material/note + abstract: material/lightbulb + info: material/information + tip: material/lightbulb-on + success: material/check-circle + question: material/help-circle + warning: material/alert + failure: material/alert-circle + danger: material/alert-octagon + bug: material/bug + example: material/format-list-bulleted + quote: material/format-quote-open + tag: + default: material/tag + info: material/information + warning: material/alert + danger: material/alert-octagon + previous: material/arrow-left + next: material/arrow-right + palette: + - media: "(prefers-color-scheme: light)" + scheme: default + primary: green + accent: deep orange + toggle: + icon: material/paw-off + name: Switch to dark mode + + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: teal + accent: deep orange + toggle: + icon: material/paw + name: Switch to light mode + features: + - navigation.tracking + - navigation.tabs + - navigation.sections + - navigation.path + - toc.follow + - navigation.top + - search.suggest + - search.share + - navigation.footer + +nav: + - MegaDetector-Classifier: + - Overview: index.md + - Installation: installation.md + - Usage Guide: usage.md + - Tags: tags.md + - Cite Us: cite.md + - Developer Guide: build_mkdocs.md + +markdown_extensions: + - admonition + - pymdownx.details + - pymdownx.superfences + - md_in_html + - attr_list + - sane_lists + - pymdownx.tabbed: + alternate_style: true + - toc: + permalink: true + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + - pymdownx.snippets: + check_paths: true + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + +plugins: + - search + - meta + - tags diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c20d78a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +PytorchWildlife +scikit_learn +lightning +munch +typer \ No newline at end of file diff --git a/src/algorithms/__init__.py b/src/algorithms/__init__.py new file mode 100644 index 0000000..fa84a4a --- /dev/null +++ b/src/algorithms/__init__.py @@ -0,0 +1,3 @@ +from . import utils +from .plain import * + diff --git a/src/algorithms/plain.py b/src/algorithms/plain.py new file mode 100644 index 0000000..be3212a --- /dev/null +++ b/src/algorithms/plain.py @@ -0,0 +1,285 @@ +import os +import numpy as np +import json +from datetime import datetime +from tqdm import tqdm +import random + +import torch +import torch.optim as optim +import pytorch_lightning as pl + +from .utils import acc +from src import models + + +__all__ = [ + 'Plain' +] + +class Plain(pl.LightningModule): + """ + Defines the architecture for training a model using PyTorch Lightning. + + This class inherits from PyTorch Lightning's LightningModule and sets up the model, optimizers, + and training/validation/testing steps for the training process. + """ + + name = 'Plain' + + def __init__(self, conf, train_class_counts, id_to_labels, **kwargs): + """ + Initializes the Plain model. + + Args: + conf: Configuration object with model parameters. + train_class_counts: Counts of training classes. + id_to_labels: Mapping from IDs to label names. + **kwargs: Additional keyword arguments. + """ + super().__init__() + self.hparams.update(conf.__dict__) + self.save_hyperparameters(ignore=['conf', 'train_class_counts']) + self.train_class_counts = train_class_counts + self.id_to_labels = id_to_labels + self.net = models.__dict__[self.hparams.model_name](num_cls=self.hparams.num_classes, + num_layers=self.hparams.num_layers) + + def configure_optimizers(self): + """ + Configures the optimizers and learning rate schedulers. + + Returns: + Tuple[List, List]: A tuple containing the list of optimizers and the list of learning rate schedulers. + """ + # Define parameters for the optimizer + net_optim_params_list = [ + # Optimizer parameters for feature extraction + {'params': self.net.feature.parameters(), + 'lr': self.hparams.lr_feature, + 'momentum': self.hparams.momentum_feature, + 'weight_decay': self.hparams.weight_decay_feature}, + # Optimizer parameters for the classifier + {'params': self.net.classifier.parameters(), + 'lr': self.hparams.lr_classifier, + 'momentum': self.hparams.momentum_classifier, + 'weight_decay': self.hparams.weight_decay_classifier} + ] + # Setup optimizer and optimizer scheduler + optimizer = torch.optim.SGD(net_optim_params_list) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma) + return [optimizer], [scheduler] + + def on_train_start(self): + """ + Hook function called at the start of training. Initializes best accuracy and the network. + """ + self.best_acc = 0 + self.net.feat_init() + self.net.setup_criteria() + + def training_step(self, batch, batch_idx): + """ + Training step for each batch. + + Args: + batch: The current batch of data. + batch_idx: The index of the current batch. + + Returns: + Tensor: The loss for the current training step. + """ + data, label_ids = batch[0], batch[1] + + # Forward pass + feats = self.net.feature(data) + logits = self.net.classifier(feats) + # Calculate loss + loss = self.net.criterion_cls(logits, label_ids) + self.log("train_loss", loss) + + return loss + + def on_validation_start(self): + """ + Hook function called at the start of validation. Initializes storage for validation outputs. + """ + self.val_st_outs = [] + + def validation_step(self, batch, batch_idx): + """ + Validation step for each batch. + + Args: + batch: The current batch of data. + batch_idx: The index of the current batch. + """ + data, label_ids = batch[0], batch[1] + # Forward pass + feats = self.net.feature(data) + logits = self.net.classifier(feats) + preds = logits.argmax(dim=1) + + self.val_st_outs.append((preds.detach().cpu().numpy(), + label_ids.detach().cpu().numpy())) + + def on_validation_epoch_end(self): + """ + Hook function called at the end of the validation epoch. Aggregates and logs validation results. + """ + total_preds = np.concatenate([x[0] for x in self.val_st_outs], axis=0) + total_label_ids = np.concatenate([x[1] for x in self.val_st_outs], axis=0) + self.eval_logging(total_preds, total_label_ids) + + def on_test_start(self): + """ + Hook function called at the start of testing. Initializes storage for test outputs. + """ + self.te_st_outs = [] + + def test_step(self, batch, batch_idx): + """ + Test step for each batch. + + Args: + batch: The current batch of data, including metadata. + batch_idx: The index of the current batch. + """ + data, label_ids, labels, file_ids = batch + # Forward pass + feats = self.net.feature(data) + logits = self.net.classifier(feats) + preds = logits.argmax(dim=1) + + self.te_st_outs.append((preds.detach().cpu().numpy(), + label_ids.detach().cpu().numpy(), + feats.detach().cpu().numpy(), + logits.detach().cpu().numpy(), + labels, file_ids + )) + + + def on_test_epoch_end(self): + """ + Hook function called at the end of the test epoch. Aggregates and logs test results, and saves output. + """ + # Concatenate outputs from all test steps + total_preds = np.concatenate([x[0] for x in self.te_st_outs], axis=0) + total_label_ids = np.concatenate([x[1] for x in self.te_st_outs], axis=0) + total_feats = np.concatenate([x[2] for x in self.te_st_outs], axis=0) + total_logits = np.concatenate([x[3] for x in self.te_st_outs], axis=0) + total_labels = np.concatenate([x[4] for x in self.te_st_outs], axis=0) + total_file_ids = np.concatenate([x[5] for x in self.te_st_outs], axis=0) + + # Calculate the metrics and save the output + self.eval_logging(total_preds[total_label_ids != -1], + total_label_ids[total_label_ids != -1], + print_class_acc=False) + + output_path = self.hparams.evaluate.replace('.ckpt', 'eval.npz') + np.savez(output_path, preds=total_preds, label_ids=total_label_ids, feats=total_feats, + logits=total_logits, labels=total_labels, file_ids=total_file_ids) + print('Test output saved to {}.'.format(output_path)) + + def on_predict_start(self): + """ + Hook function called at the start of prediction. Initializes storage for prediction outputs. + """ + self.pr_st_outs = [] + + def predict_step(self, batch, batch_idx): + """ + Prediction step for each batch. + + Args: + batch: The current batch of data, including metadata. + batch_idx: The index of the current batch. + """ + data, file_ids = batch + # Forward pass + feats = self.net.feature(data) + logits = self.net.classifier(feats) + preds = logits.argmax(dim=1) + probs = torch.softmax(logits, dim=1).max(dim=1)[0] + + self.pr_st_outs.append((preds.detach().cpu().numpy(), + feats.detach().cpu().numpy(), + logits.detach().cpu().numpy(), + probs.detach().cpu().numpy(), + file_ids + )) + + + def on_predict_epoch_end(self): + """ + Hook function called at the end of the predict epoch. Aggregates and saves prediction outputs. + """ + # Concatenate outputs from all predict steps + total_preds = np.concatenate([x[0] for x in self.pr_st_outs], axis=0) + total_feats = np.concatenate([x[1] for x in self.pr_st_outs], axis=0) + total_logits = np.concatenate([x[2] for x in self.pr_st_outs], axis=0) + total_probs = np.concatenate([x[3] for x in self.pr_st_outs], axis=0) + total_file_ids = np.concatenate([x[4] for x in self.pr_st_outs], axis=0) + + json_output = [] + for i in range(len(total_preds)): + json_output.append({ + "marker_id": "", + "survey_pic_id": total_file_ids[i], + "marker_confidence": float(total_probs[i]), + "marker_gear_type": "ghostnet" if total_preds[i] == 1 else "neg", + "marker_bounding_polygon": "", + "marker_status": "unverified", + "marker_ai_model": "" + }) + + output_path_full = self.hparams.evaluate.replace('.ckpt', '_predict.npz') + np.savez(output_path_full, preds=total_preds, feats=total_feats, + logits=total_logits, file_ids=total_file_ids) + print('Predict output saved to {}.'.format(output_path_full)) + + output_path_json = self.hparams.evaluate.replace('.ckpt', '_predict.json') + json.dump(json_output, open(output_path_json, 'w')) + print('Predict output json saved to {}.'.format(output_path_json)) + + + def eval_logging(self, preds, labels, print_class_acc=False): + """ + Logs evaluation metrics such as accuracy. + + Args: + preds: Predictions from the model. + labels: Ground truth labels. + print_class_acc (bool): Flag to print class-wise accuracy. + """ + class_acc, mac_acc, mic_acc = acc(preds, labels) + unique_eval_labels = np.unique(labels) + + self.log("valid_mac_acc", mac_acc * 100) + self.log("valid_mic_acc", mic_acc * 100) + + if print_class_acc: + + if self.train_class_counts: + acc_list = [(class_acc[i], unique_eval_labels[i], + self.id_to_labels[unique_eval_labels[i]], + self.train_class_counts[unique_eval_labels[i]]) + for i in range(len(class_acc))] + + print('\n') + for i in range(len(class_acc)): + info = '{:>20} ({:<3}, tr {:>3}) Acc: '.format(acc_list[i][2], + acc_list[i][1], + acc_list[i][3]) + info += '{:.2f}'.format(acc_list[i][0] * 100) + print(info) + else: + acc_list = [(class_acc[i], unique_eval_labels[i], + self.id_to_labels[unique_eval_labels[i]]) + for i in range(len(class_acc))] + + print('\n') + for i in range(len(class_acc)): + info = '{:>20} ({:<3}) Acc: '.format(acc_list[i][2], acc_list[i][1]) + info += '{:.2f}'.format(acc_list[i][0] * 100) + print(info) diff --git a/src/algorithms/utils.py b/src/algorithms/utils.py new file mode 100644 index 0000000..11a14da --- /dev/null +++ b/src/algorithms/utils.py @@ -0,0 +1,32 @@ +from sklearn.metrics import confusion_matrix + +def acc(preds, labels): + """ + Calculate the accuracy metrics based on predictions and true labels. + + This function computes the confusion matrix and derives three types of accuracies: + class-wise accuracy (cls_acc), micro accuracy (mic_acc), and macro accuracy (mac_acc). + + Args: + preds (array-like): Predicted labels. + labels (array-like): True labels. + + Returns: + tuple: A tuple containing: + - cls_acc (ndarray): Class-wise accuracy. + - mac_acc (float): Macro accuracy (average of class-wise accuracies). + - mic_acc (float): Micro accuracy (overall accuracy). + """ + # Compute the confusion matrix from true labels and predictions + matrix = confusion_matrix(labels, preds) + + # Calculate class-wise accuracy (accuracy for each class) + cls_acc = matrix.diagonal() / matrix.sum(axis=1) + + # Calculate micro accuracy (overall accuracy) + mic_acc = matrix.diagonal().sum() / matrix.sum() + + # Calculate macro accuracy (mean of class-wise accuracies) + mac_acc = cls_acc.mean() + + return cls_acc, mac_acc, mic_acc diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py new file mode 100644 index 0000000..a04dcce --- /dev/null +++ b/src/datasets/__init__.py @@ -0,0 +1 @@ +from .custom import * \ No newline at end of file diff --git a/src/datasets/custom.py b/src/datasets/custom.py new file mode 100644 index 0000000..cbac698 --- /dev/null +++ b/src/datasets/custom.py @@ -0,0 +1,255 @@ +# Import necessary libraries +import os +from glob import glob +import numpy as np +import pandas as pd +import torch +from PIL import Image +from torchvision import transforms +from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as pl + +# Exportable class names for external use +__all__ = [ + 'Custom_Crop' +] + +# Define the allowed image extensions +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") + +def has_file_allowed_extension(filename: str, extensions: tuple) -> bool: + """Checks if a file is an allowed extension.""" + return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions)) + +def is_image_file(filename: str) -> bool: + """Checks if a file is an allowed image extension.""" + return has_file_allowed_extension(filename, IMG_EXTENSIONS) + +# Define normalization mean and standard deviation for image preprocessing +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] + +# Define data transformations for training and validation datasets +data_transforms = { + 'train': transforms.Compose([ + transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0), ratio=(0.8, 1.2)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomVerticalFlip(p=0.5), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), + transforms.ToTensor(), + transforms.Normalize(mean, std) + ]), + 'val': transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean, std) + ]), +} + +class Custom_Base_DS(Dataset): + """ + Base dataset class for handling custom datasets. + + Attributes: + rootdir (str): Root directory containing the dataset. + transform (callable, optional): Transformations to be applied to each data sample. + predict (bool): Flag to indicate if the dataset is used for prediction. + """ + + def __init__(self, rootdir, transform=None, predict=False): + """ + Initialize the Custom_Base_DS with the directory, transformations, and mode. + + Args: + rootdir (str): Directory containing the dataset. + transform (callable, optional): Transformations to be applied to each data sample. + predict (bool): Flag to indicate if the dataset is used for prediction. + """ + self.rootdir = rootdir + self.transform = transform + self.predict = predict + self.data = [] + self.label_ids = [] + self.labels = [] + self.seq_ids = [] + + def load_data(self): + """ + Load data from the specified directory. Differentiates between prediction and training/validation mode. + """ + if self.predict: + # Load data for prediction + # self.data = glob(os.path.join(self.img_root,"*.{}".format(self.extension))) + self.data = [os.path.join(dp, f) for dp, dn, filenames in os.walk(self.img_root) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename + else: + # Load data for training/validation + self.data = list(self.ann['path']) + self.label_ids = list(self.ann['classification']) + self.labels = list(self.ann['label']) + print('Number of images loaded: ', len(self.data)) + + def class_counts_cal(self): + """ + Calculate the count of each class in the dataset. + + Returns: + tuple: Unique label IDs and their respective counts. + """ + unique_label_ids, unique_counts = np.unique(self.label_ids, return_counts=True) + return unique_label_ids, unique_counts + + def __len__(self): + """ + Return the total number of items in the dataset. + + Returns: + int: Total number of items. + """ + return len(self.data) + + def __getitem__(self, index): + """ + Retrieve an item by its index. + + Args: + index (int): Index of the item to be retrieved. + + Returns: + tuple: Depending on the mode, returns different tuples containing the image and additional information. + """ + file_id = self.data[index] + file_dir = os.path.join(self.img_root, file_id) if not self.predict else file_id + + with open(file_dir, 'rb') as f: + sample = Image.open(f).convert('RGB') + + if self.transform is not None: + sample = self.transform(sample) + + if self.predict: + return sample, file_id + + label_id = self.label_ids[index] + label = self.labels[index] + + return sample, label_id, label, file_dir + + +class Custom_Crop_DS(Custom_Base_DS): + """ + Dataset class for handling custom cropped datasets. + + Inherits from Custom_Base_DS and includes specific handling for cropped data. + """ + + def __init__(self, rootdir, dset='train', transform=None): + """ + Initialize the Custom_Crop_DS with the dataset directory, type, and transformations. + + Args: + rootdir (str): Directory containing the dataset. + dset (str): Type of dataset (train, val, test, predict). + transform (callable, optional): Transformations to be applied to each data sample. + """ + self.predict = dset == 'predict' + super().__init__(rootdir=rootdir, transform=transform, predict=self.predict) + self.img_root = rootdir if self.predict else os.path.join(self.rootdir, 'cropped_resized') + if not self.predict: + self.ann = pd.read_csv(os.path.join(self.rootdir, 'cropped_resized', '{}_annotations_cropped.csv' + .format('test' if dset == 'test' else dset))) + self.load_data() + + +class Custom_Base(pl.LightningDataModule): + """ + Base data module for handling custom datasets in PyTorch Lightning. + + Manages the data loading pipeline for training, validation, testing, and prediction. + """ + + ds = None + + def __init__(self, conf): + """ + Initialize the Custom_Base data module with configuration. + + Args: + conf (object): Configuration object containing dataset paths and other settings. + """ + super().__init__() + self._log_hyperparams = True + self.id_to_labels = None # We don't need this for evaluations. We should save this in model weights in the future + self.train_class_counts = None + + self.conf = conf + + print('Loading datasets...') + # Load datasets for different modes (training, validation, testing, prediction) + if self.conf.predict: + self.dset_pr = self.ds(rootdir=self.conf.predict_root, dset='predict', transform=data_transforms['val']) + elif self.conf.test: + self.dset_te = self.ds(rootdir=self.conf.dataset_root, dset='test', transform=data_transforms['val']) + self.id_to_labels = {i: l for i, l in np.unique(pd.Series(zip(self.dset_te.label_ids, self.dset_te.labels)))} + else: + self.dset_tr = self.ds(rootdir=self.conf.dataset_root, dset='train', transform=data_transforms['train']) + self.dset_val = self.ds(rootdir=self.conf.dataset_root, dset='val', transform=data_transforms['val']) + + self.id_to_labels = {i: l for i, l in np.unique(pd.Series(zip(self.dset_tr.label_ids, self.dset_tr.labels)))} + # Calculate class counts and label mappings + self.unique_label_ids, self.train_class_counts = self.dset_tr.class_counts_cal() + + print('Datasets loaded.') + + def train_dataloader(self): + """ + Create a DataLoader for the training dataset. + + Returns: + DataLoader: DataLoader for the training dataset. + """ + return DataLoader(self.dset_tr, batch_size=self.conf.batch_size, shuffle=True, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + + def val_dataloader(self): + """ + Create a DataLoader for the validation dataset. + + Returns: + DataLoader: DataLoader for the validation dataset. + """ + return DataLoader(self.dset_val, batch_size=self.conf.batch_size, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + + def test_dataloader(self): + """ + Create a DataLoader for the testing dataset. + + Returns: + DataLoader: DataLoader for the testing dataset. + """ + return DataLoader(self.dset_te, batch_size=256, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + + def predict_dataloader(self): + """ + Create a DataLoader for the prediction dataset. + + Returns: + DataLoader: DataLoader for the prediction dataset. + """ + return DataLoader(self.dset_pr, batch_size=64, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + + +class Custom_Crop(Custom_Base): + """ + Custom data module specifically for cropped datasets in PyTorch Lightning. + + Inherits from Custom_Base and specifies the dataset type as Custom_Crop_DS. + """ + + def __init__(self, conf): + """ + Initialize the Custom_Crop data module with configuration. + + Args: + conf (object): Configuration object containing dataset paths and other settings. + """ + self.ds = Custom_Crop_DS + super().__init__(conf=conf) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..70be09a --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1 @@ +from .plain_resnet import * \ No newline at end of file diff --git a/src/models/plain_resnet.py b/src/models/plain_resnet.py new file mode 100644 index 0000000..d664c91 --- /dev/null +++ b/src/models/plain_resnet.py @@ -0,0 +1,169 @@ +import os +import copy +from collections import OrderedDict +import torch +import torch.nn as nn +from torchvision.models.resnet import BasicBlock, Bottleneck +from torchvision.models.resnet import * + + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_state_dict_from_url +# Exportable class names for external use +__all__ = [ + 'PlainResNetClassifier' +] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth' +} + +class ResNetBackbone(ResNet): + """ + Custom ResNet backbone class for feature extraction. + + Inherits from the torchvision ResNet class and allows customization of the architecture. + """ + + def __init__( + self, + block, + layers, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + ): + """ + Initialize the ResNet backbone. + + Args: + block (nn.Module): Type of block to use (BasicBlock or Bottleneck). + layers (list of int): Number of layers in each block. + zero_init_residual (bool): Zero-initialize the last BN in each residual branch. + groups (int): Number of groups for group normalization. + width_per_group (int): Width per group. + replace_stride_with_dilation (list of bool or None): Use dilation instead of stride. + norm_layer (callable or None): Norm layer to use. + """ + super(ResNetBackbone, self).__init__( + block=block, + layers=layers, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + replace_stride_with_dilation=replace_stride_with_dilation, + norm_layer=norm_layer, + ) + + def _forward_impl(self, x): + """ + Forward pass implementation for the ResNet backbone. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the ResNet backbone. + """ + # Applying the ResNet layers and operations + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + return x + + +class PlainResNetClassifier(nn.Module): + """ + Custom ResNet classifier class. + + Extends nn.Module and provides a complete ResNet-based classifier, including feature extraction and classification layers. + """ + + name = 'PlainResNetClassifier' + + def __init__(self, num_cls=10, num_layers=18): + """ + Initialize the PlainResNetClassifier. + + Args: + num_cls (int): Number of classes for the classifier. + num_layers (int): Number of layers in the ResNet model (e.g., 18, 50). + """ + super(PlainResNetClassifier, self).__init__() + self.num_cls = num_cls + self.num_layers = num_layers + self.feature = None + self.classifier = None + self.criterion_cls = None + + # Initialize the network with the specified settings + self.setup_net() + + def setup_net(self): + """ + Set up the ResNet network and initialize its weights. + """ + kwargs = {} + + # Selecting the appropriate ResNet architecture and pre-trained weights + if self.num_layers == 18: + block = BasicBlock + layers = [2, 2, 2, 2] + #self.pretrained_weights = ResNet18_Weights.IMAGENET1K_V1 + self.pretrained_weights = state_dict = load_state_dict_from_url(model_urls['resnet18'], + progress=True) + elif self.num_layers == 50: + block = Bottleneck + layers = [3, 4, 6, 3] + #self.pretrained_weights = ResNet50_Weights.IMAGENET1K_V1 + self.pretrained_weights = state_dict = load_state_dict_from_url(model_urls['resnet50'], + progress=True) + else: + raise Exception('ResNet Type not supported.') + + # Constructing the feature extractor and classifier + self.feature = ResNetBackbone(block, layers, **kwargs) + self.classifier = nn.Linear(512 * block.expansion, self.num_cls) + + def setup_criteria(self): + """ + Set up the criterion for the classifier. + """ + # Criterion for binary classification + self.criterion_cls = nn.CrossEntropyLoss() + + def feat_init(self): + """ + Initialize the feature extractor with pre-trained weights. + """ + # Load pre-trained weights and adjust for the current model + #init_weights = self.pretrained_weights.get_state_dict(progress=True) + init_weights = self.pretrained_weights + init_weights = OrderedDict({k.replace('module.', '').replace('feature.', ''): init_weights[k] + for k in init_weights}) + + # Load the weights into the feature extractor + self.feature.load_state_dict(init_weights, strict=False) + + # Identify missing and unused keys in the loaded weights + load_keys = set(init_weights.keys()) + self_keys = set(self.feature.state_dict().keys()) + + missing_keys = self_keys - load_keys + unused_keys = load_keys - self_keys + print('missing keys: {}'.format(sorted(list(missing_keys)))) + print('unused_keys: {}'.format(sorted(list(unused_keys)))) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/batch_detection_cropping.py b/src/utils/batch_detection_cropping.py new file mode 100644 index 0000000..a893014 --- /dev/null +++ b/src/utils/batch_detection_cropping.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" Demo for batch detection, cropping and resizing""" + +#%% +# PyTorch imports +import torch +# Importing the model, dataset, transformations and utility functions from PytorchWildlife +from PytorchWildlife.models import detection as pw_detection +from PytorchWildlife.data import transforms as pw_trans +from PytorchWildlife.data import datasets as pw_data +# Importing the utility function for saving cropped images +from src.utils import utils + +def batch_detection_cropping(folder_path, output_path, annotation_file): + # Setting the device to use for computations ('cuda' indicates GPU) + DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + # Initializing the MegaDetectorV5 model for image detection + detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True) + + """ Batch-detection demo """ + # Performing batch detection on the images + results = detection_model.batch_image_detection(folder_path) + + # Saving the detected objects as cropped images + crop_annotation_path = utils.save_crop_images(results, output_path, annotation_file) + return crop_annotation_path + + + +# %% diff --git a/src/utils/data_splitting.py b/src/utils/data_splitting.py new file mode 100644 index 0000000..f667c62 --- /dev/null +++ b/src/utils/data_splitting.py @@ -0,0 +1,147 @@ +## DATA SPLITTING + +import pandas as pd +from sklearn.model_selection import train_test_split +import os +from tqdm import tqdm + +def create_splits(csv_path, output_folder, test_size=0.2, val_size=0.1): + """ + Create stratified training, validation, and testing splits. + + Args: + - csv_path (str): Path to the csv containing the annotations. + - output_folder (str): Destination directory to save the annotation split csv files. + - test_size (float): Proportion of the dataset to include in the test split. + - val_size (float): Proportion of the training dataset to include in the validation split. + + Returns: + - A tuple of DataFrames: (train_set, val_set, test_set) + - Saves the splits into separate csv files in the output_folder. + """ + # Load the data from the csv file + data = pd.read_csv(csv_path) + # Separate the features and the targets + X = data[['path','label']] + y = data['classification'] + + # First split to separate out the test set + X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=test_size, stratify=y, random_state=42) + + # Adjust val_size to account for the initial split + val_size_adjusted = val_size / (1 - test_size) + + # Second split to separate out the validation set + X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=val_size_adjusted, stratify=y_temp, random_state=42) + + # Combine features, labels, and classification back into dataframes + train_set = pd.concat([X_train.reset_index(drop=True), y_train.reset_index(drop=True)], axis=1) + val_set = pd.concat([X_val.reset_index(drop=True), y_val.reset_index(drop=True)], axis=1) + test_set = pd.concat([X_test.reset_index(drop=True), y_test.reset_index(drop=True)], axis=1) + + # Create the output directory in case that it does not exist + os.makedirs(output_folder, exist_ok=True) + + # Save the splits to new CSV files + train_set.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False) + val_set.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False) + test_set.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False) + + # Return the dataframes + return train_set, val_set, test_set + +def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, random_state=None): + """ + Splits the dataset into train, validation, and test sets based on location, ensuring that: + 1. All images from the same location are in the same split. + 2. The split is random among the locations. + 3. Saves the split datasets into CSV files. + + Parameters: + - csv_path: Path to the csv containing the annotations. + - train_size, val_size, test_size: float, proportions of the dataset to include in the train, validation, and test splits. + - random_state: int, random state for reproducibility. + """ + # Load the data from the csv file + data = pd.read_csv(csv_path) + + # Calculate train size based on val and test size + train_size = 1.0 - val_size - test_size + + # Get unique locations + unique_locations = data['Location'].unique() + + # Split locations into train and temp (temporary holding for val and test) + train_locs, temp_locs = train_test_split(unique_locations, train_size=train_size, random_state=random_state) + + # Adjust the proportions for val and test based on the remaining locations + temp_size = val_size / (val_size + test_size) + val_locs, test_locs = train_test_split(temp_locs, train_size=temp_size, random_state=random_state) + + # Allocate images to train, validation, and test sets based on their location + train_data = data[data['Location'].isin(train_locs)] + val_data = data[data['Location'].isin(val_locs)] + test_data = data[data['Location'].isin(test_locs)] + + # Save the datasets to CSV files + train_data.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False) + val_data.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False) + test_data.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False) + + # Return the split datasets + return train_data, val_data, test_data + + +def split_by_seq(csv_path, output_folder, val_size=0.15, test_size=0.15, random_state=None): + """ + Splits the dataset into train, validation, and test sets based on sequence ID, ensuring that: + 1. All images from the same sequence are in the same split. + 2. The split is random among the sequences. + 3. Saves the split datasets into CSV files. + + Parameters: + - csv_path: Path to the csv containing the annotations. + - train_size, val_size, test_size: float, proportions of the dataset to include in the train, validation, and test splits. + - random_state: int, random state for reproducibility. + """ + # Load the data from the csv file + data = pd.read_csv(csv_path) + + # Convert 'Photo_Time' from string to datetime + data['Photo_Time'] = pd.to_datetime(data['Photo_Time']) + + # Calculate train size based on val and test size + train_size = 1 - val_size - test_size + + # Sort by 'Photo_Time' to ensure chronological order + data = data.sort_values(by=['Photo_Time']).reset_index(drop=True) + + # Group photos into sequences based on a 30-second interval + time_groups = data.groupby(pd.Grouper(key='Photo_Time', freq='30S')) + + # Assign unique sequence IDs to each group + for s, i in tqdm(enumerate(time_groups.indices.values())): + data.loc[i, 'Seq_ID'] = int(s) + + # Get unique sequence IDs + unique_seq_ids = data['Seq_ID'].unique() + + # Split sequence IDs into train and temp (temporary holding for val and test) + train_seq_ids, temp_seq_ids = train_test_split(unique_seq_ids, train_size=train_size, random_state=random_state) + + # Adjust the proportions for val and test based on the remaining sequences + temp_size = val_size / (val_size + test_size) + val_seq_ids, test_seq_ids = train_test_split(temp_seq_ids, train_size=temp_size, random_state=random_state) + + # Allocate images to train, validation, and test sets based on their sequence ID + train_data = data[data['Seq_ID'].isin(train_seq_ids)] + val_data = data[data['Seq_ID'].isin(val_seq_ids)] + test_data = data[data['Seq_ID'].isin(test_seq_ids)] + + # Save the datasets to CSV files + train_data.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False) + val_data.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False) + test_data.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False) + + # Return the split datasets + return train_data, val_data, test_data diff --git a/src/utils/utils.py b/src/utils/utils.py new file mode 100644 index 0000000..0b359a5 --- /dev/null +++ b/src/utils/utils.py @@ -0,0 +1,72 @@ +import os +import pandas as pd +import cv2 +import supervision as sv +from PIL import Image +import numpy as np + +def save_crop_images(results, output_dir, original_csv_path, overwrite=False): + """ + Save cropped images based on the detection bounding boxes. + + Args: + results (list): + Detection results containing image ID and detections. + output_dir (str): + Directory to save the cropped images. + original_csv_path (str): + Path to the original CSV file. + overwrite (bool): + Whether overwriting existing image folders. Default to False. + Return: + new_csv_path (str): + Path to the new CSV file. + """ + assert isinstance(results, list) + + # Read the original CSV file + original_df = pd.read_csv(original_csv_path) + + # Prepare a list to store new records for the new CSV + new_records = [] + + os.makedirs(output_dir, exist_ok=True) + with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: + for entry in results: + # Process the data if the name of the file is in the dataframe + if os.path.basename(entry["img_id"]) in original_df['path'].values: + for i, (xyxy, cat) in enumerate(zip(entry["detections"].xyxy, entry["detections"].class_id)): + cropped_img = sv.crop_image( + image=np.array(Image.open(entry["img_id"]).convert("RGB")), xyxy=xyxy + ) + new_img_name = "{}_{}_{}".format( + int(cat), i, entry["img_id"].rsplit(os.sep, 1)[1]) + sink.save_image( + image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR), + image_name=new_img_name + ), + + # Save the crop into a new csv + image_name = entry['img_id'] + + classification_id = original_df[original_df['path'].str.endswith(image_name.split(os.sep)[-1])]['classification'].values[0] + classification_name = original_df[original_df['path'].str.endswith(image_name.split(os.sep)[-1])]['label'].values[0] + # Add record to the new CSV data + new_records.append({ + 'path': new_img_name, + 'classification': classification_id, + 'label': classification_name + }) + + # Create a DataFrame from the new records + new_df = pd.DataFrame(new_records) + + # Define the path for the new CSV file + new_file_name = "{}_cropped.csv".format(original_csv_path.split(os.sep)[-1].split('.')[0]) + new_csv_path = os.path.join(output_dir, new_file_name) + + # Save the new DataFrame to CSV + new_df.to_csv(new_csv_path, index=False) + + return new_csv_path +