Skip to content

Commit ac87f05

Browse files
authored
Merge pull request #20 from kevinykuo/feature/sampling-freq
Expose `log_frequency` parameter for conditional sampling
2 parents 8fdb36b + 06a7033 commit ac87f05

3 files changed

Lines changed: 35 additions & 4 deletions

File tree

ctgan/conditional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class ConditionalGenerator(object):
5-
def __init__(self, data, output_info):
5+
def __init__(self, data, output_info, log_frequency):
66
self.model = []
77

88
start = 0
@@ -50,7 +50,8 @@ def __init__(self, data, output_info):
5050
continue
5151
end = start + item[0]
5252
tmp = np.sum(data[:, start:end], axis=0)
53-
tmp = np.log(tmp + 1)
53+
if log_frequency:
54+
tmp = np.log(tmp + 1)
5455
tmp = tmp / np.sum(tmp)
5556
self.p[self.n_col, :item[0]] = tmp
5657
self.interval.append((self.n_opt, item[0]))

ctgan/synthesizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _cond_loss(self, data, c, m):
9595

9696
return (loss * m).sum() / data.size()[0]
9797

98-
def fit(self, train_data, discrete_columns=tuple(), epochs=300):
98+
def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True):
9999
"""Fit the CTGAN Synthesizer models to the training data.
100100
101101
Args:
@@ -109,6 +109,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
109109
a ``pandas.DataFrame``, this list should contain the column names.
110110
epochs (int):
111111
Number of training epochs. Defaults to 300.
112+
log_frequency (boolean):
113+
Whether to use log frequency of categorical levels in conditional
114+
sampling. Defaults to ``True``.
112115
"""
113116

114117
self.transformer = DataTransformer()
@@ -118,7 +121,11 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
118121
data_sampler = Sampler(train_data, self.transformer.output_info)
119122

120123
data_dim = self.transformer.output_dimensions
121-
self.cond_generator = ConditionalGenerator(train_data, self.transformer.output_info)
124+
self.cond_generator = ConditionalGenerator(
125+
train_data,
126+
self.transformer.output_info,
127+
log_frequency
128+
)
122129

123130
self.generator = Generator(
124131
self.embedding_dim + self.cond_generator.n_opt,

tests/integration/test_ctgan.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,26 @@ def test_ctgan_numpy():
4848
assert sampled.shape == (100, 2)
4949
assert isinstance(sampled, np.ndarray)
5050
assert set(np.unique(sampled[:, 1])) == {'a', 'b', 'c'}
51+
52+
53+
def test_log_frequency():
54+
data = pd.DataFrame({
55+
'continuous': np.random.random(1000),
56+
'discrete': np.random.choice(['a', 'b', 'c'], 1000, p=[0.95, 0.025, 0.025])
57+
})
58+
59+
discrete_columns = ['discrete']
60+
61+
ctgan = CTGANSynthesizer()
62+
ctgan.fit(data, discrete_columns, epochs=100)
63+
64+
sampled = ctgan.sample(1000)
65+
counts = sampled['discrete'].value_counts()
66+
assert counts['a'] < 650
67+
68+
ctgan = CTGANSynthesizer()
69+
ctgan.fit(data, discrete_columns, epochs=100, log_frequency=False)
70+
71+
sampled = ctgan.sample(1000)
72+
counts = sampled['discrete'].value_counts()
73+
assert counts['a'] > 900

0 commit comments

Comments
 (0)