Skip to content

TravisHaa/gripvector

Repository files navigation

GripVector

GripVector is a Python toolkit for EMG-based gesture recognition using the Myo armband. It covers the full pipeline: raw data capture, preprocessing, model training (MLP and U-Net–style encoder), and live inference with optional 8-direction “compass” interpretation for directional gestures.


Table of Contents


Overview

  • Hardware: Myo armband (8-channel EMG, 200 Hz raw mode). Uses a USB Bluetooth dongle (e.g. PID=2458).
  • Labels: Six base classes — neutral, north, south, east, west, fist — with optional 8-direction (N/NE/E/SE/S/SW/W/NW) interpretation for live “compass” output.
  • Models: MLP baseline (flattened windows) and UnetEncoder (1D conv stem + encoder blocks + classification head).
  • Training: Session-based CSVs, per-session or global normalization, windowing, optional augmentation (channel roll, gain, dropout). Checkpoints include model config and normalization stats for deployment.

Requirements

  • Python: ≥ 3.10
  • Hardware: Myo armband + compatible Bluetooth dongle (e.g. Nordic USB dongle, PID 2458)
  • OS: Serial port access (tested on macOS/Linux; Windows may need driver/permission handling for multithreading)

Installation

Clone the repo and install in development mode (e.g. with uv):

git clone <repo-url>
cd gripvector
uv sync

Or with pip:

pip install -e .

Dependencies (see pyproject.toml): pandas, pyserial, torch, tqdm, wandb. Dev: matplotlib, scikit-learn.


Project Structure

gripvector/
├── src/gripvector/           # Main package
│   ├── __init__.py          # Public API (Myo, record_*, LABELS, classifier components)
│   ├── hardware.py          # Myo BLE/serial protocol (RAW/PREPROCESSED/FILTERED modes)
│   ├── data_worker.py       # Raw CSV recording with countdown (ch1–ch8, label)
│   ├── labeled_data_worker.py # Labeled recording with session/segment metadata
│   └── classifier/
│       ├── config.py        # TrainingConfig, UnetEncoderConfig, MLPConfig
│       ├── data.py          # EMGDataset, concat_datasets (session grouping)
│       ├── preprocess.py    # Envelope, windowing, channel standardization
│       ├── train.py         # train(), WandbLogger, evaluate
│       ├── mlp.py           # MLP
│       ├── unet_encoder.py  # UnetEncoder (stem + encoders + head)
│       └── utils.py         # create_df, df_xy_split, split_by_class, COMPASS
├── scripts/
│   ├── record_raw_data.py  # CLI: record static holds (RAW) to CSV
│   ├── live_classify.py    # Live 6-way classification (no compass)
│   ├── compass.py          # Live 6-way + 8-direction compass (aggregated)
│   ├── infer.py            # Same as compass (live compass UI)
│   ├── check_data_quality.py # Report per-file, per-class counts in data/
│   ├── pca_plot.py         # 2D PCA of EMG (expects ch1–ch8, label)
│   └── pca_plot_3d.py      # 3D PCA (expects Channel_1–Channel_8, label)
├── train/
│   ├── mlp_baseline.py     # Single MLP training run
│   ├── unet_encoder.py     # Single UnetEncoder training run
│   └── sweep_unet_encoder.py # W&B sweep over hyperparameters
└── pyproject.toml

Quick Start

  1. Record data (Myo in RAW mode, 200 Hz):

    uv run scripts/record_raw_data.py data/my_session.csv -l 8 -r 2

    Records each of the six classes for 8 seconds with 2 s rest between classes. Output CSV has columns timestamp, session_id, segment_id, sample_index, label, ch1ch8.

  2. Check data:

    uv run scripts/check_data_quality.py

    Scans data/*.csv and prints a table of rows per file and per label.

  3. Train (edit paths in train/unet_encoder.py or train/mlp_baseline.py, then):

    uv run train/unet_encoder.py

    Checkpoints and W&B logs go to runs/<run_name>/.

  4. Live inference with compass UI:

    uv run scripts/compass.py --ckpt runs/<run_name>/BEST_acc99.pt --device cuda

Data Collection

Recording formats

  • Raw recorder (record_static_holds_raw_with_countdown / scripts/record_raw_data.py):
    Writes RAW 200 Hz EMG with columns ch1ch8 and label. This format is what the classifier pipeline expects (EMG_COLS = ch1..ch8).

  • Labeled recorder (record_labeled_dataset in labeled_data_worker):
    Adds session_id, segment_id, segment_elapsed_s, emg_index and uses column names Channel_1Channel_8. For training, either rename columns to ch1ch8 or point EMG_COLS at the correct names.

Session naming for training

For multi-session train/eval with concat_datasets, CSVs should follow:

  • Train: {name}_session{sid}_train{N}.csv
  • Eval: {name}_session{sid}_test.csv

Example: andy_session1_train1.csv, andy_session1_test.csv. Sessions are grouped by {name}_session{sid}; eval uses that session’s (or global) normalization stats.

Scripts

  • scripts/record_raw_data.py
    • Positional: csv (output path).
    • -l/--seconds_per_class (default 5), -r/--rest_seconds (default 2).
    • Calls record_static_holds_raw_with_countdown(output_csv, seconds_per_class, rest_seconds).

Preprocessing

Pipeline (see classifier/preprocess.py and classifier/data.py):

  1. Load CSVs → one DataFrame (e.g. via create_df).
  2. Split into EMG columns and label (df_xy_split).
  3. Envelope: RAW int8 → float → rectify → EMA low-pass (e.g. 4 Hz) → optional soft clip (preprocess_envelope).
  4. Per-channel standardization: fit_channel_standardizer on train data (returns mean, std, and mergeable stats); apply_channel_standardizer on train/eval.
  5. Windowing: make_windows(x, window_duration_s, step_duration_s, sampling_rate)(Nw, Tw, C).
  6. Dataset: EMGDataset splits by class, windows, and optionally applies augmentation (channel roll, gain, dropout) when training_stats is None.

Typical values: 200 Hz, window 0.2–0.4 s, step 0.05–0.1 s, envelope cutoff 4 Hz.


Training

Single runs

  • UnetEncoder: Edit train/unet_encoder.py (paths, UnetEncoderConfig, TrainingConfig), then:

    uv run train/unet_encoder.py
  • MLP: Edit train/mlp_baseline.py (paths, MLPConfig, TrainingConfig), then:

    uv run train/mlp_baseline.py

concat_datasets(model_config) expects session-named train/eval CSVs and returns (train_set, eval_set, norm_stats). Checkpoints are written under run_dir and include model_state, model_config, training_config, norm_stats, and eval_acc.

Hyperparameter sweep

  • train/sweep_unet_encoder.py runs a W&B sweep over window size, hop ratio, base, mlp_hidden_dims, dropout, batch_size, weight_decay, grad_clip.

  • Create a sweep in the W&B project, then:

    uv run train/sweep_unet_encoder.py -s <sweep_id>

Inference

All live scripts expect a UnetEncoder checkpoint (.pt) that contains model_config, model_state, and norm_stats (with at least mean and std; compass/infer can use norm_stats["global"]).

6-way only

  • scripts/live_classify.py
    • --ckpt, --device (default cuda), --print_hz, --timeout.
    • Optional: --window, --step, --fs, --cutoff.
    • Streams Myo RAW → envelope → z-score with checkpoint stats → model → 6-way probs and top label.

6-way + 8-direction compass

  • scripts/compass.py (or scripts/infer.py)
    • Same as above, plus:
    • --dir_temp (direction softmax temperature), --aggregate_n (running mode / mean over last N predictions).
    • Builds a 2D direction from N/S/E/W logits, maps to N/NE/E/SE/S/SW/W/NW and shows an emoji compass plus 6-way bars.

Offline / batch

Use the same preprocessing as in EMGDataset / live_classify: load CSV → envelope → normalize with checkpoint norm_stats → window → model(x). No dedicated batch script is provided; you can call UnetEncoder and apply_channel_standardizer from the package.


API Summary

High-level imports:

from gripvector import (
    Myo,
    emg_mode,
    record_static_holds_raw_with_countdown,
    record_labeled_dataset,
    LABELS,
    EMG_COLS,
    preprocess_envelope,
    make_windows,
    fit_channel_standardizer,
    apply_channel_standardizer,
    create_df,
    df_xy_split,
    split_by_class,
    EMGDataset,
    concat_datasets,
    train,
    TrainingConfig,
    UnetEncoderConfig,
    UnetEncoder,
    MLPConfig,
    MLP,
    COMPASS,
)
  • Myo: Myo(tty=None, mode=emg_mode.RAW), .connect(), .add_emg_handler(fn), .run(), .disconnect().
  • Recording: record_static_holds_raw_with_countdown(output_csv, seconds_per_class=8, rest_seconds=2, ...); record_labeled_dataset(mode, output_csv, seconds_per_class=10, ...).
  • Data: create_df(csvs), df_xy_split(df), split_by_class(X, y), EMGDataset(...), concat_datasets(model_config).
  • Training: train(model, train_set, eval_set, norm_stats, training_config).

Configuration Reference

  • TrainingConfig: num_epochs, batch_size, seed, lr, weight_decay, grad_clip, device, run_dir, wandb, wandb_project, log_every.
  • UnetEncoderConfig: train_csvs, eval_csvs, window_duration_s, step_duration_s, sampling_rate, num_classes, Cin (8), base, mlp_hidden_dims, dropout, activation.
  • MLPConfig: layer_dims, train_csvs, eval_csvs, window_duration_s, step_duration_s, sampling_rate, num_classes, activation, dropout, activate_last.

Checkpoints store model_config as a dict (e.g. asdict(UnetEncoderConfig)), so loading can use UnetEncoder(**model_cfg) or UnetEncoder(UnetEncoderConfig(**model_cfg)).


License & Attribution

  • Myo / hardware: src/gripvector/hardware.py is derived from dzhu/myo-raw, with edits by Fernando Cosentino, Alvaro Villoslada (Alvipe), and PerlinWarp (pyomyo). See the MIT license and attribution block inside hardware.py.
  • GripVector (data pipeline, models, scripts, README): add your own license as needed.

Tips

  • Use RAW mode (200 Hz) for training and live inference so preprocessing is consistent.
  • Keep session_id and naming consistent so concat_datasets and eval normalization work as intended.
  • For compass/vector-8, aggregation (--aggregate_n) and --dir_temp help smooth direction output.
  • If the serial port blocks, the recorder and live scripts set myo.bt.ser.timeout to avoid indefinite hangs.

About

NECL research

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors