Skip to content

Commit 9b0f915

Browse files
Changes for conditional ldm training
1 parent 9069801 commit 9b0f915

21 files changed

Lines changed: 1837 additions & 63 deletions

README.md

Lines changed: 237 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ Stable Diffusion Implementation in PyTorch
22
========
33

44
This repository implements Stable Diffusion.
5-
As of now this only implements unconditional latent diffusion models and trains on mnist and celebhq dataset.
6-
Pretty soon it will also have code for conditional ldm.
5+
As of today the repo provides code to do the following:
6+
* Training and Inference on Unconditional Latent Diffusion Models
7+
* Training a Class Conditional Latent Diffusion Model
8+
* Training a Text Conditioned Latent Diffusion Model
9+
* Training a Semantic Mask Conditioned Latent Diffusion Model
10+
* Any Combination of the above three conditioning
711

812
For autoencoder I provide code for vae as well as vqvae.
913
But both the stages of training use VQVAE only. One can easily change that to vae if needed
@@ -12,49 +16,127 @@ For diffusion part, as of now it only implements DDPM with linear schedule.
1216

1317

1418
## Stable Diffusion Tutorial Video
19+
### Unconditional
1520
<a href="https://www.youtube.com/watch?v=1BkzNb3ejK4">
1621
<img alt="Stable Diffusion Tutorial" src="https://github.com/explainingai-code/StableDiffusion-PyTorch/assets/144267687/7a24d114-38bd-43a8-9819-3afa112f39ab"
1722
width="400">
1823
</a>
1924

25+
### Conditional
26+
27+
___
28+
2029
## Sample Output for Autoencoder on CelebHQ
2130
Image - Top, Reconstructions - Below
2231

2332
<img src="https://github.com/explainingai-code/StableDiffusion-PyTorch/assets/144267687/2260d618-046e-411c-bea5-0c4cb7438560" width="300">
2433

25-
## Sample Output for LDM on CelebHQ (not fully converged)
34+
## Sample Output for Unconditional LDM on CelebHQ (not fully converged)
2635

2736
<img src="https://github.com/explainingai-code/StableDiffusion-PyTorch/assets/144267687/212cd84a-9bd1-43f0-93b4-3b8ff9866571" width="300">
2837

38+
## Sample Output for Conditional LDM
39+
### Sample Output for Class Conditioned on MNIST
40+
### Sample Output for Text Conditioned on CelebHQ (not converged)
41+
### Sample Output for Mask Conditioned on CelebHQ (not converged)
42+
### Sample Output for Text and Mask Conditioned on CelebHQ (not converged)
2943

30-
## Data preparation
31-
For setting up the mnist dataset:
44+
___
3245

33-
Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
46+
## Setup
47+
* Create a new conda environment with python 3.8 then run below commands
48+
* ```git clone https://github.com/explainingai-code/StableDiffusion-PyTorch.git```
49+
* ```cd StableDiffusion-PyTorch```
50+
* ```pip install -r requirements.txt```
51+
* Download lpips weights from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and put it in ```models/weights/v0.1/vgg.pth```
3452

35-
For setting up on CelebHQ, simply download the images from the official site.
36-
And mention the right path in the configuration ```config/celebhq.yaml```.
53+
___
3754

55+
## Data Preparation
56+
### Mnist
3857

39-
For training on your own dataset
40-
* Create your own config and have the path point to images (look at celebhq.yaml for guidance)
41-
* Create your own dataset class, similar to celeb_dataset.py
42-
* Map the dataset name to the right class in the training code [here](https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/tools/train_ddpm_vqvae.py#L40) & similarly in inference and ddpm training/inference code.
58+
For setting up the mnist dataset follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
4359

60+
Ensure directory structure is following
61+
```
62+
StableDiffusion-PyTorch
63+
-> data
64+
-> mnist
65+
-> train
66+
-> images
67+
-> *.png
68+
-> test
69+
-> images
70+
-> *.png
71+
```
4472

45-
# Quickstart
46-
* Create a new conda environment with python 3.8 then run below commands
47-
* ```git clone https://github.com/explainingai-code/StableDiffusion-PyTorch.git```
48-
* ```cd StableDiffusion-PyTorch```
49-
* ```pip install -r requirements.txt```
50-
* Download lpips from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and put it in ```models/weights/v0.1/vgg.pth```
51-
* For training autoencoder
52-
* ```python -m tools.train_vqvae --config config/mnist.yaml``` for training vqvae
53-
* ```python -m tools.infer_vqvae --config config/mnist.yaml``` for generating reconstructions
54-
* For training ldm
55-
* ```python -m tools.train_ddpm_vqvae --config config/mnist.yaml``` for training ddpm
56-
* ```python -m tools.sample_ddpm_vqvae --config config/mnist.yaml``` for generating images
73+
### CelebHQ
74+
#### Unconditional
75+
For setting up on CelebHQ for unconditional, simply download the images from the official repo of CelebMASK HQ [here](https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file).
5776

77+
Ensure directory structure is the following
78+
```
79+
StableDiffusion-PyTorch
80+
-> data
81+
-> CelebAMask-HQ
82+
-> CelebA-HQ-img
83+
-> *.jpg
84+
85+
```
86+
#### Mask Conditional
87+
For CelebHQ for mask conditional LDM additionally do the following:
88+
89+
Ensure directory structure is the following
90+
```
91+
StableDiffusion-PyTorch
92+
-> data
93+
-> CelebAMask-HQ
94+
-> CelebA-HQ-img
95+
-> *.jpg
96+
-> CelebAMask-HQ-mask-anno
97+
-> 0/1/2/3.../14
98+
-> *.png
99+
100+
```
101+
102+
* Run `python -m utils.create_celeb_mask` from repo root to create the mask images from mask annotations
103+
104+
Ensure directory structure is the following
105+
```
106+
StableDiffusion-PyTorch
107+
-> data
108+
-> CelebAMask-HQ
109+
-> CelebA-HQ-img
110+
-> *.jpg
111+
-> CelebAMask-HQ-mask-anno
112+
-> 0/1/2/3.../14
113+
-> *.png
114+
-> CelebAMask-HQ-mask
115+
-> *.png
116+
```
117+
118+
#### Text Conditional
119+
For CelebHQ for text conditional LDM additionally do the following:
120+
* The repo uses captions collected as part of this repo - https://github.com/IIGROUP/MM-CelebA-HQ-Dataset?tab=readme-ov-file
121+
* Download the captions from the `text` link provided in the repo - https://github.com/IIGROUP/MM-CelebA-HQ-Dataset?tab=readme-ov-file#overview
122+
* This will download a `celeba-captions` folder, simply move this inside the `data/CelebAMask-HQ` folder as that is where the dataset class expects it to be.
123+
124+
Ensure directory structure is the following
125+
```
126+
StableDiffusion-PyTorch
127+
-> data
128+
-> CelebAMask-HQ
129+
-> CelebA-HQ-img
130+
-> *.jpg
131+
-> CelebAMask-HQ-mask-anno
132+
-> 0/1/2/3.../14
133+
-> *.png
134+
-> CelebAMask-HQ-mask
135+
-> *.png
136+
-> celeba-caption
137+
-> *.txt
138+
```
139+
---
58140
## Configuration
59141
Allows you to play with different components of ddpm and autoencoder training
60142
* ```config/mnist.yaml``` - Small autoencoder and ldm can even be trained on CPU
@@ -66,6 +148,131 @@ Most parameters are self explanatory but below I mention couple which are specif
66148
* ```autoencoder_acc_steps``` : For accumulating gradients if image size is too large for larger batch sizes
67149
* ```save_latents``` : Enable this to save the latents , during inference of autoencoder. That way ddpm training will be faster
68150

151+
___
152+
## Training
153+
The repo provides training and inference for Mnist(Unconditional and Class Conditional) and CelebHQ (Unconditional, Text and/or Mask Conditional).
154+
155+
For working on your own dataset:
156+
* Create your own config and have the path in config point to images (look at `celebhq.yaml` for guidance)
157+
* Create your own dataset class which will just collect all the filenames and return the image in its getitem method. Look at `mnist_dataset.py` or `celeb_dataset.py` for guidance
158+
159+
Once the config and dataset is setup:
160+
* Train the auto encoder on your dataset using [this section](#training-autoencoder-for-ldm)
161+
* For training Unconditional LDM follow [this section](#training-unconditional-ldm)
162+
* For class conditional ldm go through [this section](#training-class-conditional-ldm)
163+
* For text conditional ldm go through [this section](#training-text-conditional-ldm)
164+
* For text and mask conditional ldm go through [this section](#training-text-and-mask-conditional-ldm)
165+
166+
167+
## Training AutoEncoder for LDM
168+
* For training autoencoder on mnist,ensure the right path is mentioned in `mnist.yaml`
169+
* For training autoencoder on celebhq,ensure the right path is mentioned in `celebhq.yaml`
170+
* For training autoencoder on your own dataset
171+
* Create your own config and have the path point to images (look at celebhq.yaml for guidance)
172+
* Create your own dataset class, similar to celeb_dataset.py without conditining parts
173+
* Map the dataset name to the right class in the training code [here](https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/tools/train_ddpm_vqvae.py#L40)
174+
* For training autoencoder run ```python -m tools.train_vqvae --config config/mnist.yaml``` for training vqvae with the desire config file
175+
* For inference using trained autoencoder run```python -m tools.infer_vqvae --config config/mnist.yaml``` for generating reconstructions with right config file. Use save_latent in config to save the latent files
176+
177+
178+
## Training Unconditional LDM
179+
Train the autoencoder first and setup dataset accordingly.
180+
181+
For training unconditional LDM map the dataset to the right class in `train_ddpm_vqvae.py`
182+
* ```python -m tools.train_ddpm_vqvae --config config/mnist.yaml``` for training unconditional ddpm using right config
183+
* ```python -m tools.sample_ddpm_vqvae --config config/mnist.yaml``` for generating images using trained ddpm
184+
185+
## Training Conditional LDM
186+
For training conditional models we need two changes:
187+
* Dataset classes must provide the additional conditional inputs(see below)
188+
* Config must be changed with additional conditioning config added
189+
190+
Specifically the dataset `getitem` will return the following:
191+
* `image_tensor` for unconditional training
192+
* tuple of `(image_tensor, cond_input )` for class conditional training where cond_input is a dictionary consisting of keys ```{class/text/image}```
193+
194+
### Training Class Conditional LDM
195+
The repo provides class conditional latent diffusion model training code for mnist dataset, so one
196+
can use that to follow the same for their own dataset
197+
198+
* Use `mnist_class_cond.yaml` config file as a guide to create your class conditional config file.
199+
Specifically following new keys need to be modified according to your dataset within `ldm_params`.
200+
* ```
201+
condition_config:
202+
condition_types: ['class']
203+
class_condition_config :
204+
num_classes : <number of classes: 10 for mnist>
205+
cond_drop_prob : <probability of dropping class labels>
206+
```
207+
* Create a dataset class similar to mnist where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
208+
* For class conditional input will ONLY be the integer class
209+
* ```
210+
(image_tensor, {
211+
'class' : {0/1/.../num_classes}
212+
})
213+
214+
For training class conditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
215+
* ```python -m tools.train_ddpm_cond --config config/mnist_class_cond.yaml``` for training class conditional on mnist
216+
* ```python -m tools.sample_ddpm_class_cond --config config/mnist.yaml``` for generating images using class conditional trained ddpm
217+
218+
### Training Text Conditional LDM
219+
The repo provides text conditional latent diffusion model training code for celebhq dataset, so one
220+
can use that to follow the same for their own dataset
221+
222+
* Use `celebhq_text_cond.yaml` config file as a guide to create your config file.
223+
Specifically following new keys need to be modified according to your dataset within `ldm_params`.
224+
* ```
225+
condition_config:
226+
condition_types: [ 'text' ]
227+
text_condition_config:
228+
text_embed_model: 'clip' or 'bert'
229+
text_embed_dim: 512 or 768
230+
cond_drop_prob: 0.1
231+
```
232+
* Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
233+
* For text, conditional input will ONLY be the caption
234+
* ```
235+
(image_tensor, {
236+
'text' : 'a sample caption for image_tensor'
237+
})
238+
239+
For training text conditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
240+
* ```python -m tools.train_ddpm_cond --config config/celebhq_text_cond.yaml``` for training text conditioned ldm on celebhq
241+
* ```python -m tools.sample_ddpm_text_cond --config config/celebhq_text_cond.yaml``` for generating images using text conditional trained ddpm
242+
243+
### Training Text and Mask Conditional LDM
244+
The repo provides text and mask conditional latent diffusion model training code for celebhq dataset, so one
245+
can use that to follow the same for their own dataset and can even use that train a mask only conditional ldm
246+
247+
* Use `celebhq_text_image_cond.yaml` config file as a guide to create your config file.
248+
Specifically following new keys need to be modified according to your dataset within `ldm_params`.
249+
* ```
250+
condition_config:
251+
condition_types: [ 'text', 'image' ]
252+
text_condition_config:
253+
text_embed_model: 'clip' or 'bert
254+
text_embed_dim: 512 or 768
255+
cond_drop_prob: 0.1
256+
image_condition_config:
257+
image_condition_input_channels: 18
258+
image_condition_output_channels: 3
259+
image_condition_h : 512
260+
image_condition_w : 512
261+
cond_drop_prob: 0.1
262+
```
263+
* Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
264+
* For text and mask, conditional input will caption and mask image
265+
* ```
266+
(image_tensor, {
267+
'text' : 'a sample caption for image_tensor',
268+
'image' : NUM_CLASSES x MASK_H x MASK_W
269+
})
270+
271+
For training text unconditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
272+
* ```python -m tools.train_ddpm_cond --config config/celebhq_text_image_cond.yaml``` for training text and mask conditioned ldm on celebhq
273+
* ```python -m tools.sample_ddpm_text_image_cond --config config/celebhq_text_image_cond.yaml``` for generating images using text and mask conditional trained ddpm
274+
275+
69276
## Output
70277
Outputs will be saved according to the configuration present in yaml files.
71278
@@ -79,9 +286,12 @@ During inference of autoencoder the following output will be saved
79286
* Reconstructions for random images in ```task_name```
80287
* Latents will be save in ```task_name/vqvae_latent_dir_name``` if mentioned in config
81288
82-
During training of DDPM we will save the latest checkpoint in ```task_name``` directory
83-
During sampling, sampled image grid for all timesteps in ```task_name/samples/*.png```
84-
289+
During training and inference of ddpm following output will be saved
290+
* During training of unconditional or conditional DDPM we will save the latest checkpoint in ```task_name``` directory
291+
* During sampling, unconditional sampled image grid for all timesteps in ```task_name/samples/*.png```
292+
* During sampling, class conditionally sampled image grid for all timesteps in ```task_name/cond_class_samples/*.png```
293+
* During sampling, text only conditionally sampled image grid for all timesteps in ```task_name/cond_text_samples/*.png```
294+
* During sampling, image only conditionally sampled image grid for all timesteps in ```task_name/cond_text_image_samples/*.png```
85295
86296
87297

config/celebhq.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
dataset_params:
2-
im_path: 'data/celeba_hq_256'
2+
im_path: 'data/CelebAMask-HQ'
33
im_channels : 3
44
im_size : 256
55
name: 'celebhq'

config/celebhq_text_cond.yaml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
dataset_params:
2+
im_path: '/Users/tusharkumar/Downloads/CelebAMask-HQ'
3+
im_channels : 3
4+
im_size : 256
5+
name: 'celebhq'
6+
7+
diffusion_params:
8+
num_timesteps : 1000
9+
beta_start : 0.00085
10+
beta_end : 0.012
11+
12+
ldm_params:
13+
down_channels: [ 256, 384, 512, 768 ]
14+
mid_channels: [ 768, 512 ]
15+
down_sample: [ True, True, True ]
16+
attn_down : [True, True, True]
17+
time_emb_dim: 512
18+
norm_channels: 32
19+
num_heads: 16
20+
conv_out_channels : 128
21+
num_down_layers : 2
22+
num_mid_layers : 2
23+
num_up_layers : 2
24+
condition_config:
25+
condition_types: [ 'text' ]
26+
text_condition_config:
27+
text_embed_model: 'clip'
28+
train_text_embed_model: False
29+
text_embed_dim: 512
30+
cond_drop_prob: 0.1
31+
32+
autoencoder_params:
33+
z_channels: 3
34+
codebook_size : 8192
35+
down_channels : [64, 128, 256, 256]
36+
mid_channels : [256, 256]
37+
down_sample : [True, True, True]
38+
attn_down : [False, False, False]
39+
norm_channels: 32
40+
num_heads: 4
41+
num_down_layers : 2
42+
num_mid_layers : 2
43+
num_up_layers : 2
44+
45+
46+
train_params:
47+
seed : 1111
48+
task_name: 'celebhq'
49+
ldm_batch_size: 16
50+
autoencoder_batch_size: 4
51+
disc_start: 15000
52+
disc_weight: 0.5
53+
codebook_weight: 1
54+
commitment_beta: 0.2
55+
perceptual_weight: 1
56+
kl_weight: 0.000005
57+
ldm_epochs: 100
58+
autoencoder_epochs: 20
59+
num_samples: 1
60+
num_grid_rows: 1
61+
ldm_lr: 0.000005
62+
autoencoder_lr: 0.00001
63+
autoencoder_acc_steps: 4
64+
autoencoder_img_save_steps: 64
65+
save_latents : False
66+
cf_guidance_scale : 1.0
67+
vae_latent_dir_name: 'vae_latents'
68+
vqvae_latent_dir_name: 'vqvae_latents'
69+
ldm_ckpt_name: 'ddpm_ckpt_text_cond_clip.pth'
70+
vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
71+
vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
72+
vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
73+
vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'

0 commit comments

Comments
 (0)