Skip to content

Commit 5ac5161

Browse files
authored
Validate discrete column (#118)
* Bump version: 0.3.1.dev0 → 0.3.1.dev1 * Validates discrete columns * Fix lint
1 parent cd56f03 commit 5ac5161

6 files changed

Lines changed: 55 additions & 4 deletions

File tree

conda/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{% set name = 'ctgan' %}
2-
{% set version = '0.3.1.dev0' %}
2+
{% set version = '0.3.1.dev1' %}
33

44
package:
55
name: "{{ name|lower }}"

ctgan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
__author__ = 'MIT Data To AI Lab'
66
__email__ = 'dailabmit@gmail.com'
7-
__version__ = '0.3.1.dev0'
7+
__version__ = '0.3.1.dev1'
88

99
from ctgan.demo import load_demo
1010
from ctgan.synthesizers.ctgan import CTGANSynthesizer

ctgan/synthesizers/ctgan.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22

33
import numpy as np
4+
import pandas as pd
45
import torch
56
from packaging import version
67
from torch import optim
@@ -222,6 +223,31 @@ def _cond_loss(self, data, c, m):
222223

223224
return (loss * m).sum() / data.size()[0]
224225

226+
def _validate_discrete_columns(self, train_data, discrete_columns):
227+
"""Check whether ``discrete_columns`` exists in ``train_data``.
228+
229+
Args:
230+
train_data (numpy.ndarray or pandas.DataFrame):
231+
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
232+
discrete_columns (list-like):
233+
List of discrete columns to be used to generate the Conditional
234+
Vector. If ``train_data`` is a Numpy array, this list should
235+
contain the integer indices of the columns. Otherwise, if it is
236+
a ``pandas.DataFrame``, this list should contain the column names.
237+
"""
238+
if isinstance(train_data, pd.DataFrame):
239+
invalid_columns = set(discrete_columns) - set(train_data.columns)
240+
elif isinstance(train_data, np.ndarray):
241+
invalid_columns = []
242+
for column in discrete_columns:
243+
if column < 0 or column >= train_data.shape[1]:
244+
invalid_columns.append(column)
245+
else:
246+
raise TypeError('``train_data`` should be either pd.DataFrame or np.array.')
247+
248+
if invalid_columns:
249+
raise ValueError('Invalid columns found: {}'.format(invalid_columns))
250+
225251
def fit(self, train_data, discrete_columns=tuple(), epochs=None):
226252
"""Fit the CTGAN Synthesizer models to the training data.
227253
@@ -234,6 +260,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
234260
contain the integer indices of the columns. Otherwise, if it is
235261
a ``pandas.DataFrame``, this list should contain the column names.
236262
"""
263+
self._validate_discrete_columns(train_data, discrete_columns)
264+
237265
if epochs is None:
238266
epochs = self._epochs
239267
else:

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.3.1.dev0
2+
current_version = 0.3.1.dev1
33
commit = True
44
tag = True
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,6 @@
9999
test_suite='tests',
100100
tests_require=tests_require,
101101
url='https://github.com/sdv-dev/CTGAN',
102-
version='0.3.1.dev0',
102+
version='0.3.1.dev1',
103103
zip_safe=False,
104104
)

tests/integration/test_ctgan.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
import pandas as pd
16+
import pytest
1617

1718
from ctgan.synthesizers.ctgan import CTGANSynthesizer
1819

@@ -145,3 +146,25 @@ def test_save_load():
145146
sampled = ctgan.sample(1000)
146147
assert set(sampled.columns) == {'continuous', 'discrete'}
147148
assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
149+
150+
151+
def test_wrong_discrete_columns_dataframe():
152+
data = pd.DataFrame({
153+
'discrete': ['a', 'b']
154+
})
155+
discrete_columns = ['b', 'c']
156+
157+
ctgan = CTGANSynthesizer(epochs=1)
158+
with pytest.raises(ValueError):
159+
ctgan.fit(data, discrete_columns)
160+
161+
162+
def test_wrong_discrete_columns_numpy():
163+
data = pd.DataFrame({
164+
'discrete': ['a', 'b']
165+
})
166+
discrete_columns = [0, 1]
167+
168+
ctgan = CTGANSynthesizer(epochs=1)
169+
with pytest.raises(ValueError):
170+
ctgan.fit(data.to_numpy(), discrete_columns)

0 commit comments

Comments
 (0)