From 99f61284ff5093aabff18e7987ca357303865c66 Mon Sep 17 00:00:00 2001 From: Luke Mainwaring Date: Tue, 26 May 2026 12:09:17 -0400 Subject: [PATCH] Review codebase after reading NeuralBench paper --- backend/src/cortexdj/ml/contrastive.py | 10 ++++++---- backend/src/cortexdj/ml/contrastive_dataset.py | 10 ++++++++-- backend/src/cortexdj/ml/dataset.py | 13 ++++++++++--- backend/src/cortexdj/ml/predict.py | 4 ++-- docs/pretrained-models-analysis.md | 12 +++++++++++- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/backend/src/cortexdj/ml/contrastive.py b/backend/src/cortexdj/ml/contrastive.py index e44787d..cda88b6 100644 --- a/backend/src/cortexdj/ml/contrastive.py +++ b/backend/src/cortexdj/ml/contrastive.py @@ -85,7 +85,7 @@ def __init__(self, *, projection_dim: int = EMBEDDING_DIM, dropout: float = 0.3) ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Take (B, 32, 800) raw EEG at 200Hz → (B, 512) unit vectors.""" + """Take (B, 32, 800) EEG at 200Hz, pre-scaled by `CBRAMOD_SCALE_FACTOR` → (B, 512) unit vectors.""" out = self.backbone(x, return_features=True) features = out["features"] # (B, ..., backbone_embed_dim) pooled = features.reshape(features.shape[0], -1, features.shape[-1]).mean(dim=1) @@ -225,9 +225,11 @@ def retrieval_metrics( def encode_session(model: EegCLAPEncoder, segments: np.ndarray, device: torch.device) -> np.ndarray: """Aggregate a session's EEG windows into a single 512-d query vector. - `segments` is a `(n_segments, 32, 800)` float32 array of already-resampled - 4-second EEG windows. Returns an L2-normalized numpy vector suitable for - pgvector cosine similarity search. + `segments` is a `(n_segments, 32, 800)` float32 array of 4-second EEG + windows at 200Hz, pre-scaled to CBraMod's input range (callers should go + through `trial_to_eeg_windows`, which applies `CBRAMOD_SCALE_FACTOR`). + Returns an L2-normalized numpy vector suitable for pgvector cosine + similarity search. """ model.eval() tensor = torch.from_numpy(segments.astype(np.float32)).to(device) diff --git a/backend/src/cortexdj/ml/contrastive_dataset.py b/backend/src/cortexdj/ml/contrastive_dataset.py index 7fee9a8..27cab81 100644 --- a/backend/src/cortexdj/ml/contrastive_dataset.py +++ b/backend/src/cortexdj/ml/contrastive_dataset.py @@ -32,6 +32,7 @@ from cortexdj.ml.contrastive import CLAP_MODEL_ID, ClapAudioEncoder, load_audio_waveform from cortexdj.ml.dataset import ( CBRAMOD_SAMPLING_RATE, + CBRAMOD_SCALE_FACTOR, SEGMENT_SAMPLES, _cache_dir, _extract_participant_id, @@ -40,6 +41,10 @@ logger = logging.getLogger(__name__) +# Gates the CLAP audio embedding cache only (see `_audio_cache_key`). EEG-side +# preprocessing has no persistent cache in this file — `DeapClapPairDataset` +# reads `.dat` fresh each construction — so bump this only when the audio +# embedding pipeline changes. _CACHE_VERSION = "v2" STIMULI_RESOLVED_PATH = DATA_DIR / "deap_stimuli_resolved.json" @@ -57,7 +62,8 @@ def trial_to_eeg_windows( Shared by the training dataset and the runtime retrieval service so both sides produce windows with identical shape and preprocessing. `trial_data` is expected as (n_channels, n_samples) at DEAP's 128Hz sampling rate; - returns (n_windows, n_channels, target_segment_samples) at 200Hz. + returns (n_windows, n_channels, target_segment_samples) at 200Hz, scaled by + `CBRAMOD_SCALE_FACTOR` to match the pretrained input distribution. """ n_samples = trial_data.shape[1] n_segments = n_samples // source_segment_samples @@ -67,7 +73,7 @@ def trial_to_eeg_windows( end = start + source_segment_samples segment = trial_data[:, start:end] # (n_channels, source_segment_samples) resampled = resample(segment, target_segment_samples, axis=1).astype(np.float32) - windows.append(resampled) + windows.append(resampled * CBRAMOD_SCALE_FACTOR) if not windows: return np.zeros((0, trial_data.shape[0], target_segment_samples), dtype=np.float32) return np.stack(windows, axis=0) diff --git a/backend/src/cortexdj/ml/dataset.py b/backend/src/cortexdj/ml/dataset.py index 9e1976e..18ac008 100644 --- a/backend/src/cortexdj/ml/dataset.py +++ b/backend/src/cortexdj/ml/dataset.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) # Bump when feature extraction or labeling logic changes. -_CACHE_VERSION = "v3" +_CACHE_VERSION = "v4" NUM_EEG_CHANNELS = 32 SEGMENT_SAMPLES = DEFAULT_SAMPLING_RATE * 4 # 4-second segments (512 samples at 128Hz) @@ -34,6 +34,13 @@ AROUSAL_THRESHOLD = 5.0 VALENCE_THRESHOLD = 5.0 +# Multiply DEAP µV signals by this before feeding CBraMod. Without it, CBraMod +# sees signals ~100× too large and its pretrained features operate +# off-distribution. Derivation: NeuralBench's CBraMod recipe (neuralbench-repo +# `neuralbench/models/cbramod.yaml`: scaler=null, scale_factor=10000) applies +# 1e4 to MNE volts; the µV equivalent is µV × 1e-6 V/µV × 1e4 = µV × 1e-2. +CBRAMOD_SCALE_FACTOR = 0.01 + # Label binarization strategy for DEAP's 1-9 Likert self-reports. # # `median_per_subject` (default): splits each axis at that subject's own @@ -379,8 +386,8 @@ def _load_participant(self, file_path: Path, thresholds: dict[int, tuple[float, end = start + self.source_segment_samples segment = trial_data[:, start:end] # (32, 512) - # Resample 128Hz -> target (200Hz for CBraMod) - resampled = resample(segment, self.target_segment_samples, axis=1) + # Resample 128Hz -> target (200Hz for CBraMod), scale to match pretrained input range. + resampled = resample(segment, self.target_segment_samples, axis=1) * CBRAMOD_SCALE_FACTOR self.samples.append((resampled, arousal_label, valence_label)) self.participant_ids.append(participant_id) diff --git a/backend/src/cortexdj/ml/predict.py b/backend/src/cortexdj/ml/predict.py index 6db64ef..c111292 100644 --- a/backend/src/cortexdj/ml/predict.py +++ b/backend/src/cortexdj/ml/predict.py @@ -9,7 +9,7 @@ from scipy.signal import resample from cortexdj.core.paths import CHECKPOINTS_DIR -from cortexdj.ml.dataset import CBRAMOD_SEGMENT_SAMPLES, scores_to_quadrant +from cortexdj.ml.dataset import CBRAMOD_SCALE_FACTOR, CBRAMOD_SEGMENT_SAMPLES, scores_to_quadrant from cortexdj.ml.model import EEGNetClassifier from cortexdj.ml.preprocessing import compute_band_powers, extract_features from cortexdj.ml.pretrained import PretrainedDualHead, load_pretrained_dual_head @@ -85,7 +85,7 @@ def predict_segment( ) -> EEGPredictionResult: """Run inference on a single EEG segment (n_channels x n_samples).""" if isinstance(model, PretrainedDualHead): - resampled = resample(eeg_data, CBRAMOD_SEGMENT_SAMPLES, axis=1) + resampled = resample(eeg_data, CBRAMOD_SEGMENT_SAMPLES, axis=1) * CBRAMOD_SCALE_FACTOR input_tensor = torch.tensor(resampled, dtype=torch.float32).unsqueeze(0) else: features = extract_features(eeg_data) diff --git a/docs/pretrained-models-analysis.md b/docs/pretrained-models-analysis.md index 31d03ac..6454fae 100644 --- a/docs/pretrained-models-analysis.md +++ b/docs/pretrained-models-analysis.md @@ -88,9 +88,19 @@ model.push_to_hub("username/cortexdj-emotion-cbramod") model = CBraMod.from_pretrained("username/cortexdj-emotion-cbramod") ``` +## External validation: NeuralBench-EEG (Meta FAIR, 2026) + +[Banville et al., 2026](https://github.com/facebookresearch/neuroai/tree/main/neuralbench-repo) benchmarked 14 EEG architectures across 36 tasks / 94 datasets under a single standardized recipe. Findings that bear on the recommendations above: + +- **Foundation models only marginally beat task-specific from-scratch models.** REVE/LaBraM/LUNA lead the ranking, but **CTNet (150K params, task-specific)** ranks 4th — beating CBraMod (4.9M) and three other foundation models. The gap is narrow enough that adding more datasets per task flips the order. Validates the "Emotion-specific baselines" framing: TSception/DGCNN/EEGConformer aren't just sanity checks, they're plausibly competitive. +- **CBraMod sits mid-pack (5th of 14)** in the Core ranking, just behind CTNet and SimpleConvTimeAgg. REVE (#1, 69M params, pretrained on 60K hours) outperforms by a real but narrow margin — supports the Tier 1 ordering and suggests the upside of swapping CBraMod → REVE is bounded. +- **Cross-subject is where everything gets hard.** Motor imagery, P300, and N2pc collapse to near-dummy performance under cross-subject splits. Our LOSO CV sits in the same subject-disjoint family — actually stricter than NeuralBench's 20%-subject holdout — so DEAP accuracies in the literature that used within-subject splits are not apples-to-apples comparables. +- **Each foundation model expects a specific input distribution.** CBraMod's recipe specifies 200 Hz / 0.3–75 Hz / `scale_factor=10000` against MNE volts (≈ ×0.01 against DEAP µV). The cortexdj data path now applies this scale via `CBRAMOD_SCALE_FACTOR` in `ml/dataset.py`; the bandpass mismatch (DEAP is pre-filtered to 4–45 Hz) is not recoverable without raw `.bdf` files. +- **Standardized downstream recipe** for reference: AdamW lr=1e-4, wd=0.05, cosine + 10% warmup, ≤50 epochs, end-to-end finetune, linear probe on mean-pooled tokens. Diverges from cortexdj's current MLP-head + two-phase freeze→unfreeze schedule; worth running as a baseline comparator before committing to either. + ## Open Questions -1. **Sampling rate mismatch** — DEAP preprocessed data is 128 Hz; most pretrained models expect 200–256 Hz. Resampling is straightforward but may affect pretrained representations. Needs benchmark with and without resampling. +1. **Sampling rate mismatch** — DEAP preprocessed data is 128 Hz; most pretrained models expect 200–256 Hz. Resampling is straightforward but may affect pretrained representations. Needs benchmark with and without resampling. *Partial answer from NeuralBench: each foundation model has a model-specific preprocessing recipe that resampling alone won't satisfy — see "Each foundation model expects a specific input distribution" above.* 2. **Channel mapping** — DEAP uses 32 channels in a specific montage. Flexible-channel models (CBraMod, REVE, LUNA) handle this via positional encoding, but accuracy impact needs measurement. Concrete API: braindecode's [`plot_channel_interpolation`](https://braindecode.org/stable/auto_examples/model_building/plot_channel_interpolation.html) example. Practical extremes to ablate against: DREAMER (14 ch, Emotiv) and MUSIN-G (128 ch, HGSN) — see [datasets-analysis.md](datasets-analysis.md).