Skip to content

Commit 86fcd23

Browse files
authored
Fix TVAE loss_function (#144)
* Fixes the issue * Fix lint
1 parent b5900f2 commit 86fcd23

2 files changed

Lines changed: 21 additions & 1 deletion

File tree

ctgan/synthesizers/tvae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
5454
loss = []
5555
for column_info in output_info:
5656
for span_info in column_info:
57-
if len(column_info) != 1 or span_info.activation_fn != "softmax":
57+
if span_info.activation_fn != "softmax":
5858
ed = st + span_info.dim
5959
std = sigmas[st]
6060
loss.append(((x[:, st] - torch.tanh(recon_x[:, st])) ** 2 / 2 / (std ** 2)).sum())

tests/integration/test_tvae.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,23 @@ def test_drop_last_false():
5151
correct += 1
5252

5353
assert correct >= 95
54+
55+
56+
def test_loss_function():
57+
data = pd.DataFrame({
58+
'1': [float(i) for i in range(1000)],
59+
'2': [float(2 * i) for i in range(1000)]
60+
})
61+
62+
tvae = TVAESynthesizer(epochs=300)
63+
tvae.fit(data)
64+
65+
num_samples = 1000
66+
sampled = tvae.sample(num_samples)
67+
error = 0
68+
for _, row in sampled.iterrows():
69+
error += abs(2 * row['1'] - row['2'])
70+
71+
avg_error = error / num_samples
72+
73+
assert avg_error < 400

0 commit comments

Comments
 (0)