Skip to content

Commit cc04ff2

Browse files
committed
add test
1 parent 2901e03 commit cc04ff2

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

tests/integration/test_ctgan.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,26 @@ def test_ctgan_numpy():
4848
assert sampled.shape == (100, 2)
4949
assert isinstance(sampled, np.ndarray)
5050
assert set(np.unique(sampled[:, 1])) == {'a', 'b', 'c'}
51+
52+
53+
def test_log_frequency():
54+
data = pd.DataFrame({
55+
'continuous': np.random.random(1000),
56+
'discrete': np.random.choice(['a', 'b', 'c'], 1000, p=[0.95, 0.025, 0.025])
57+
})
58+
59+
discrete_columns = ['discrete']
60+
61+
ctgan = CTGANSynthesizer()
62+
ctgan.fit(data, discrete_columns, epochs=100)
63+
64+
sampled = ctgan.sample(1000)
65+
counts = sampled['discrete'].value_counts()
66+
assert counts['a'] < 650
67+
68+
ctgan = CTGANSynthesizer()
69+
ctgan.fit(data, discrete_columns, epochs=100, log_frequency=False)
70+
71+
sampled = ctgan.sample(1000)
72+
counts = sampled['discrete'].value_counts()
73+
assert counts['a'] > 900

0 commit comments

Comments
 (0)