Skip to content

Commit e93e00d

Browse files
authored
Expose hyperparameters (#130)
* Expose hyperparameters/change cuda logic * Fix set_device/update documentation * Remove self from discriminator * Fix optimizers * Remove self from discriminator * Remove "_" from variables
1 parent 87ecaa3 commit e93e00d

2 files changed

Lines changed: 57 additions & 29 deletions

File tree

ctgan/synthesizers/ctgan.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
class Discriminator(Module):
1616

17-
def __init__(self, input_dim, discriminator_dim, pack=10):
17+
def __init__(self, input_dim, discriminator_dim, pac=10):
1818
super(Discriminator, self).__init__()
19-
dim = input_dim * pack
20-
self.pack = pack
21-
self.packdim = dim
19+
dim = input_dim * pac
20+
self.pac = pac
21+
self.pacdim = dim
2222
seq = []
2323
for item in list(discriminator_dim):
2424
seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
@@ -49,8 +49,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
4949
return gradient_penalty
5050

5151
def forward(self, input):
52-
assert input.size()[0] % self.pack == 0
53-
return self.seq(input.view(-1, self.packdim))
52+
assert input.size()[0] % self.pac == 0
53+
return self.seq(input.view(-1, self.pacdim))
5454

5555

5656
class Residual(Module):
@@ -122,12 +122,19 @@ class CTGANSynthesizer(BaseSynthesizer):
122122
Whether to have print statements for progress results. Defaults to ``False``.
123123
epochs (int):
124124
Number of training epochs. Defaults to 300.
125+
pac (int):
126+
Number of samples to group together when applying the discriminator.
127+
Defaults to 10.
128+
cuda (bool):
129+
Whether to attempt to use cuda for GPU computation.
130+
If this is False or CUDA is not available, CPU will be used.
131+
Defaults to ``True``.
125132
"""
126133

127134
def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
128135
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
129136
discriminator_decay=0, batch_size=500, discriminator_steps=1, log_frequency=True,
130-
verbose=False, epochs=300):
137+
verbose=False, epochs=300, pac=10, cuda=True):
131138

132139
assert batch_size % 2 == 0
133140

@@ -145,8 +152,21 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di
145152
self._log_frequency = log_frequency
146153
self._verbose = verbose
147154
self._epochs = epochs
148-
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
149155
self.trained_epochs = 0
156+
self.pac = pac
157+
158+
if not cuda or not torch.cuda.is_available():
159+
device = 'cpu'
160+
elif isinstance(cuda, str):
161+
device = cuda
162+
else:
163+
device = 'cuda'
164+
165+
self._device = torch.device(device)
166+
167+
self._transformer = None
168+
self._data_sampler = None
169+
self._generator = None
150170

151171
@staticmethod
152172
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
@@ -289,18 +309,19 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
289309
data_dim
290310
).to(self._device)
291311

292-
self._discriminator = Discriminator(
312+
discriminator = Discriminator(
293313
data_dim + self._data_sampler.dim_cond_vec(),
294-
self._discriminator_dim
314+
self._discriminator_dim,
315+
pac=self.pac
295316
).to(self._device)
296317

297-
self._optimizerG = optim.Adam(
318+
optimizerG = optim.Adam(
298319
self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9),
299320
weight_decay=self._generator_decay
300321
)
301322

302-
self._optimizerD = optim.Adam(
303-
self._discriminator.parameters(), lr=self._discriminator_lr,
323+
optimizerD = optim.Adam(
324+
discriminator.parameters(), lr=self._discriminator_lr,
304325
betas=(0.5, 0.9), weight_decay=self._discriminator_decay
305326
)
306327

@@ -343,17 +364,17 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
343364
real_cat = real
344365
fake_cat = fake
345366

346-
y_fake = self._discriminator(fake_cat)
347-
y_real = self._discriminator(real_cat)
367+
y_fake = discriminator(fake_cat)
368+
y_real = discriminator(real_cat)
348369

349-
pen = self._discriminator.calc_gradient_penalty(
350-
real_cat, fake_cat, self._device)
370+
pen = discriminator.calc_gradient_penalty(
371+
real_cat, fake_cat, self._device, self.pac)
351372
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
352373

353-
self._optimizerD.zero_grad()
374+
optimizerD.zero_grad()
354375
pen.backward(retain_graph=True)
355376
loss_d.backward()
356-
self._optimizerD.step()
377+
optimizerD.step()
357378

358379
fakez = torch.normal(mean=mean, std=std)
359380
condvec = self._data_sampler.sample_condvec(self._batch_size)
@@ -370,9 +391,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
370391
fakeact = self._apply_activate(fake)
371392

372393
if c1 is not None:
373-
y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1))
394+
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
374395
else:
375-
y_fake = self._discriminator(fakeact)
396+
y_fake = discriminator(fakeact)
376397

377398
if condvec is None:
378399
cross_entropy = 0
@@ -381,9 +402,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
381402

382403
loss_g = -torch.mean(y_fake) + cross_entropy
383404

384-
self._optimizerG.zero_grad()
405+
optimizerG.zero_grad()
385406
loss_g.backward()
386-
self._optimizerG.step()
407+
optimizerG.step()
387408

388409
if self._verbose:
389410
print(f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f},"
@@ -444,7 +465,5 @@ def sample(self, n, condition_column=None, condition_value=None):
444465

445466
def set_device(self, device):
446467
self._device = device
447-
if hasattr(self, '_generator'):
468+
if self._generator is not None:
448469
self._generator.to(self._device)
449-
if hasattr(self, '_discriminator'):
450-
self._discriminator.to(self._device)

ctgan/synthesizers/tvae.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def __init__(
8282
decompress_dims=(128, 128),
8383
l2scale=1e-5,
8484
batch_size=500,
85-
epochs=300
85+
epochs=300,
86+
loss_factor=2,
87+
cuda=True
8688
):
8789

8890
self.embedding_dim = embedding_dim
@@ -91,10 +93,17 @@ def __init__(
9193

9294
self.l2scale = l2scale
9395
self.batch_size = batch_size
94-
self.loss_factor = 2
96+
self.loss_factor = loss_factor
9597
self.epochs = epochs
9698

97-
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
99+
if not cuda or not torch.cuda.is_available():
100+
device = 'cpu'
101+
elif isinstance(cuda, str):
102+
device = cuda
103+
else:
104+
device = 'cuda'
105+
106+
self._device = torch.device(device)
98107

99108
def fit(self, train_data, discrete_columns=tuple()):
100109
self.transformer = DataTransformer()

0 commit comments

Comments
 (0)