You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Download lpips weights from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and put it in ```models/weights/v0.1/vgg.pth```
34
52
35
-
For setting up on CelebHQ, simply download the images from the official site.
36
-
And mention the right path in the configuration ```config/celebhq.yaml```.
53
+
___
37
54
55
+
## Data Preparation
56
+
### Mnist
38
57
39
-
For training on your own dataset
40
-
* Create your own config and have the path point to images (look at celebhq.yaml for guidance)
41
-
* Create your own dataset class, similar to celeb_dataset.py
42
-
* Map the dataset name to the right class in the training code [here](https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/tools/train_ddpm_vqvae.py#L40) & similarly in inference and ddpm training/inference code.
58
+
For setting up the mnist dataset follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
43
59
60
+
Ensure directory structure is following
61
+
```
62
+
StableDiffusion-PyTorch
63
+
-> data
64
+
-> mnist
65
+
-> train
66
+
-> images
67
+
-> *.png
68
+
-> test
69
+
-> images
70
+
-> *.png
71
+
```
44
72
45
-
# Quickstart
46
-
* Create a new conda environment with python 3.8 then run below commands
* Download lpips from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and put it in ```models/weights/v0.1/vgg.pth```
51
-
* For training autoencoder
52
-
*```python -m tools.train_vqvae --config config/mnist.yaml``` for training vqvae
53
-
*```python -m tools.infer_vqvae --config config/mnist.yaml``` for generating reconstructions
54
-
* For training ldm
55
-
*```python -m tools.train_ddpm_vqvae --config config/mnist.yaml``` for training ddpm
56
-
*```python -m tools.sample_ddpm_vqvae --config config/mnist.yaml``` for generating images
73
+
### CelebHQ
74
+
#### Unconditional
75
+
For setting up on CelebHQ for unconditional, simply download the images from the official repo of CelebMASK HQ [here](https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file).
57
76
77
+
Ensure directory structure is the following
78
+
```
79
+
StableDiffusion-PyTorch
80
+
-> data
81
+
-> CelebAMask-HQ
82
+
-> CelebA-HQ-img
83
+
-> *.jpg
84
+
85
+
```
86
+
#### Mask Conditional
87
+
For CelebHQ for mask conditional LDM additionally do the following:
88
+
89
+
Ensure directory structure is the following
90
+
```
91
+
StableDiffusion-PyTorch
92
+
-> data
93
+
-> CelebAMask-HQ
94
+
-> CelebA-HQ-img
95
+
-> *.jpg
96
+
-> CelebAMask-HQ-mask-anno
97
+
-> 0/1/2/3.../14
98
+
-> *.png
99
+
100
+
```
101
+
102
+
* Run `python -m utils.create_celeb_mask` from repo root to create the mask images from mask annotations
103
+
104
+
Ensure directory structure is the following
105
+
```
106
+
StableDiffusion-PyTorch
107
+
-> data
108
+
-> CelebAMask-HQ
109
+
-> CelebA-HQ-img
110
+
-> *.jpg
111
+
-> CelebAMask-HQ-mask-anno
112
+
-> 0/1/2/3.../14
113
+
-> *.png
114
+
-> CelebAMask-HQ-mask
115
+
-> *.png
116
+
```
117
+
118
+
#### Text Conditional
119
+
For CelebHQ for text conditional LDM additionally do the following:
120
+
* The repo uses captions collected as part of this repo - https://github.com/IIGROUP/MM-CelebA-HQ-Dataset?tab=readme-ov-file
121
+
* Download the captions from the `text` link provided in the repo - https://github.com/IIGROUP/MM-CelebA-HQ-Dataset?tab=readme-ov-file#overview
122
+
* This will download a `celeba-captions` folder, simply move this inside the `data/CelebAMask-HQ` folder as that is where the dataset class expects it to be.
123
+
124
+
Ensure directory structure is the following
125
+
```
126
+
StableDiffusion-PyTorch
127
+
-> data
128
+
-> CelebAMask-HQ
129
+
-> CelebA-HQ-img
130
+
-> *.jpg
131
+
-> CelebAMask-HQ-mask-anno
132
+
-> 0/1/2/3.../14
133
+
-> *.png
134
+
-> CelebAMask-HQ-mask
135
+
-> *.png
136
+
-> celeba-caption
137
+
-> *.txt
138
+
```
139
+
---
58
140
## Configuration
59
141
Allows you to play with different components of ddpm and autoencoder training
60
142
*```config/mnist.yaml``` - Small autoencoder and ldm can even be trained on CPU
@@ -66,6 +148,131 @@ Most parameters are self explanatory but below I mention couple which are specif
66
148
*```autoencoder_acc_steps``` : For accumulating gradients if image size is too large for larger batch sizes
67
149
*```save_latents``` : Enable this to save the latents , during inference of autoencoder. That way ddpm training will be faster
68
150
151
+
___
152
+
## Training
153
+
The repo provides training and inference for Mnist(Unconditional and Class Conditional) and CelebHQ (Unconditional, Text and/or Mask Conditional).
154
+
155
+
For working on your own dataset:
156
+
* Create your own config and have the path in config point to images (look at `celebhq.yaml` for guidance)
157
+
* Create your own dataset class which will just collect all the filenames and return the image in its getitem method. Look at `mnist_dataset.py` or `celeb_dataset.py` for guidance
158
+
159
+
Once the config and dataset is setup:
160
+
* Train the auto encoder on your dataset using [this section](#training-autoencoder-for-ldm)
161
+
* For training Unconditional LDM follow [this section](#training-unconditional-ldm)
162
+
* For class conditional ldm go through [this section](#training-class-conditional-ldm)
163
+
* For text conditional ldm go through [this section](#training-text-conditional-ldm)
164
+
* For text and mask conditional ldm go through [this section](#training-text-and-mask-conditional-ldm)
165
+
166
+
167
+
## Training AutoEncoder for LDM
168
+
* For training autoencoder on mnist,ensure the right path is mentioned in `mnist.yaml`
169
+
* For training autoencoder on celebhq,ensure the right path is mentioned in `celebhq.yaml`
170
+
* For training autoencoder on your own dataset
171
+
* Create your own config and have the path point to images (look at celebhq.yaml for guidance)
172
+
* Create your own dataset class, similar to celeb_dataset.py without conditining parts
173
+
* Map the dataset name to the right class in the training code [here](https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/tools/train_ddpm_vqvae.py#L40)
174
+
* For training autoencoder run ```python -m tools.train_vqvae --config config/mnist.yaml``` for training vqvae with the desire config file
175
+
* For inference using trained autoencoder run```python -m tools.infer_vqvae --config config/mnist.yaml``` for generating reconstructions with right config file. Use save_latent in config to save the latent files
176
+
177
+
178
+
## Training Unconditional LDM
179
+
Train the autoencoder first and setup dataset accordingly.
180
+
181
+
For training unconditional LDM map the dataset to the right class in `train_ddpm_vqvae.py`
182
+
*```python -m tools.train_ddpm_vqvae --config config/mnist.yaml``` for training unconditional ddpm using right config
183
+
*```python -m tools.sample_ddpm_vqvae --config config/mnist.yaml``` for generating images using trained ddpm
184
+
185
+
## Training Conditional LDM
186
+
For training conditional models we need two changes:
187
+
* Dataset classes must provide the additional conditional inputs(see below)
188
+
* Config must be changed with additional conditioning config added
189
+
190
+
Specifically the dataset `getitem` will return the following:
191
+
*`image_tensor` for unconditional training
192
+
* tuple of `(image_tensor, cond_input )` for class conditional training where cond_input is a dictionary consisting of keys ```{class/text/image}```
193
+
194
+
### Training Class Conditional LDM
195
+
The repo provides class conditional latent diffusion model training code for mnist dataset, so one
196
+
can use that to follow the same for their own dataset
197
+
198
+
* Use `mnist_class_cond.yaml` config file as a guide to create your class conditional config file.
199
+
Specifically following new keys need to be modified according to your dataset within `ldm_params`.
200
+
*```
201
+
condition_config:
202
+
condition_types: ['class']
203
+
class_condition_config :
204
+
num_classes : <number of classes: 10 for mnist>
205
+
cond_drop_prob : <probability of dropping class labels>
206
+
```
207
+
* Create a dataset class similar to mnist where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
208
+
* For class conditional input will ONLY be the integer class
209
+
*```
210
+
(image_tensor, {
211
+
'class' : {0/1/.../num_classes}
212
+
})
213
+
214
+
For training class conditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
215
+
* ```python -m tools.train_ddpm_cond --config config/mnist_class_cond.yaml``` for training class conditional on mnist
216
+
* ```python -m tools.sample_ddpm_class_cond --config config/mnist.yaml``` for generating images using class conditional trained ddpm
217
+
218
+
### Training Text Conditional LDM
219
+
The repo provides text conditional latent diffusion model training code for celebhq dataset, so one
220
+
can use that to follow the same for their own dataset
221
+
222
+
* Use `celebhq_text_cond.yaml` config file as a guide to create your config file.
223
+
Specifically following new keys need to be modified according to your dataset within `ldm_params`.
224
+
* ```
225
+
condition_config:
226
+
condition_types: [ 'text' ]
227
+
text_condition_config:
228
+
text_embed_model: 'clip' or 'bert'
229
+
text_embed_dim: 512 or 768
230
+
cond_drop_prob: 0.1
231
+
```
232
+
* Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
233
+
* For text, conditional input will ONLY be the caption
234
+
*```
235
+
(image_tensor, {
236
+
'text' : 'a sample caption for image_tensor'
237
+
})
238
+
239
+
For training text conditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
240
+
* ```python -m tools.train_ddpm_cond --config config/celebhq_text_cond.yaml``` for training text conditioned ldm on celebhq
241
+
* ```python -m tools.sample_ddpm_text_cond --config config/celebhq_text_cond.yaml``` for generating images using text conditional trained ddpm
242
+
243
+
### Training Text and Mask Conditional LDM
244
+
The repo provides text and mask conditional latent diffusion model training code for celebhq dataset, so one
245
+
can use that to follow the same for their own dataset and can even use that train a mask only conditional ldm
246
+
247
+
* Use `celebhq_text_image_cond.yaml` config file as a guide to create your config file.
248
+
Specifically following new keys need to be modified according to your dataset within `ldm_params`.
249
+
* ```
250
+
condition_config:
251
+
condition_types: [ 'text', 'image' ]
252
+
text_condition_config:
253
+
text_embed_model: 'clip' or 'bert
254
+
text_embed_dim: 512 or 768
255
+
cond_drop_prob: 0.1
256
+
image_condition_config:
257
+
image_condition_input_channels: 18
258
+
image_condition_output_channels: 3
259
+
image_condition_h : 512
260
+
image_condition_w : 512
261
+
cond_drop_prob: 0.1
262
+
```
263
+
* Create a dataset class similar to celebhq where the getitem method now returns a tuple of image_tensor and dictionary of conditional_inputs.
264
+
* For text and mask, conditional input will caption and mask image
265
+
*```
266
+
(image_tensor, {
267
+
'text' : 'a sample caption for image_tensor',
268
+
'image' : NUM_CLASSES x MASK_H x MASK_W
269
+
})
270
+
271
+
For training text unconditional LDM map the dataset to the right class in `train_ddpm_cond` and run the below commands using desired config
272
+
* ```python -m tools.train_ddpm_cond --config config/celebhq_text_image_cond.yaml``` for training text and mask conditioned ldm on celebhq
273
+
* ```python -m tools.sample_ddpm_text_image_cond --config config/celebhq_text_image_cond.yaml``` for generating images using text and mask conditional trained ddpm
274
+
275
+
69
276
## Output
70
277
Outputs will be saved according to the configuration present in yaml files.
71
278
@@ -79,9 +286,12 @@ During inference of autoencoder the following output will be saved
79
286
* Reconstructions for random images in ```task_name```
80
287
* Latents will be save in ```task_name/vqvae_latent_dir_name``` if mentioned in config
81
288
82
-
During training of DDPM we will save the latest checkpoint in ```task_name``` directory
83
-
During sampling, sampled image grid for all timesteps in ```task_name/samples/*.png```
84
-
289
+
During training and inference of ddpm following output will be saved
290
+
* During training of unconditional or conditional DDPM we will save the latest checkpoint in ```task_name``` directory
291
+
* During sampling, unconditional sampled image grid for all timesteps in ```task_name/samples/*.png```
292
+
* During sampling, class conditionally sampled image grid for all timesteps in ```task_name/cond_class_samples/*.png```
293
+
* During sampling, text only conditionally sampled image grid for all timesteps in ```task_name/cond_text_samples/*.png```
294
+
* During sampling, image only conditionally sampled image grid for all timesteps in ```task_name/cond_text_image_samples/*.png```
0 commit comments