Skip to content

mtkresearch/highdiff

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

85 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

K-DCT DDPM

This repo contains PyTorch implementation of the paper "Improved denoising diffusion probabilistic models with efficient non-diagonal covariance modeling"

The sampling process of Denoising Diffusion Probabilistic Models (DDPMs) can be accelerated by leveraging second-order information in the form of approximations to the denoising posterior covariance. Previous attempts at using such information have used drastic (e.g. diagonal) simplifications of the covariance. These do not do justice to the peculiar statistical structure of natural images, which exhibit strong non-diagonal correlations between pixels and color channels, and a slow-decaying power-law frequency spectrum. Here, we develop a novel covariance model that captures these features. Our Kronecker-DCT (K-DCT) model uses a Kronecker-factored decomposition of inter-color covariances and spatial covariances modeled in the frequency domain using the Discrete Cosine Transform (DCT). The use of the DCT reduces the computational complexity from quadratic to log-linear, resulting in negligible computational and memory overhead each denoising step. By learning K-DCT-structured amortizations of the denoising posterior covariance using pre-trained score models on CIFAR-10, Celeb-A, ImageNet and LSUN datasets, we show improved performance compared to previous SOTA denoising samplers, both in terms of FID and likelihoods, especially in the regime of few denoising steps.

Installation

Our implementation is based on the Extended Analytic-DPM and OCM-DPM repository. To set up the environment, please follow the installation instructions provided in that repository. The main functionality of our code closely mirrors the original repo, and we provide detailed usage instructions below.

Training

To train the model, you can use the following command:

python run_train.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory $train_hparams
  • pretrained_path is the path to a pretrained diffusion probabilistic model (DPM), can be found in the Extended Analytic-DPM repository.
  • dataset represents the training dataset, one of <cifar10|celeba64|imagenet64>.
  • workspace is the place to put training outputs, e.g., logs and checkpoints.
  • train_hparams specify other hyperparameters used in training.

We provide the train_hparams used in training for our models on each dataset:

  • CIFAR10 (LS): --method pred_eps_<obj>_blockcirc_pretrained --mode blockcirc
  • CIFAR10 (CS): --method pred_eps_<obj>_blockcirc_pretrained --mode blockcirc --schedule cosine_1000
  • CelebA64: --method pred_eps_<obj>_blockcirc_pretrained --mode blockcirc_complex
  • ImageNet64: --method pred_eps_<obj>_blockcirc_pretrained --mode blockcirc_complex

where <obj> can be either be hes if training with the OCM objective or epsc if training with the NPR objective. For example, to train the CIFAR10 (LS) model with the NPR objective, you can run:

python run_train.py 
  --pretrained_path path/to/pretrained_dpm \
  --dataset cifar10 \
  --workspace path/to/working_directory \
  --method pred_eps_epsc_blockcirc_pretrained \
  --mode blockcirc

Evaluation

To evaluate the model, you can use the following command:

python run_eval.py --pretrained_path path/to/evaluated_model --dataset dataset --workspace path/to/working_directory --phase phase --sample_steps sample_steps --batch_size batch_size --method pred_eps_hes_pretrained $eval_hparams
  • dataset represents the dataset the model is trained on, one of <cifar10|celeba64|imagenet64>.
  • workspace is the place to put evaluation outputs, e.g., logs, samples and bpd values.
  • phase specifies running FID or likelihood evaluation, one of <sample4test|nll4test>.
  • sample_steps is the number of steps to run during inference, the samller this value the faster the inference.
  • batch_size is the batch size, e.g., 500.
  • eval_hparams specifies other optional hyperparameters used in evaluation.

We provide eval_hparams for the FID and NLL results in this paper.

  • FID Evaluation (DDPM)
    • CIFAR10 (LS): --mode blockcirc --rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2
    • CIFAR10 (CS): --mode blockcirc --rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000
    • CelebA64: --rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --mode blockcirc_complex
    • ImageNet64: --rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode blockcirc_complex
  • NLL Evaluation
    • CIFAR10 (LS): --mode blockcirc --rev_var_type optimal
    • CIFAR10 (CS): --mode blockcric --rev_var_type optimal --schedule cosine_1000
    • CelebA64: --rev_var_type optimal --mode blockcirc_complex
    • ImageNet64: --rev_var_type optimal --mode blockcirc_complex

About

Higher order Diffusion Models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors