Skip to content

Commit 0c078af

Browse files
Add variable number of discriminator updates for each generator update (#103)
* Add n_discriminator steps * move parameter to init * Update synthesizer.py * remove whitespace * Add extra information in docstring and change variable name to discriminator_steps Co-authored-by: Carles Sala <carles@pythiac.com>
1 parent e1c09e1 commit 0c078af

1 file changed

Lines changed: 51 additions & 44 deletions

File tree

ctgan/synthesizer.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,20 @@ class CTGANSynthesizer(object):
2929
Size of the output samples for each one of the Discriminator Layers. A Linear Layer
3030
will be created for each one of the values provided. Defaults to (256, 256).
3131
l2scale (float):
32-
Wheight Decay for the Adam Optimizer. Defaults to 1e-6.
32+
Weight Decay for the Adam Optimizer. Defaults to 1e-6.
3333
batch_size (int):
3434
Number of data samples to process in each step.
35+
discriminator_steps (int):
36+
Number of discriminator updates to do for each generator update.
37+
From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper
38+
default is 5. Default used is 1 to match original CTGAN implementation.
39+
log_frequency (boolean):
40+
Whether to use log frequency of categorical levels in conditional
41+
sampling. Defaults to ``True``.
3542
"""
3643

3744
def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
38-
l2scale=1e-6, batch_size=500, log_frequency=True):
45+
l2scale=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True):
3946

4047
self.embedding_dim = embedding_dim
4148
self.gen_dim = gen_dim
@@ -46,6 +53,7 @@ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
4653
self.log_frequency = log_frequency
4754
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
4855
self.trained_epoches = 0
56+
self.discriminator_steps = discriminator_steps
4957

5058
@staticmethod
5159
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
@@ -64,9 +72,6 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
6472
but will be differentiated as if it is the soft sample in autograd
6573
dim (int):
6674
a dimension along which softmax will be computed. Default: -1.
67-
log_frequency (boolean):
68-
Whether to use log frequency of categorical levels in conditional
69-
sampling. Defaults to ``True``.
7075
7176
Returns:
7277
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
@@ -197,46 +202,48 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
197202
for i in range(epochs):
198203
self.trained_epoches += 1
199204
for id_ in range(steps_per_epoch):
200-
fakez = torch.normal(mean=mean, std=std)
201-
202-
condvec = self.cond_generator.sample(self.batch_size)
203-
if condvec is None:
204-
c1, m1, col, opt = None, None, None, None
205-
real = data_sampler.sample(self.batch_size, col, opt)
206-
else:
207-
c1, m1, col, opt = condvec
208-
c1 = torch.from_numpy(c1).to(self.device)
209-
m1 = torch.from_numpy(m1).to(self.device)
210-
fakez = torch.cat([fakez, c1], dim=1)
211-
212-
perm = np.arange(self.batch_size)
213-
np.random.shuffle(perm)
214-
real = data_sampler.sample(self.batch_size, col[perm], opt[perm])
215-
c2 = c1[perm]
216-
217-
fake = self.generator(fakez)
218-
fakeact = self._apply_activate(fake)
219-
220-
real = torch.from_numpy(real.astype('float32')).to(self.device)
221-
222-
if c1 is not None:
223-
fake_cat = torch.cat([fakeact, c1], dim=1)
224-
real_cat = torch.cat([real, c2], dim=1)
225-
else:
226-
real_cat = real
227-
fake_cat = fake
228-
229-
y_fake = self.discriminator(fake_cat)
230-
y_real = self.discriminator(real_cat)
231-
232-
pen = self.discriminator.calc_gradient_penalty(
233-
real_cat, fake_cat, self.device)
234-
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
235205

236-
self.optimizerD.zero_grad()
237-
pen.backward(retain_graph=True)
238-
loss_d.backward()
239-
self.optimizerD.step()
206+
for n in range(self.discriminator_steps):
207+
fakez = torch.normal(mean=mean, std=std)
208+
209+
condvec = self.cond_generator.sample(self.batch_size)
210+
if condvec is None:
211+
c1, m1, col, opt = None, None, None, None
212+
real = data_sampler.sample(self.batch_size, col, opt)
213+
else:
214+
c1, m1, col, opt = condvec
215+
c1 = torch.from_numpy(c1).to(self.device)
216+
m1 = torch.from_numpy(m1).to(self.device)
217+
fakez = torch.cat([fakez, c1], dim=1)
218+
219+
perm = np.arange(self.batch_size)
220+
np.random.shuffle(perm)
221+
real = data_sampler.sample(self.batch_size, col[perm], opt[perm])
222+
c2 = c1[perm]
223+
224+
fake = self.generator(fakez)
225+
fakeact = self._apply_activate(fake)
226+
227+
real = torch.from_numpy(real.astype('float32')).to(self.device)
228+
229+
if c1 is not None:
230+
fake_cat = torch.cat([fakeact, c1], dim=1)
231+
real_cat = torch.cat([real, c2], dim=1)
232+
else:
233+
real_cat = real
234+
fake_cat = fake
235+
236+
y_fake = self.discriminator(fake_cat)
237+
y_real = self.discriminator(real_cat)
238+
239+
pen = self.discriminator.calc_gradient_penalty(
240+
real_cat, fake_cat, self.device)
241+
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
242+
243+
self.optimizerD.zero_grad()
244+
pen.backward(retain_graph=True)
245+
loss_d.backward()
246+
self.optimizerD.step()
240247

241248
fakez = torch.normal(mean=mean, std=std)
242249
condvec = self.cond_generator.sample(self.batch_size)

0 commit comments

Comments
 (0)