Skip to content

Commit 790c175

Browse files
authored
TVAE: Fix #135 (#136)
* Bump version: 0.4.1.dev0 → 0.4.1.dev1 * Set drop_last to False * Add drop_last variable * Fix lint * Changed drop_last to False
1 parent 38f0d30 commit 790c175

6 files changed

Lines changed: 23 additions & 5 deletions

File tree

conda/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{% set name = 'ctgan' %}
2-
{% set version = '0.4.1.dev0' %}
2+
{% set version = '0.4.1.dev1' %}
33

44
package:
55
name: "{{ name|lower }}"

ctgan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
__author__ = 'MIT Data To AI Lab'
66
__email__ = 'dailabmit@gmail.com'
7-
__version__ = '0.4.1.dev0'
7+
__version__ = '0.4.1.dev1'
88

99
from ctgan.demo import load_demo
1010
from ctgan.synthesizers.ctgan import CTGANSynthesizer

ctgan/synthesizers/tvae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def fit(self, train_data, discrete_columns=tuple()):
110110
self.transformer.fit(train_data, discrete_columns)
111111
train_data = self.transformer.transform(train_data)
112112
dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device))
113-
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
113+
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
114114

115115
data_dim = self.transformer.output_dimensions
116116
encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.1.dev0
2+
current_version = 0.4.1.dev1
33
commit = True
44
tag = True
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,6 @@
9999
test_suite='tests',
100100
tests_require=tests_require,
101101
url='https://github.com/sdv-dev/CTGAN',
102-
version='0.4.1.dev0',
102+
version='0.4.1.dev1',
103103
zip_safe=False,
104104
)

tests/integration/test_tvae.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,21 @@ def test_tvae(tmpdir):
3333
assert isinstance(sampled, pd.DataFrame)
3434
assert set(sampled.columns) == set(data.columns)
3535
assert set(sampled.dtypes) == set(data.dtypes)
36+
37+
38+
def test_drop_last_false():
39+
data = pd.DataFrame({
40+
'1': ['a', 'b', 'c'] * 150,
41+
'2': ['a', 'b', 'c'] * 150
42+
})
43+
44+
tvae = TVAESynthesizer(epochs=300)
45+
tvae.fit(data, ['1', '2'])
46+
47+
sampled = tvae.sample(100)
48+
correct = 0
49+
for _, row in sampled.iterrows():
50+
if row['1'] == row['2']:
51+
correct += 1
52+
53+
assert correct >= 95

0 commit comments

Comments
 (0)