1414
1515class Discriminator (Module ):
1616
17- def __init__ (self , input_dim , discriminator_dim , pack = 10 ):
17+ def __init__ (self , input_dim , discriminator_dim , pac = 10 ):
1818 super (Discriminator , self ).__init__ ()
19- dim = input_dim * pack
20- self .pack = pack
21- self .packdim = dim
19+ dim = input_dim * pac
20+ self .pac = pac
21+ self .pacdim = dim
2222 seq = []
2323 for item in list (discriminator_dim ):
2424 seq += [Linear (dim , item ), LeakyReLU (0.2 ), Dropout (0.5 )]
@@ -49,8 +49,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
4949 return gradient_penalty
5050
5151 def forward (self , input ):
52- assert input .size ()[0 ] % self .pack == 0
53- return self .seq (input .view (- 1 , self .packdim ))
52+ assert input .size ()[0 ] % self .pac == 0
53+ return self .seq (input .view (- 1 , self .pacdim ))
5454
5555
5656class Residual (Module ):
@@ -122,12 +122,19 @@ class CTGANSynthesizer(BaseSynthesizer):
122122 Whether to have print statements for progress results. Defaults to ``False``.
123123 epochs (int):
124124 Number of training epochs. Defaults to 300.
125+ pac (int):
126+ Number of samples to group together when applying the discriminator.
127+ Defaults to 10.
128+ cuda (bool):
129+ Whether to attempt to use cuda for GPU computation.
130+ If this is False or CUDA is not available, CPU will be used.
131+ Defaults to ``True``.
125132 """
126133
127134 def __init__ (self , embedding_dim = 128 , generator_dim = (256 , 256 ), discriminator_dim = (256 , 256 ),
128135 generator_lr = 2e-4 , generator_decay = 1e-6 , discriminator_lr = 2e-4 ,
129136 discriminator_decay = 0 , batch_size = 500 , discriminator_steps = 1 , log_frequency = True ,
130- verbose = False , epochs = 300 ):
137+ verbose = False , epochs = 300 , pac = 10 , cuda = True ):
131138
132139 assert batch_size % 2 == 0
133140
@@ -145,8 +152,21 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di
145152 self ._log_frequency = log_frequency
146153 self ._verbose = verbose
147154 self ._epochs = epochs
148- self ._device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
149155 self .trained_epochs = 0
156+ self .pac = pac
157+
158+ if not cuda or not torch .cuda .is_available ():
159+ device = 'cpu'
160+ elif isinstance (cuda , str ):
161+ device = cuda
162+ else :
163+ device = 'cuda'
164+
165+ self ._device = torch .device (device )
166+
167+ self ._transformer = None
168+ self ._data_sampler = None
169+ self ._generator = None
150170
151171 @staticmethod
152172 def _gumbel_softmax (logits , tau = 1 , hard = False , eps = 1e-10 , dim = - 1 ):
@@ -289,18 +309,19 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
289309 data_dim
290310 ).to (self ._device )
291311
292- self . _discriminator = Discriminator (
312+ discriminator = Discriminator (
293313 data_dim + self ._data_sampler .dim_cond_vec (),
294- self ._discriminator_dim
314+ self ._discriminator_dim ,
315+ pac = self .pac
295316 ).to (self ._device )
296317
297- self . _optimizerG = optim .Adam (
318+ optimizerG = optim .Adam (
298319 self ._generator .parameters (), lr = self ._generator_lr , betas = (0.5 , 0.9 ),
299320 weight_decay = self ._generator_decay
300321 )
301322
302- self . _optimizerD = optim .Adam (
303- self . _discriminator .parameters (), lr = self ._discriminator_lr ,
323+ optimizerD = optim .Adam (
324+ discriminator .parameters (), lr = self ._discriminator_lr ,
304325 betas = (0.5 , 0.9 ), weight_decay = self ._discriminator_decay
305326 )
306327
@@ -343,17 +364,17 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
343364 real_cat = real
344365 fake_cat = fake
345366
346- y_fake = self . _discriminator (fake_cat )
347- y_real = self . _discriminator (real_cat )
367+ y_fake = discriminator (fake_cat )
368+ y_real = discriminator (real_cat )
348369
349- pen = self . _discriminator .calc_gradient_penalty (
350- real_cat , fake_cat , self ._device )
370+ pen = discriminator .calc_gradient_penalty (
371+ real_cat , fake_cat , self ._device , self . pac )
351372 loss_d = - (torch .mean (y_real ) - torch .mean (y_fake ))
352373
353- self . _optimizerD .zero_grad ()
374+ optimizerD .zero_grad ()
354375 pen .backward (retain_graph = True )
355376 loss_d .backward ()
356- self . _optimizerD .step ()
377+ optimizerD .step ()
357378
358379 fakez = torch .normal (mean = mean , std = std )
359380 condvec = self ._data_sampler .sample_condvec (self ._batch_size )
@@ -370,9 +391,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
370391 fakeact = self ._apply_activate (fake )
371392
372393 if c1 is not None :
373- y_fake = self . _discriminator (torch .cat ([fakeact , c1 ], dim = 1 ))
394+ y_fake = discriminator (torch .cat ([fakeact , c1 ], dim = 1 ))
374395 else :
375- y_fake = self . _discriminator (fakeact )
396+ y_fake = discriminator (fakeact )
376397
377398 if condvec is None :
378399 cross_entropy = 0
@@ -381,9 +402,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
381402
382403 loss_g = - torch .mean (y_fake ) + cross_entropy
383404
384- self . _optimizerG .zero_grad ()
405+ optimizerG .zero_grad ()
385406 loss_g .backward ()
386- self . _optimizerG .step ()
407+ optimizerG .step ()
387408
388409 if self ._verbose :
389410 print (f"Epoch { i + 1 } , Loss G: { loss_g .detach ().cpu (): .4f} ,"
@@ -444,7 +465,5 @@ def sample(self, n, condition_column=None, condition_value=None):
444465
445466 def set_device (self , device ):
446467 self ._device = device
447- if hasattr ( self , ' _generator' ) :
468+ if self . _generator is not None :
448469 self ._generator .to (self ._device )
449- if hasattr (self , '_discriminator' ):
450- self ._discriminator .to (self ._device )
0 commit comments