Skip to content

Commit 821bb36

Browse files
authored
Fixed NaN != NaN counting bug. (#100)
* Fixed NaN != NaN counting bug. * Improved code.
1 parent 3759676 commit 821bb36

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

ctgan/transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def _fit_discrete(self, column, data):
5050
ohe = OneHotEncodingTransformer()
5151
data = data[:, 0]
5252
ohe.fit(data)
53-
categories = len(set(data))
53+
num_categories = len(ohe.dummies)
5454

5555
return {
5656
'name': column,
5757
'encoder': ohe,
58-
'output_info': [(categories, 'softmax')],
59-
'output_dimensions': categories
58+
'output_info': [(num_categories, 'softmax')],
59+
'output_dimensions': num_categories
6060
}
6161

6262
def fit(self, data, discrete_columns=tuple()):

0 commit comments

Comments
 (0)