Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,6 @@ API Reference
models/pyhealth.models.VisionEmbeddingModel
models/pyhealth.models.TextEmbedding
models/pyhealth.models.BIOT
models/pyhealth.models.CBraMod_Wrapper
models/pyhealth.models.unified_multimodal_embedding_docs
models/pyhealth.models.califorest
67 changes: 67 additions & 0 deletions docs/api/models/pyhealth.models.cbramod.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
pyhealth.models.CBraMod_Wrapper
===================================

CBraMod model for EEG signal classification.

Overview
--------

CBraMod is a criss-cross attention transformer tailored for EEG decoding. The
wrapper integrates the model into the PyHealth ``BaseModel`` pipeline so it can
be trained with the standard ``Trainer`` APIs.

Input/Output
------------

- **Input:** ``signal`` tensor shaped ``(batch, channels, timesteps)`` where
``timesteps`` is a multiple of 200 (the patch size used by CBraMod).
- **Output (classifier_head=True):** dict with ``loss``, ``y_prob``, ``y_true``,
``logit``, and ``embeddings``.
- **Output (classifier_head=False):** dict with ``logit`` and ``embeddings``.

Example Usage
-------------

.. code-block:: python

import torch
from pyhealth.datasets import create_sample_dataset, get_dataloader
from pyhealth.models import CBraMod_Wrapper

n_channels = 16
patch_size = 200
n_patches = 10
n_samples = patch_size * n_patches

samples = [
{
"patient_id": f"patient-{i}",
"visit_id": "visit-0",
"signal": torch.randn(n_channels, n_samples).numpy().tolist(),
"label": i % 6,
}
for i in range(8)
]

dataset = create_sample_dataset(
samples=samples,
input_schema={"signal": "tensor"},
output_schema={"label": "multiclass"},
dataset_name="test_cbramod",
)

model = CBraMod_Wrapper(
dataset=dataset,
seq_len=n_patches,
n_classes=6,
classifier_head=True,
)

batch = next(iter(get_dataloader(dataset, batch_size=2, shuffle=True)))
output = model(**batch)
print(output["logit"].shape)

.. autoclass:: pyhealth.models.CBraMod_Wrapper
:members:
:undoc-members:
:show-inheritance:
6 changes: 6 additions & 0 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,13 @@ EEG and Sleep Analysis
* - ``EEG_events_SparcNet.py``
- SparcNet for EEG event detection
* - ``EEG_isAbnormal_SparcNet.py``
<<<<<<< HEAD
- SparcNet for EEG abnormality detection
* - ``CBraMod_tuab_eeg_abnormal_classification.py``
- CBraMod for EEG abnormality detection on TUAB
=======
- SparcNet for EEG abnormality detection
>>>>>>> origin/master
* - ``cardiology_detection_isAR_SparcNet.py``
- SparcNet for cardiology arrhythmia detection

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pyhealth.datasets import TUABDataset, split_by_visit, get_dataloader
from pyhealth.tasks import EEGAbnormalTUAB
from pyhealth.models import CBraMod_Wrapper
from pyhealth.trainer import Trainer

# step 1: load signal data
dataset = TUABDataset(
root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/",
dev=True,
refresh_cache=True,
)
print(dataset.stats())

# step 2: set task (disable STFT for CBraMod)
TUAB_ds = dataset.set_task(
EEGAbnormalTUAB(
resample_rate=200,
bandpass_filter=(0.1, 75.0),
notch_filter=50.0,
compute_stft=False,
)
)

print(f"Total task samples: {len(TUAB_ds)}")
print(f"Input schema: {TUAB_ds.input_schema}")
print(f"Output schema: {TUAB_ds.output_schema}")

# Inspect a sample to infer sequence length
sample = TUAB_ds[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

seq_len = sample["signal"].shape[-1] // 200

# split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
TUAB_ds, [0.6, 0.2, 0.2]
)
train_dataloader = get_dataloader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=16, shuffle=False)
print(
"loader size: train/val/test",
len(train_dataset),
len(val_dataset),
len(test_dataset),
)

# step 3: define model
model = CBraMod_Wrapper(
dataset=TUAB_ds,
seq_len=seq_len,
n_classes=2,
classifier_head=True,
)

# step 4: define trainer
trainer = Trainer(model=model, device="cuda:0")
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=10,
optimizer_params={"lr": 1e-4},
)

# step 5: evaluate
print(trainer.evaluate(test_dataloader))
86 changes: 43 additions & 43 deletions pyhealth/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,43 @@
from .binary import binary_metrics_fn
from .drug_recommendation import ddi_rate_score
from .generative import (
calc_membership_inference,
calc_nnaar,
compute_discriminator_privacy,
compute_mle,
compute_prevalence_metrics,
evaluate_synthetic_ehr,
)
from .interpretability import (
ComprehensivenessMetric,
Evaluator,
RemovalBasedMetric,
SufficiencyMetric,
evaluate_attribution,
)
from .multiclass import multiclass_metrics_fn
from .multilabel import multilabel_metrics_fn
# from .fairness import fairness_metrics_fn
from .ranking import ranking_metrics_fn
from .regression import regression_metrics_fn
__all__ = [
"binary_metrics_fn",
"ddi_rate_score",
"calc_nnaar",
"calc_membership_inference",
"compute_discriminator_privacy",
"compute_mle",
"compute_prevalence_metrics",
"evaluate_synthetic_ehr",
"ComprehensivenessMetric",
"SufficiencyMetric",
"RemovalBasedMetric",
"Evaluator",
"evaluate_attribution",
"multiclass_metrics_fn",
"multilabel_metrics_fn",
"ranking_metrics_fn",
"regression_metrics_fn",
]
from .binary import binary_metrics_fn
from .drug_recommendation import ddi_rate_score
from .generative import (
calc_membership_inference,
calc_nnaar,
compute_discriminator_privacy,
compute_mle,
compute_prevalence_metrics,
evaluate_synthetic_ehr,
)
from .interpretability import (
ComprehensivenessMetric,
Evaluator,
RemovalBasedMetric,
SufficiencyMetric,
evaluate_attribution,
)
from .multiclass import multiclass_metrics_fn
from .multilabel import multilabel_metrics_fn

# from .fairness import fairness_metrics_fn
from .ranking import ranking_metrics_fn
from .regression import regression_metrics_fn

__all__ = [
"binary_metrics_fn",
"ddi_rate_score",
"calc_nnaar",
"calc_membership_inference",
"compute_discriminator_privacy",
"compute_mle",
"compute_prevalence_metrics",
"evaluate_synthetic_ehr",
"ComprehensivenessMetric",
"SufficiencyMetric",
"RemovalBasedMetric",
"Evaluator",
"evaluate_attribution",
"multiclass_metrics_fn",
"multilabel_metrics_fn",
"ranking_metrics_fn",
"regression_metrics_fn",
]
3 changes: 2 additions & 1 deletion pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base_model import BaseModel
from .transformer_deid import TransformerDeID
from .biot import BIOT
from .cbramod import CBraMod_Wrapper
from .cnn import CNN, CNNLayer
from .concare import ConCare, ConCareLayer
from .contrawr import ContraWR, ResBlock2D
Expand Down Expand Up @@ -50,4 +51,4 @@
from .generators.gpt2 import GPT2
from .generators.promptehr import PromptEHR
from .generators.medgan import MedGAN
from .generators.corgan import CorGAN
from .generators.corgan import CorGAN
Loading