This project is a continuing effort after SwiFT and the official code repo for 'Predicting task-related brain activity from resting-state brain dynamics with fMRI Transformer.' Feel free to ask the authors any questions regarding this project.
Contact
- First author
- Junbeom Kwon: kjb961013@snu.ac.kr
- Corresponding author
- Professor Jiook Cha: connectome@snu.ac.kr
Effective usage of this repository requires learning a couple of technologies: PyTorch, PyTorch Lightning. Knowledge of some experiment logging frameworks like Weights&Biases, Neptune is also recommended.
This repository implements the SwiFUN (SwiFUN).
- Our code offers the following things.
- Trainer based on PyTorch Lightning for running SwiFT and SwiFUN (same as Swin UNETR).
- Data preprocessing/loading pipelines for 4D fMRI datasets.
We highly recommend you to use our conda environment.
# clone project
git clone https://github.com/Transconnectome/SwiFUN.git
# install project
cd SwiFT
conda env create -f envs/py39.yaml
conda activate py39Our directory structure looks like this:
βββ notebooks <- Useful Jupyter notebook examples are given (TBU)
βββ output <- Experiment log and checkpoints will be saved here
βββ project
β βββ module <- Every module is given in this directory
β β βββ models <- Models (Swin fMRI Transformer)
β β βββ utils
β β β βββ data_module.py <- Dataloader & codes for matching fMRI scans and target variables
β β β βββ data_preprocessing_and_load
β β β βββ datasets.py <- Dataset Class for each dataset
β β β βββ preprocessing.py <- Preprocessing codes for step 6
β β βββ pl_classifier.py <- LightningModule
β βββ main.py <- Main code that trains and tests the 4DSwinTransformer model
β
βββ test
β βββ module_test_swin.py <- Code for debugging SwinTransformer
β βββ module_test_swin4d.py <- Code for debugging 4DSwinTransformer
β
βββ sample_scripts <- Example shell scripts for training
β
βββ .gitignore <- List of files/folders ignored by git
βββ export_DDP_vars.sh <- setup file for running torch DistributedDataParallel (DDP)
βββ README.md
- Single forward & backward pass for debugging SwinTransformer4D model.
cd SwiFUN/
python test/module_test_swin4d.pyYou can check the arguments list by using -h
python project/main.py --data_module dummy --classifier_module default -h4.2 Hidden Arguments for PyTorch lightning
pytorch_lightning offers useful arguments for training. For example, we used --max_epochs and --default_root_dir in our experiments. We recommend the user refer to the following link to check the argument lists.
- Training SwiFT in an interactive way
# interactive
cd SwiFUN
bash sample_scripts/sample_train_swifun.shThis bash script was tested on the server cluster (Linux) with 8 RTX 3090 GPUs. You should correct the following lines.
[to be updated]
cd {path to your 'SwiFUN' directory}
source /usr/anaconda3/etc/profile.d/conda.sh (init conda) # might change if you have your own conda.
conda activate {conda env name}
MAIN_ARGS='--loggername neptune --classifier_module v6 --dataset_name {dataset_name} --image_path {path to the image data}' # This script assumes that you have preprocessed HCP dataset. You may run the codes anyway with "--dataset_name Dummy"
DEFAULT_ARGS='--project_name {neptune project name}'
export NEPTUNE_API_TOKEN="{Neptune API token allocated to each user}"
export CUDA_VISIBLE_DEVICES={usable GPU number}- Training SwiFUN with Slurm (if you run the codes at Slurm-based clusters) Please refer to the tutorial for Slurm commands.
cd SwiFUN
sbatch sample_scripts/sample_train_swifun.slurmWe offer two options for loggers.
- Tensorboard (https://www.tensorflow.org/tensorboard)
- Log & model checkpoints are saved in
--default_root_dir - Logging test code with Tensorboard is not available.
- Log & model checkpoints are saved in
- Neptune AI (https://neptune.ai/)
- Generate a new workspace and project on the Neptune website.
- Academic workspace offers 200GB of storage and collaboration for free.
- export NEPTUNE_API_TOKEN="YOUR API TOKEN" in your script.
- specify the "--project_name" argument with your Neptune project name. ex) "--project_name user-id/project"
- Generate a new workspace and project on the Neptune website.
These preprocessing codes are implemented based on the initial repository by GonyRosenman TFF
To make your own dataset, you should execute either of the minimal preprocessing steps:
- fMRIprep Preprocessing with fMRIprep
- FSL UKB Preprocessing pipeline
- We ensure that each brain is registered to the MNI space, and the whole brain mask is applied to remove non-brain regions.
- We are investigating how additional preprocessing steps to remove confounding factors such as head movement impact performance.
After the minimal preprocessing steps, you should perform additional preprocessing to use SwiFT. (You can find the preprocessing code at 'project/module/utils/data_preprocessing_and_load/preprocessing.py')
- normalization: voxel normalization(not used) and whole-brain z-normalization (mainly used)
- change fMRI volumes to floating point 16 to save storage and decrease IO bottleneck.
- each fMRI volume is saved separately as torch checkpoints to facilitate window-based training.
- remove non-brain(background) voxels that are over 96 voxels.
- you should open your fMRI scans to determine the level that does not cut out the brain regions
- you can use
nilearnto visualize your fMRI data. (official documentation: here)
from nilearn import plotting from nilearn.image import mean_img plotting.view_img(mean_img(fmri_filename), threshold=None)
- if your dimension is under 96, you can pad non-brain voxels at 'datasets.py' files.
- refer to the annotation in the 'preprocessing.py' code to adjust it for your own datasets.
The resulting data structure is as follows:
βββ {Dataset name}_MNI_to_TRs
βββ img <- Every normalized volume is located in this directory
β βββ sub-01 <- subject name
β β βββ frame_0.pt <- Each torch pt file contains one volume in a fMRI sequence (total number of pt files = length of fMRI sequence)
β β βββ frame_1.pt
β β β :
β β βββ frame_{T}.pt <- the last volume in an fMRI sequence (length T)
β β βββ global_stats.pt <- min, max, mean value of fMRI for the subject
β βββ sub-02
β β βββ frame_0.pt
β β βββ frame_1.pt
β β βββ :
βββ metadata
βββ metafile.csv <- file containing target variable
- The data loading pipeline works by processing image and metadata at 'project/module/utils/data_module.py' and passing the paired image-label tuples to the Dataset classes at 'project/module/utils/data_preprocessing_and_load/datasets.py.'
- you should implement codes for combining image path, subject_name, and target variables at 'project/module/utils/data_module.py'
- you should define Dataset Class for your dataset at 'project/module/utils/data_preprocessing_and_load/datasets.py.' In the Dataset class (getitem), you should specify how many background voxels you would add or remove to make the volumes shaped 96 * 96 * 96.
@article{kwon2024predicting,
title={Predicting task-related brain activity from resting-state brain dynamics with fMRI Transformer},
author={Kwon, Junbeom and Seo, Jungwoo and Wang, Heehwan and Moon, Taesup and Yoo, Shinjae and Cha, Jiook},
journal={bioRxiv},
pages={2024--05},
year={2024},
publisher={Cold Spring Harbor Laboratory}
}