-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiffusion.py
More file actions
159 lines (129 loc) · 5.79 KB
/
diffusion.py
File metadata and controls
159 lines (129 loc) · 5.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
from utils import *
import torch.nn.functional as F
import math
from utils import fourier_nscales, _nested_map, interpolate_nscales
import numpy as np
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
:return: a tensor of betas.
Taken from https://github.com/SmartTURB/diffusion-lagr/blob/master/guided_diffusion/gaussian_diffusion.py
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.1):
"""
Function that returns a linear schedule of beta values for the diffusion process.
Args:
timesteps (int): Number of timesteps.
beta_start (float): Starting beta value.
beta_end (float): Ending beta value.
Returns:
torch.Tensor: Beta values for the diffusion process.
"""
return torch.linspace(beta_start, beta_end, timesteps)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def tanh61_beta_schedule(timesteps, t0=6, t1=1):
"""
tanh6-1 schedule
"""
return betas_for_alpha_bar(
timesteps,
lambda t: -math.tanh((t0 + t1) * t - t0) + math.tanh(t1),
)
def create_beta_schedule(steps, scheduler="linear", **kwargs):
"""
Function that returns a beta schedule for the diffusion process.
Args:
steps (int): Number of timesteps.
scheduler (str): Scheduler to use for the diffusion process.
kwargs: Additional arguments for the scheduler.
Returns:
torch.Tensor: Beta values for the diffusion process.
"""
if(scheduler == "cosine"):
betas = cosine_beta_schedule(timesteps=steps, **kwargs)
elif(scheduler == "tanh61"):
betas = tanh61_beta_schedule(timesteps=steps, **kwargs)
else:
betas = linear_beta_schedule(timesteps=steps, **kwargs)
return betas
class GaussianDiffusion:
"""
Class for Gaussian diffusion process.
Args:
betas (torch.Tensor): Beta values for the diffusion process.
"""
def __init__(self, betas) -> None:
self.betas = betas
self.alphas = 1. - betas
# accumulated product of alphas for each time step (1,.., T)
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
# and until time t-1 (1,.., t-1)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
# terms with square root of all the accumulated alphas
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
# variance of the posterior q(x_{t-1} | x_t, x_0)
# according to eq. 7 in Ho et al. (2020)
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
#
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
def forward_diffusion_process_dict(self, x_0, t, fourier = False, smoother = None, device="cpu"):
"""
Takes a dictionary of batches and timesteps as input and returns the noisy version of it.
Args:
x_0 (dict): Dictionary of batches.
t (torch.Tensor): Timesteps.
Returns:
dict(torch.Tensor): Noisy data.
dict(torch.Tensor): Noise added to each level.
"""
levels = len(x_0)
noise = torch.randn_like(x_0[0])
noises = interpolate_nscales(noise, scales = levels)
for level, trajectory in x_0.items():
noisy_traj, noise = self.forward_diffusion_process(trajectory, t, noise = noises[level])
x_0[level] = noisy_traj.to(device)
noises[level] = noise.to(device)
return x_0, noises
def forward_diffusion_process(self, x_0, t, device = None, noise = None):
"""
Takes a data point (or a batch) and a timestep (or batch of timesteps)
as input and returns the noisy version of it.
Returns:
torch.Tensor: Noisy data.
torch.Tensor: Noise added to the data.
"""
if noise is None:
noise = torch.randn_like(x_0)
sqrt_alphas_cumprod_t = get_index_from_list(self.sqrt_alphas_cumprod, t, x_0.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
)
if device is not None:
return (sqrt_alphas_cumprod_t * x_0 \
+ sqrt_one_minus_alphas_cumprod_t * noise).to(device), noise.to(device)
return sqrt_alphas_cumprod_t * x_0 \
+ sqrt_one_minus_alphas_cumprod_t * noise, noise