Skip to content

Commit 2901e03

Browse files
committed
expose log_frequency param
1 parent c31d3ef commit 2901e03

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

ctgan/conditional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class ConditionalGenerator(object):
5-
def __init__(self, data, output_info):
5+
def __init__(self, data, output_info, log_frequency):
66
self.model = []
77

88
start = 0
@@ -50,7 +50,8 @@ def __init__(self, data, output_info):
5050
continue
5151
end = start + item[0]
5252
tmp = np.sum(data[:, start:end], axis=0)
53-
tmp = np.log(tmp + 1)
53+
if log_frequency:
54+
tmp = np.log(tmp + 1)
5455
tmp = tmp / np.sum(tmp)
5556
self.p[self.n_col, :item[0]] = tmp
5657
self.interval.append((self.n_opt, item[0]))

ctgan/synthesizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _cond_loss(self, data, c, m):
9595

9696
return (loss * m).sum() / data.size()[0]
9797

98-
def fit(self, train_data, discrete_columns=tuple(), epochs=300):
98+
def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True):
9999
"""Fit the CTGAN Synthesizer models to the training data.
100100
101101
Args:
@@ -109,6 +109,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
109109
a ``pandas.DataFrame``, this list should contain the column names.
110110
epochs (int):
111111
Number of training epochs. Defaults to 300.
112+
log_frequency (boolean):
113+
Whether to use log frequency of categorical levels in conditional
114+
sampling. Defaults to True.
112115
"""
113116

114117
self.transformer = DataTransformer()
@@ -118,7 +121,11 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
118121
data_sampler = Sampler(train_data, self.transformer.output_info)
119122

120123
data_dim = self.transformer.output_dimensions
121-
self.cond_generator = ConditionalGenerator(train_data, self.transformer.output_info)
124+
self.cond_generator = ConditionalGenerator(
125+
train_data,
126+
self.transformer.output_info,
127+
log_frequency
128+
)
122129

123130
self.generator = Generator(
124131
self.embedding_dim + self.cond_generator.n_opt,

0 commit comments

Comments
 (0)