This document outlines the architectural upgrades and future development plans for the torch_ecg library.
Objective: Eliminate lengthy if-elif branches in downstream models like ECG_CRNN to enhance code maintainability and extensibility.
- Status: Done.
MODELS,BACKBONES,ATTN_LAYERSregistries are implemented inmodels/registry.py;OPTIMIZERS,SCHEDULERS,LOSSESincomponents/registry.py;PREPROCESSORSinpreprocessors/registry.py. All CNN backbones and downstream models use@BACKBONES.register()/@MODELS.register()decorators.Registry.build(name, **kwargs)is the unified construction interface. Adding a new backbone no longer requires modifyingmodels/ecg_crnn.pyor any other existing file. - Strategy: Implement a registry mechanism similar to the one used in the
fl-simlibrary. EstablishBACKBONES,MODELS, andSSLregistries.- Use decorators like
@register_backbone("resnet")to register modules. - Use a unified
BACKBONES.build(name, **kwargs)method for module instantiation.
- Use decorators like
- Benefits: Decouples model definition from construction logic, making it easier for both maintainers and users to inject custom backbones.
Objective: Provide a unified feature extraction interface for Self-Supervised Learning (SSL) and multi-task learning.
- Status: Done. All CNN backbones (
ResNet,VGG16,DenseNet,MobileNetV1/V2/V3,MultiScopicCNN,RegNet,Xception) and theTransformernow implementforward_features(x)(returns feature maps before the classifier head) andcompute_features_output_shape(seq_len, batch_size)for shape inference without running a forward pass. - Strategy: Follow the convention used in the
timmlibrary by providing aforward_features(x)method for all CNN and Transformer backbones. - Features:
- Unified return of feature maps instead of classification logits.
- Support for accessing intermediate activations for feature fusion or saliency analysis (e.g., Grad-CAM).
Objective: Reduce the burden of manually calculating and specifying in_channels in configuration files.
- Strategy: Introduce
nn.LazyLinearornn.LazyConv1din complex SSL modules. - Benefits: Simplifies
model_configsby allowing the model to automatically infer input dimensions during the first forward pass, reducing boilerplate code.
Objective: Eliminate redundancy between NumPy and PyTorch implementations and optimize performance by keeping computations on the GPU.
- Pure PyTorch Filtering ✅:
BandPassnow uses a zero-phase FFT-based filter andBaselineRemoveuses dualavg_pool1d, both implemented inutils/utils_signal_t.py. No more CPU-GPU data transfers. - Unification of Managers ⬜: Refactor
PreprocManagerandAugmenterManagerto share a common base or registry, as their logic for managing sequences of transforms is very similar. - Dimension Agnostic Transforms ⬜ (augmenters pending): Preprocessors now handle arbitrary leading batch dimensions (
..., n_leads, siglen). Augmenters still need to be updated. - Numpy Version Maintenance ✅:
_preprocessors(NumPy) is kept only for offline data preparation or deployment environments without PyTorch, while making the PyTorchpreprocessorsthe primary choice for training pipelines.
Objective: Ensure compatibility with Pandas 3.0+, particularly concerning Arrow-backed strings and stricter type checking.
- Explicit Object Dtypes: Explicitly set
dtype=objectfor DataFrame columns intended to hold list-like objects (e.g., diagnoses, available signals) to prevent errors when Arrow-backed strings are used. - Initialization Refactoring: Replace patterns of initializing columns with
Noneor""and then populating with lists via.at[]with more robust initializations like[[] for _ in range(len(df))]or using.apply(). - Vectorized Operations: Prefer
.apply()or other vectorized pandas operations overiterrows()loops for better performance and type consistency.
Several model files in torch_ecg/models/ are currently stubs (raise NotImplementedError throughout). All stubs already inherit the correct mixins (SizeMixin, CitationMixin) and have backbone API signatures (forward_features, compute_features_output_shape) scaffolded. These should be completed before the SSL phase.
| File | Classes | Notes |
|---|---|---|
models/cnn/darknet.py |
DarkNet |
Backbone for YOLO-style detection |
models/cnn/efficientnet.py |
EfficientNet, EfficientNetV2 |
Mobile-efficient compound scaling |
models/cnn/ho_resnet.py |
MidPointResNet, RK4ResNet, RK8ResNet |
Higher-Order ODE ResNets |
models/grad_cam.py |
GradCam |
Saliency analysis (logic partially drafted, not wired up) |
models/ecg_fcn.py |
ECG_FCN |
Fully-convolutional segmentation (fully commented out) |
The torch_ecg/ssl/ module is an empty shell. ssl/README.md contains a survey of target architectures (CLOCS, ST-MEM, MAE-ECG, SimCLR, TF-C, 3M-ECG, CMSC, ECG-BERT).
Prerequisites: Item 3 (Lazy Modules) and the model stubs above should be completed first.
- Define base classes:
BaseContrastiveLearner,BaseMaskedAutoencoder(inssl/base.py). - Implement a base contrastive learning framework (supporting paradigms like SimCLR, MoCo).
- Implement Masked Autoencoder (MAE) logic optimized for 1D physiological signals.
- Official implementation of classic models: CLOCS, ST-MEM, MAE-ECG.
- Provide a unified Fine-tuning API for seamless integration with downstream tasks in
torch_ecg.models. - Add
SSLregistry (analogous toMODELS,BACKBONES).
- Improve test coverage, especially for the newly introduced SSL and Transformer modules.
- Optimize automated documentation generation.
- Release more pre-trained weights for SOTA models from PhysioNet/CinC challenges.