Skip to content

Commit e1c09e1

Browse files
authored
Issue 102 (#104)
* Moved log_frequency to __init__ * Fixed test.
1 parent 821bb36 commit e1c09e1

2 files changed

Lines changed: 9 additions & 8 deletions

File tree

ctgan/synthesizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ class CTGANSynthesizer(object):
3535
"""
3636

3737
def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
38-
l2scale=1e-6, batch_size=500):
38+
l2scale=1e-6, batch_size=500, log_frequency=True):
3939

4040
self.embedding_dim = embedding_dim
4141
self.gen_dim = gen_dim
4242
self.dis_dim = dis_dim
4343

4444
self.l2scale = l2scale
4545
self.batch_size = batch_size
46+
self.log_frequency = log_frequency
4647
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
4748
self.trained_epoches = 0
4849

@@ -63,6 +64,9 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
6364
but will be differentiated as if it is the soft sample in autograd
6465
dim (int):
6566
a dimension along which softmax will be computed. Default: -1.
67+
log_frequency (boolean):
68+
Whether to use log frequency of categorical levels in conditional
69+
sampling. Defaults to ``True``.
6670
6771
Returns:
6872
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
@@ -130,7 +134,7 @@ def _cond_loss(self, data, c, m):
130134

131135
return (loss * m).sum() / data.size()[0]
132136

133-
def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True):
137+
def fit(self, train_data, discrete_columns=tuple(), epochs=300):
134138
"""Fit the CTGAN Synthesizer models to the training data.
135139
136140
Args:
@@ -144,9 +148,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
144148
a ``pandas.DataFrame``, this list should contain the column names.
145149
epochs (int):
146150
Number of training epochs. Defaults to 300.
147-
log_frequency (boolean):
148-
Whether to use log frequency of categorical levels in conditional
149-
sampling. Defaults to ``True``.
150151
"""
151152

152153
if not hasattr(self, "transformer"):
@@ -162,7 +163,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
162163
self.cond_generator = ConditionalGenerator(
163164
train_data,
164165
self.transformer.output_info,
165-
log_frequency
166+
self.log_frequency
166167
)
167168

168169
if not hasattr(self, "generator"):

tests/integration/test_ctgan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def test_log_frequency():
6666
counts = sampled['discrete'].value_counts()
6767
assert counts['a'] < 6500
6868

69-
ctgan = CTGANSynthesizer()
70-
ctgan.fit(data, discrete_columns, epochs=100, log_frequency=False)
69+
ctgan = CTGANSynthesizer(log_frequency=False)
70+
ctgan.fit(data, discrete_columns, epochs=100)
7171

7272
sampled = ctgan.sample(10000)
7373
counts = sampled['discrete'].value_counts()

0 commit comments

Comments
 (0)