Skip to content

notsooamit/ChloraScan

Repository files navigation

M3FS-Net

A dual-channel convolutional network for crop disease identification under wild field conditions. Combines a pretrained MobileNetV2 backbone with a custom multi-scale feature extractor and a hybrid statistical-attention channel selection mechanism.


Problem

Crop disease classifiers trained on lab-curated datasets (uniform backgrounds, controlled lighting) degrade sharply when deployed on field imagery. Background clutter, variable illumination, occlusion, and inter-class visual similarity combine to make wild-condition classification a fundamentally harder problem.

Traditional ML methods (SVM, Random Forest, LightGBM) operating on hand-engineered features (HOG, PCA) collapse on multi-class datasets -- PCA+SVM achieves 36.6% on 10-class tomato disease, barely above random. Standard pretrained CNNs with frozen-backbone transfer learning fare better but still lose 8-20 percentage points when moving from 4-class to 10-class wild-condition datasets. No single pretrained architecture dominates across all datasets.

M3FS-Net addresses both problems: it extracts complementary features through independent channels optimized for different representational scales, and it aggressively prunes redundant channels via a differentiable selection mechanism that gates on both statistical quality and classification relevance.


Architecture

M3FS-Net processes each input image through two independent parallel branches, then selectively fuses their outputs before classification.

Input (224 x 224 x 3)
    |
    +-- Branch 1: MobileNetV2 Backbone (layers 0-13 frozen, 14-18 fine-tuned)
    |       |
    |       7 x 7 x 1280 feature maps
    |       |
    |       HSAFS (alpha=128)
    |       |
    |       7 x 7 x 128  (10% of channels retained)
    |
    +-- Branch 2: Enhanced Three Deep Block (E3DB)
            |
            +-- Block-LOCAL:   AdaptiveAvgPool(14x14) -> Conv1x1 -> Conv3x3(d=1)
            +-- Block-MEDIUM:  AdaptiveAvgPool(28x28) -> Conv1x1 -> Conv3x3(d=2)
            +-- Block-GLOBAL:  AdaptiveAvgPool(56x56) -> Conv1x1 -> Conv3x3(d=4)
            |
            Concat -> Conv1x1 fuse -> SE attention
            |
            7 x 7 x 256
            |
            HSAFS (beta=96)
            |
            7 x 7 x 96  (37.5% of channels retained)

    Concat(Branch1, Branch2) -> 7 x 7 x 224
        |
        Conv3x3(d=2, 128) -> BatchNorm -> ReLU
        |
        GlobalAveragePooling -> 128
        |
        Dense(256) -> ReLU -> Dropout(0.3)
        |
        Dense(num_classes) -> Softmax

Branch 1: MobileNetV2 Pretrained Backbone

MobileNetV2 pretrained on ImageNet provides robust general-purpose visual features. The 19-layer inverted residual backbone outputs 7x7x1280 feature maps at 224x224 input resolution.

Transfer learning strategy. Layers 0-13 are frozen -- these capture universal primitives (edges, textures, shapes) that transfer across domains. Layers 14-18 are fine-tuned at lr=1e-4 to adapt high-level features to disease-specific patterns.

Why MobileNetV2. Lightweight inference. Strong ImageNet transfer performance. Depthwise separable convolutions keep parameter count low without sacrificing representational capacity.

Branch 2: Enhanced Three Deep Block (E3DB)

Branch 2 operates on raw RGB pixels -- not MobileNetV2 features -- ensuring the two channels capture genuinely complementary information. It runs three sub-branches in parallel, each at a different spatial resolution and dilation rate:

Sub-branch Input resolution Dilation What it captures
Local 14x14 1 Lesion edges, spot boundaries, color transitions
Medium 28x28 2 Lesion texture, vein discoloration, fungal growth patterns
Global 56x56 4 Leaf shape deformation, discoloration distribution, disease spread

Sub-branch outputs are concatenated (7x7x384), projected to 7x7x256 via a 1x1 convolution, then weighted by a Squeeze-and-Excitation block that learns to emphasize the most informative scale per input.

Hybrid Statistical-Attention Feature Selection (HSAFS)

HSAFS is applied independently to each branch's output before fusion. It selects the k most discriminative channels through a product of two scores:

Statistical score (unsupervised). Four metrics computed per-channel, batch-averaged, then min-max normalized and averaged:

  • Coefficient of Variation (sigma / |mu|) -- scale-invariant dispersion
  • Kurtosis (E[(x-mu)^4] / sigma^4 - 3) -- tail heaviness, clamped to [-10, 10]
  • Shannon Entropy (-sum p log p) -- information content
  • Inter-Quartile Range (Q75 - Q25) -- robust spread, outlier-immune

Learned attention (supervised). A standard SE-style bottleneck: GAP -> Linear(C, C/4) -> ReLU -> Linear(C/4, C) -> Sigmoid, producing per-channel weights through backpropagation.

Fusion. final_score = stat_score * learned_attention. The product ensures selected channels are both statistically rich and classification-relevant.

Selection. Soft top-k via argsort(score)[:k], with softmax-weighted outputs scaled by k to preserve magnitude. The soft weighting makes the entire pipeline differentiable end-to-end.

Channel reduction ratios: alpha=128 (from 1280, 10% retained) for Branch 1; beta=96 (from 256, 37.5% retained) for Branch 2.

Test-Time Augmentation

Each test image is processed through 5 forward passes: 1 clean (standard Resize+CenterCrop) and 4 augmented (RandomResizedCrop + RandomFlip). Logits are averaged before argmax. TTA provides 1-5% accuracy improvement without architectural changes.


Results

Primary

All metrics computed under macro-averaged one-vs-rest scheme. M3FS-Net with TTA (5-view).

Dataset Classes Accuracy TPR FPR TNR F1
MDS (Rice) 4 99.81 99.81 0.06 99.94 99.81
CDS (Corn) 4 96.08 96.08 1.31 98.69 96.06
TDDS (Tomato) 10 97.10 97.10 0.32 99.68 97.10

Comparative

M3FS-Net ranks #1 on all three datasets against 10 baselines (5 traditional ML, 5 pretrained CNNs).

Dataset M3FS-Net Best DL baseline Best trad-ML
MDS (4-class) 99.81 99.13 (MobileNetV2) 94.13 (HOG+LGBM)
CDS (4-class) 96.08 94.61 (VGG16) 82.84 (HOG+LGBM)
TDDS (10-class) 97.10 92.20 (DenseNet121) 61.40 (HOG+LGBM)

Degradation analysis

How much accuracy each model loses when moving from simple (MDS) to complex (TDDS):

Model MDS -> TDDS drop Retention
M3FS-Net -2.71 97.3%
DenseNet121 -6.26 93.6%
MobileNetV2 -8.13 91.8%
VGG16 -16.26 83.5%
HOG+LGBM -32.73 65.2%

Ablation summary

Config Frozen layers TTA MDS CDS TDDS
Baseline 0-13 No 99.81 93.38 91.40
More trainable params 0-9 No -- 93.38 90.80
AdamW + label smoothing 0-13 No -- 88.73 --
Final 0-13 Yes (5-view) 99.81 96.08 97.10

Key finding: TTA provided the largest consistent accuracy gain. Architectural and regularization changes degraded performance on the smaller CDS and more complex TDDS datasets.


Datasets

All images captured under uncontrolled field conditions with natural backgrounds, variable lighting, and real-world noise. Class balancing applied via random downsampling to the minimum class count per dataset.

Property MDS CDS TDDS
Crop Rice Corn Tomato
Classes 4 4 10
Samples/class 1,300 510 500
Total images 5,200 2,040 5,000
Split 70/10/20 stratified 70/10/20 stratified 70/10/20 stratified
Seed 42 42 42

MDS classes: Bacterial Leaf Blight, Blast, Brown Spot, Tungro.

CDS classes: Blight, Common Rust, Gray Leaf Spot, Healthy.

TDDS classes: Bacterial Spot, Early Blight, Healthy, Late Blight, Leaf Mold, Mosaic Virus, Septoria Leaf Spot, Spider Mites, Target Spot, Yellow Leaf Curl Virus.


Training configuration

Parameter Value
Framework PyTorch 2.x
Input size 224 x 224 x 3
Batch size 32
Max epochs 30
Early stopping Patience 5 (validation loss)
Optimizer Adam
LR (backbone layers 14-18) 1e-4
LR (E3DB, HSAFS, classifier head) 1e-3
LR schedule CosineAnnealingWarmRestarts (T0=10)
Loss CrossEntropyLoss
Dropout 0.3
TTA views 5 (1 clean + 4 augmented)
Hardware NVIDIA Tesla T4 (Google Colab)

Data augmentation (training only): RandomResizedCrop (scale 0.8-1.0), RandomHorizontalFlip, RandomVerticalFlip, RandomRotation (+-15 deg), ColorJitter (brightness/contrast/saturation 0.2, hue 0.05), ImageNet normalization.

Evaluation preprocessing: Resize(256) -> CenterCrop(224) -> Normalize.


Installation

Requires Python 3.10+ and PyTorch 2.x.

git clone <repo-url>
cd ChloraScan
pip install -e .

For GPU training, install the CUDA-compatible PyTorch build before the package:

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install -e .

Registered CLI commands

After pip install -e ., the following commands are available:

train-m3fs-net        --dataset {MDS,CDS,TDDS,all} [--epochs N] [--no-tta] [--force]
train-dl-baselines    --dataset {MDS,CDS,TDDS,all} --model {DenseNet121,MobileNetV2,...}
train-traditional-ml  --dataset {MDS,CDS,TDDS,all} --model {PCA_SVM,HOG_RF,...}

Usage

Data preparation

Organize datasets under a single root directory (default: ./data/):

data/
  MDS/
    Bacterial_leaf_blight/  (1300+ images)
    Blast/                  (1300+ images)
    Brown_spot/             (1300+ images)
    Tungro/                 (1300+ images)
  CDS/
    Blight/                 (510+ images)
    ...
  TDDS/
    Bacterial_Spot/         (500+ images)
    ...

Set CHLORASCAN_DATA_ROOT to override the default path:

export CHLORASCAN_DATA_ROOT=/path/to/datasets

Train M3FS-Net

# All datasets
train-m3fs-net --dataset all

# Single dataset, custom epochs, with TTA
train-m3fs-net --dataset TDDS --epochs 30

# Without TTA (single-pass evaluation)
train-m3fs-net --dataset MDS --no-tta

Train DL baselines

# All models on all datasets
train-dl-baselines --dataset all --model all

# Single model on single dataset
train-dl-baselines --dataset TDDS --model DenseNet121

Train traditional ML baselines

# All models on all datasets
train-traditional-ml --dataset all

# Single model on single dataset
train-traditional-ml --dataset MDS --model HOG_LGBM

Quick verification (no data needed)

python -c "
from chlorascan.models import M3FSNet
import torch
m = M3FSNet(4)
o = m(torch.randn(2, 3, 224, 224))
assert o.shape == (2, 4)
print('Forward pass OK:', o.shape)
"

Google Colab

The notebooks in notebooks/ are designed for Colab. They mount Google Drive for data access and install the package in development mode. Use the _Simplified.ipynb versions for a cleaner workflow -- they import from the package rather than inlining all definitions.


Project structure

ChloraScan/
  README.md
  pyproject.toml
  requirements.txt
  paper/
    M3FS_Net_Paper.md
  notebooks/
    M3FS_Net.ipynb                     # Original with cell outputs
    M3FS_Net_Simplified.ipynb          # Clean version, imports package
    DL_Benchmark.ipynb
    DL_Benchmark_Simplified.ipynb
    Traditional_ML_Benchmark.ipynb
    Traditional_ML_Benchmark_Simplified.ipynb
  src/chlorascan/
    __init__.py
    config.py                          # All constants, paths, hyperparameters
    data/
      dataset.py                       # DiseaseDataset (torch.utils.data.Dataset)
      transforms.py                    # Train/eval/TTA augmentation pipelines
      utils.py                         # Split, cache, verify, DataLoader builders
    models/
      se_block.py                      # Squeeze-and-Excitation channel attention
      enhanced_3db.py                  # Enhanced Three Deep Block (multi-scale)
      hsafs.py                         # Hybrid Statistical-Attention Feature Selection
      m3fs_net.py                      # M3FS-Net full model + factory function
      baselines.py                     # DenseNet121, MobileNetV2, VGG16, InceptionV3, EfficientNetB4
    training/
      trainer.py                       # Generic Trainer with early stopping + resume
      metrics.py                       # Macro-averaged one-vs-rest metrics
      tta.py                           # 5-view test-time augmentation evaluator
    traditional_ml/
      features.py                      # HOG and PCA feature extraction
      models.py                        # SVM, RF, LGBM, MLP, LR model factories
    utils/
      seed.py                          # Deterministic seed setting
      checkpoint.py                    # Path helpers for checkpoints and results
    scripts/
      train_m3fs_net.py                # CLI: train-m3fs-net
      train_dl_baselines.py            # CLI: train-dl-baselines
      train_traditional_ml.py          # CLI: train-traditional-ml
  scripts/                             # Standalone runner scripts (same content)
  results/                             # Output directory (checkpoints, JSON results, cache)

Configuration

All settings live in src/chlorascan/config.py. Override via environment variables:

Variable Default Purpose
CHLORASCAN_DATA_ROOT ./data/ Root path for dataset directories
CHLORASCAN_RESULTS_DIR ./results/ Output directory for checkpoints and results
CHLORASCAN_SEED 42 Random seed for reproducibility

Limitations

  • TTA inference cost. Five forward passes per image during evaluation. Single-pass accuracy without TTA is 1-5 percentage points lower depending on dataset.
  • Small dataset sensitivity. CDS (510 samples/class) showed the highest training variance across hyperparameter configurations.
  • Independent channels. The two branches operate without cross-channel attention. Feature-level communication before fusion could further improve complementarity.
  • No explainability. The current pipeline does not provide GradCAM or attention-map visualizations for practitioners.

Model parameters

Variant Total Trainable Frozen
4-class (MDS, CDS) 3,948,228 3,405,700 542,528
10-class (TDDS) 3,949,770 3,407,242 542,528

The 542,528 frozen parameters belong to MobileNetV2 layers 0-13. The ~3.4M trainable parameters include MobileNetV2 layers 14-18, the entire E3DB, both HSAFS selectors, the fusion convolution, and the classifier head.


License

MIT. See LICENSE file.

About

M3FS-Net: A dual-channel CNN with multi-scale feature extraction and hybrid statistical-attention channel selection for crop disease identification under wild field conditions.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors