@@ -29,13 +29,20 @@ class CTGANSynthesizer(object):
2929 Size of the output samples for each one of the Discriminator Layers. A Linear Layer
3030 will be created for each one of the values provided. Defaults to (256, 256).
3131 l2scale (float):
32- Wheight Decay for the Adam Optimizer. Defaults to 1e-6.
32+ Weight Decay for the Adam Optimizer. Defaults to 1e-6.
3333 batch_size (int):
3434 Number of data samples to process in each step.
35+ discriminator_steps (int):
36+ Number of discriminator updates to do for each generator update.
37+ From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper
38+ default is 5. Default used is 1 to match original CTGAN implementation.
39+ log_frequency (boolean):
40+ Whether to use log frequency of categorical levels in conditional
41+ sampling. Defaults to ``True``.
3542 """
3643
3744 def __init__ (self , embedding_dim = 128 , gen_dim = (256 , 256 ), dis_dim = (256 , 256 ),
38- l2scale = 1e-6 , batch_size = 500 , log_frequency = True ):
45+ l2scale = 1e-6 , batch_size = 500 , discriminator_steps = 1 , log_frequency = True ):
3946
4047 self .embedding_dim = embedding_dim
4148 self .gen_dim = gen_dim
@@ -46,6 +53,7 @@ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
4653 self .log_frequency = log_frequency
4754 self .device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
4855 self .trained_epoches = 0
56+ self .discriminator_steps = discriminator_steps
4957
5058 @staticmethod
5159 def _gumbel_softmax (logits , tau = 1 , hard = False , eps = 1e-10 , dim = - 1 ):
@@ -64,9 +72,6 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
6472 but will be differentiated as if it is the soft sample in autograd
6573 dim (int):
6674 a dimension along which softmax will be computed. Default: -1.
67- log_frequency (boolean):
68- Whether to use log frequency of categorical levels in conditional
69- sampling. Defaults to ``True``.
7075
7176 Returns:
7277 Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
@@ -197,46 +202,48 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
197202 for i in range (epochs ):
198203 self .trained_epoches += 1
199204 for id_ in range (steps_per_epoch ):
200- fakez = torch .normal (mean = mean , std = std )
201-
202- condvec = self .cond_generator .sample (self .batch_size )
203- if condvec is None :
204- c1 , m1 , col , opt = None , None , None , None
205- real = data_sampler .sample (self .batch_size , col , opt )
206- else :
207- c1 , m1 , col , opt = condvec
208- c1 = torch .from_numpy (c1 ).to (self .device )
209- m1 = torch .from_numpy (m1 ).to (self .device )
210- fakez = torch .cat ([fakez , c1 ], dim = 1 )
211-
212- perm = np .arange (self .batch_size )
213- np .random .shuffle (perm )
214- real = data_sampler .sample (self .batch_size , col [perm ], opt [perm ])
215- c2 = c1 [perm ]
216-
217- fake = self .generator (fakez )
218- fakeact = self ._apply_activate (fake )
219-
220- real = torch .from_numpy (real .astype ('float32' )).to (self .device )
221-
222- if c1 is not None :
223- fake_cat = torch .cat ([fakeact , c1 ], dim = 1 )
224- real_cat = torch .cat ([real , c2 ], dim = 1 )
225- else :
226- real_cat = real
227- fake_cat = fake
228-
229- y_fake = self .discriminator (fake_cat )
230- y_real = self .discriminator (real_cat )
231-
232- pen = self .discriminator .calc_gradient_penalty (
233- real_cat , fake_cat , self .device )
234- loss_d = - (torch .mean (y_real ) - torch .mean (y_fake ))
235205
236- self .optimizerD .zero_grad ()
237- pen .backward (retain_graph = True )
238- loss_d .backward ()
239- self .optimizerD .step ()
206+ for n in range (self .discriminator_steps ):
207+ fakez = torch .normal (mean = mean , std = std )
208+
209+ condvec = self .cond_generator .sample (self .batch_size )
210+ if condvec is None :
211+ c1 , m1 , col , opt = None , None , None , None
212+ real = data_sampler .sample (self .batch_size , col , opt )
213+ else :
214+ c1 , m1 , col , opt = condvec
215+ c1 = torch .from_numpy (c1 ).to (self .device )
216+ m1 = torch .from_numpy (m1 ).to (self .device )
217+ fakez = torch .cat ([fakez , c1 ], dim = 1 )
218+
219+ perm = np .arange (self .batch_size )
220+ np .random .shuffle (perm )
221+ real = data_sampler .sample (self .batch_size , col [perm ], opt [perm ])
222+ c2 = c1 [perm ]
223+
224+ fake = self .generator (fakez )
225+ fakeact = self ._apply_activate (fake )
226+
227+ real = torch .from_numpy (real .astype ('float32' )).to (self .device )
228+
229+ if c1 is not None :
230+ fake_cat = torch .cat ([fakeact , c1 ], dim = 1 )
231+ real_cat = torch .cat ([real , c2 ], dim = 1 )
232+ else :
233+ real_cat = real
234+ fake_cat = fake
235+
236+ y_fake = self .discriminator (fake_cat )
237+ y_real = self .discriminator (real_cat )
238+
239+ pen = self .discriminator .calc_gradient_penalty (
240+ real_cat , fake_cat , self .device )
241+ loss_d = - (torch .mean (y_real ) - torch .mean (y_fake ))
242+
243+ self .optimizerD .zero_grad ()
244+ pen .backward (retain_graph = True )
245+ loss_d .backward ()
246+ self .optimizerD .step ()
240247
241248 fakez = torch .normal (mean = mean , std = std )
242249 condvec = self .cond_generator .sample (self .batch_size )
0 commit comments