Skip to content

Commit a692b5f

Browse files
committed
make release-tag: Merge branch 'master' into stable
2 parents fc5ab8c + e6d3628 commit a692b5f

8 files changed

Lines changed: 92 additions & 37 deletions

File tree

HISTORY.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
# History
22

3+
## v0.4.1 - 2021-03-30
4+
5+
This release exposes all the hyperparameters which the user may find useful for both `CTGAN`
6+
and `TVAE`. Also `TVAE` can now be fitted on datasets that are shorter than the batch
7+
size and drops the last batch only if the data size is not divisible by the batch size.
8+
9+
### Issues closed
10+
11+
* `TVAE`: Adapt `batch_size` to data size - Issue [#135](https://github.com/sdv-dev/CTGAN/issues/135) by @fealho and @csala
12+
* `ValueError` from `validate_discre_columns` with `uniqueCombinationConstraint` - Issue [133](https://github.com/sdv-dev/CTGAN/issues/133) by @fealho and @MLjungg
13+
314
## v0.4.0 - 2021-02-24
415

516
Maintenance relese to upgrade dependencies to ensure compatibility with the rest

conda/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{% set name = 'ctgan' %}
2-
{% set version = '0.4.0' %}
2+
{% set version = '0.4.1.dev2' %}
33

44
package:
55
name: "{{ name|lower }}"

ctgan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
__author__ = 'MIT Data To AI Lab'
66
__email__ = 'dailabmit@gmail.com'
7-
__version__ = '0.4.0'
7+
__version__ = '0.4.1.dev2'
88

99
from ctgan.demo import load_demo
1010
from ctgan.synthesizers.ctgan import CTGANSynthesizer

ctgan/synthesizers/ctgan.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
class Discriminator(Module):
1616

17-
def __init__(self, input_dim, discriminator_dim, pack=10):
17+
def __init__(self, input_dim, discriminator_dim, pac=10):
1818
super(Discriminator, self).__init__()
19-
dim = input_dim * pack
20-
self.pack = pack
21-
self.packdim = dim
19+
dim = input_dim * pac
20+
self.pac = pac
21+
self.pacdim = dim
2222
seq = []
2323
for item in list(discriminator_dim):
2424
seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
@@ -49,8 +49,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
4949
return gradient_penalty
5050

5151
def forward(self, input):
52-
assert input.size()[0] % self.pack == 0
53-
return self.seq(input.view(-1, self.packdim))
52+
assert input.size()[0] % self.pac == 0
53+
return self.seq(input.view(-1, self.pacdim))
5454

5555

5656
class Residual(Module):
@@ -122,12 +122,19 @@ class CTGANSynthesizer(BaseSynthesizer):
122122
Whether to have print statements for progress results. Defaults to ``False``.
123123
epochs (int):
124124
Number of training epochs. Defaults to 300.
125+
pac (int):
126+
Number of samples to group together when applying the discriminator.
127+
Defaults to 10.
128+
cuda (bool):
129+
Whether to attempt to use cuda for GPU computation.
130+
If this is False or CUDA is not available, CPU will be used.
131+
Defaults to ``True``.
125132
"""
126133

127134
def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
128135
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
129136
discriminator_decay=0, batch_size=500, discriminator_steps=1, log_frequency=True,
130-
verbose=False, epochs=300):
137+
verbose=False, epochs=300, pac=10, cuda=True):
131138

132139
assert batch_size % 2 == 0
133140

@@ -145,8 +152,20 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di
145152
self._log_frequency = log_frequency
146153
self._verbose = verbose
147154
self._epochs = epochs
148-
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
149-
self.trained_epochs = 0
155+
self.pac = pac
156+
157+
if not cuda or not torch.cuda.is_available():
158+
device = 'cpu'
159+
elif isinstance(cuda, str):
160+
device = cuda
161+
else:
162+
device = 'cuda'
163+
164+
self._device = torch.device(device)
165+
166+
self._transformer = None
167+
self._data_sampler = None
168+
self._generator = None
150169

151170
@staticmethod
152171
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
@@ -289,18 +308,19 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
289308
data_dim
290309
).to(self._device)
291310

292-
self._discriminator = Discriminator(
311+
discriminator = Discriminator(
293312
data_dim + self._data_sampler.dim_cond_vec(),
294-
self._discriminator_dim
313+
self._discriminator_dim,
314+
pac=self.pac
295315
).to(self._device)
296316

297-
self._optimizerG = optim.Adam(
317+
optimizerG = optim.Adam(
298318
self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9),
299319
weight_decay=self._generator_decay
300320
)
301321

302-
self._optimizerD = optim.Adam(
303-
self._discriminator.parameters(), lr=self._discriminator_lr,
322+
optimizerD = optim.Adam(
323+
discriminator.parameters(), lr=self._discriminator_lr,
304324
betas=(0.5, 0.9), weight_decay=self._discriminator_decay
305325
)
306326

@@ -309,7 +329,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
309329

310330
steps_per_epoch = max(len(train_data) // self._batch_size, 1)
311331
for i in range(epochs):
312-
self.trained_epochs += 1
313332
for id_ in range(steps_per_epoch):
314333

315334
for n in range(self._discriminator_steps):
@@ -343,17 +362,17 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
343362
real_cat = real
344363
fake_cat = fake
345364

346-
y_fake = self._discriminator(fake_cat)
347-
y_real = self._discriminator(real_cat)
365+
y_fake = discriminator(fake_cat)
366+
y_real = discriminator(real_cat)
348367

349-
pen = self._discriminator.calc_gradient_penalty(
350-
real_cat, fake_cat, self._device)
368+
pen = discriminator.calc_gradient_penalty(
369+
real_cat, fake_cat, self._device, self.pac)
351370
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
352371

353-
self._optimizerD.zero_grad()
372+
optimizerD.zero_grad()
354373
pen.backward(retain_graph=True)
355374
loss_d.backward()
356-
self._optimizerD.step()
375+
optimizerD.step()
357376

358377
fakez = torch.normal(mean=mean, std=std)
359378
condvec = self._data_sampler.sample_condvec(self._batch_size)
@@ -370,9 +389,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
370389
fakeact = self._apply_activate(fake)
371390

372391
if c1 is not None:
373-
y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1))
392+
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
374393
else:
375-
y_fake = self._discriminator(fakeact)
394+
y_fake = discriminator(fakeact)
376395

377396
if condvec is None:
378397
cross_entropy = 0
@@ -381,9 +400,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
381400

382401
loss_g = -torch.mean(y_fake) + cross_entropy
383402

384-
self._optimizerG.zero_grad()
403+
optimizerG.zero_grad()
385404
loss_g.backward()
386-
self._optimizerG.step()
405+
optimizerG.step()
387406

388407
if self._verbose:
389408
print(f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f},"
@@ -444,7 +463,5 @@ def sample(self, n, condition_column=None, condition_value=None):
444463

445464
def set_device(self, device):
446465
self._device = device
447-
if hasattr(self, '_generator'):
466+
if self._generator is not None:
448467
self._generator.to(self._device)
449-
if hasattr(self, '_discriminator'):
450-
self._discriminator.to(self._device)

ctgan/synthesizers/tvae.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def __init__(
8282
decompress_dims=(128, 128),
8383
l2scale=1e-5,
8484
batch_size=500,
85-
epochs=300
85+
epochs=300,
86+
loss_factor=2,
87+
cuda=True
8688
):
8789

8890
self.embedding_dim = embedding_dim
@@ -91,17 +93,24 @@ def __init__(
9193

9294
self.l2scale = l2scale
9395
self.batch_size = batch_size
94-
self.loss_factor = 2
96+
self.loss_factor = loss_factor
9597
self.epochs = epochs
9698

97-
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
99+
if not cuda or not torch.cuda.is_available():
100+
device = 'cpu'
101+
elif isinstance(cuda, str):
102+
device = cuda
103+
else:
104+
device = 'cuda'
105+
106+
self._device = torch.device(device)
98107

99108
def fit(self, train_data, discrete_columns=tuple()):
100109
self.transformer = DataTransformer()
101110
self.transformer.fit(train_data, discrete_columns)
102111
train_data = self.transformer.transform(train_data)
103112
dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device))
104-
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
113+
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
105114

106115
data_dim = self.transformer.output_dimensions
107116
encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.0
2+
current_version = 0.4.1.dev2
33
commit = True
44
tag = True
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
'scikit-learn>=0.23,<1',
1919
'torch>=1.4,<2',
2020
'torchvision>=0.5.0,<1',
21-
'rdt>=0.4.0,<0.5',
21+
'rdt>=0.4.1,<0.5',
2222
]
2323

2424
setup_requires = [
@@ -99,6 +99,6 @@
9999
test_suite='tests',
100100
tests_require=tests_require,
101101
url='https://github.com/sdv-dev/CTGAN',
102-
version='0.4.0',
102+
version='0.4.1.dev2',
103103
zip_safe=False,
104104
)

tests/integration/test_tvae.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,21 @@ def test_tvae(tmpdir):
3333
assert isinstance(sampled, pd.DataFrame)
3434
assert set(sampled.columns) == set(data.columns)
3535
assert set(sampled.dtypes) == set(data.dtypes)
36+
37+
38+
def test_drop_last_false():
39+
data = pd.DataFrame({
40+
'1': ['a', 'b', 'c'] * 150,
41+
'2': ['a', 'b', 'c'] * 150
42+
})
43+
44+
tvae = TVAESynthesizer(epochs=300)
45+
tvae.fit(data, ['1', '2'])
46+
47+
sampled = tvae.sample(100)
48+
correct = 0
49+
for _, row in sampled.iterrows():
50+
if row['1'] == row['2']:
51+
correct += 1
52+
53+
assert correct >= 95

0 commit comments

Comments
 (0)