Skip to content

Commit cd56f03

Browse files
authored
Issue 115 (#117)
* Fixes max() of empty list/correctly samples from all rows of _data * Fix lint
1 parent 85e062e commit cd56f03

2 files changed

Lines changed: 17 additions & 2 deletions

File tree

ctgan/data_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def is_discrete_column(column_info):
4141
# Prepare an interval matrix for efficiently sample conditional vector
4242
max_category = max(
4343
[column_info[0].dim for column_info in output_info
44-
if is_discrete_column(column_info)])
44+
if is_discrete_column(column_info)], default=0)
4545

4646
self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
4747
self._discrete_column_n_category = np.zeros(
@@ -133,7 +133,7 @@ def sample_data(self, n, col, opt):
133133
n rows of matrix data.
134134
"""
135135
if col is None:
136-
idx = np.random.randint(len(self._data), n)
136+
idx = np.random.randint(len(self._data), size=n)
137137
return self._data[idx]
138138

139139
idx = []

tests/integration/test_ctgan.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@
1717
from ctgan.synthesizers.ctgan import CTGANSynthesizer
1818

1919

20+
def test_ctgan_no_categoricals():
21+
data = pd.DataFrame({
22+
'continuous': np.random.random(1000)
23+
})
24+
25+
ctgan = CTGANSynthesizer(epochs=1)
26+
ctgan.fit(data, [])
27+
28+
sampled = ctgan.sample(100)
29+
30+
assert sampled.shape == (100, 1)
31+
assert isinstance(sampled, pd.DataFrame)
32+
assert set(sampled.columns) == {'continuous'}
33+
34+
2035
def test_ctgan_dataframe():
2136
data = pd.DataFrame({
2237
'continuous': np.random.random(100),

0 commit comments

Comments
 (0)