1414
1515class Discriminator (Module ):
1616
17- def __init__ (self , input_dim , discriminator_dim , pack = 10 ):
17+ def __init__ (self , input_dim , discriminator_dim , pac = 10 ):
1818 super (Discriminator , self ).__init__ ()
19- dim = input_dim * pack
20- self .pack = pack
21- self .packdim = dim
19+ dim = input_dim * pac
20+ self .pac = pac
21+ self .pacdim = dim
2222 seq = []
2323 for item in list (discriminator_dim ):
2424 seq += [Linear (dim , item ), LeakyReLU (0.2 ), Dropout (0.5 )]
@@ -49,8 +49,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
4949 return gradient_penalty
5050
5151 def forward (self , input ):
52- assert input .size ()[0 ] % self .pack == 0
53- return self .seq (input .view (- 1 , self .packdim ))
52+ assert input .size ()[0 ] % self .pac == 0
53+ return self .seq (input .view (- 1 , self .pacdim ))
5454
5555
5656class Residual (Module ):
@@ -122,12 +122,19 @@ class CTGANSynthesizer(BaseSynthesizer):
122122 Whether to have print statements for progress results. Defaults to ``False``.
123123 epochs (int):
124124 Number of training epochs. Defaults to 300.
125+ pac (int):
126+ Number of samples to group together when applying the discriminator.
127+ Defaults to 10.
128+ cuda (bool):
129+ Whether to attempt to use cuda for GPU computation.
130+ If this is False or CUDA is not available, CPU will be used.
131+ Defaults to ``True``.
125132 """
126133
127134 def __init__ (self , embedding_dim = 128 , generator_dim = (256 , 256 ), discriminator_dim = (256 , 256 ),
128135 generator_lr = 2e-4 , generator_decay = 1e-6 , discriminator_lr = 2e-4 ,
129136 discriminator_decay = 0 , batch_size = 500 , discriminator_steps = 1 , log_frequency = True ,
130- verbose = False , epochs = 300 ):
137+ verbose = False , epochs = 300 , pac = 10 , cuda = True ):
131138
132139 assert batch_size % 2 == 0
133140
@@ -145,8 +152,20 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di
145152 self ._log_frequency = log_frequency
146153 self ._verbose = verbose
147154 self ._epochs = epochs
148- self ._device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
149- self .trained_epochs = 0
155+ self .pac = pac
156+
157+ if not cuda or not torch .cuda .is_available ():
158+ device = 'cpu'
159+ elif isinstance (cuda , str ):
160+ device = cuda
161+ else :
162+ device = 'cuda'
163+
164+ self ._device = torch .device (device )
165+
166+ self ._transformer = None
167+ self ._data_sampler = None
168+ self ._generator = None
150169
151170 @staticmethod
152171 def _gumbel_softmax (logits , tau = 1 , hard = False , eps = 1e-10 , dim = - 1 ):
@@ -289,18 +308,19 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
289308 data_dim
290309 ).to (self ._device )
291310
292- self . _discriminator = Discriminator (
311+ discriminator = Discriminator (
293312 data_dim + self ._data_sampler .dim_cond_vec (),
294- self ._discriminator_dim
313+ self ._discriminator_dim ,
314+ pac = self .pac
295315 ).to (self ._device )
296316
297- self . _optimizerG = optim .Adam (
317+ optimizerG = optim .Adam (
298318 self ._generator .parameters (), lr = self ._generator_lr , betas = (0.5 , 0.9 ),
299319 weight_decay = self ._generator_decay
300320 )
301321
302- self . _optimizerD = optim .Adam (
303- self . _discriminator .parameters (), lr = self ._discriminator_lr ,
322+ optimizerD = optim .Adam (
323+ discriminator .parameters (), lr = self ._discriminator_lr ,
304324 betas = (0.5 , 0.9 ), weight_decay = self ._discriminator_decay
305325 )
306326
@@ -309,7 +329,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
309329
310330 steps_per_epoch = max (len (train_data ) // self ._batch_size , 1 )
311331 for i in range (epochs ):
312- self .trained_epochs += 1
313332 for id_ in range (steps_per_epoch ):
314333
315334 for n in range (self ._discriminator_steps ):
@@ -343,17 +362,17 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
343362 real_cat = real
344363 fake_cat = fake
345364
346- y_fake = self . _discriminator (fake_cat )
347- y_real = self . _discriminator (real_cat )
365+ y_fake = discriminator (fake_cat )
366+ y_real = discriminator (real_cat )
348367
349- pen = self . _discriminator .calc_gradient_penalty (
350- real_cat , fake_cat , self ._device )
368+ pen = discriminator .calc_gradient_penalty (
369+ real_cat , fake_cat , self ._device , self . pac )
351370 loss_d = - (torch .mean (y_real ) - torch .mean (y_fake ))
352371
353- self . _optimizerD .zero_grad ()
372+ optimizerD .zero_grad ()
354373 pen .backward (retain_graph = True )
355374 loss_d .backward ()
356- self . _optimizerD .step ()
375+ optimizerD .step ()
357376
358377 fakez = torch .normal (mean = mean , std = std )
359378 condvec = self ._data_sampler .sample_condvec (self ._batch_size )
@@ -370,9 +389,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
370389 fakeact = self ._apply_activate (fake )
371390
372391 if c1 is not None :
373- y_fake = self . _discriminator (torch .cat ([fakeact , c1 ], dim = 1 ))
392+ y_fake = discriminator (torch .cat ([fakeact , c1 ], dim = 1 ))
374393 else :
375- y_fake = self . _discriminator (fakeact )
394+ y_fake = discriminator (fakeact )
376395
377396 if condvec is None :
378397 cross_entropy = 0
@@ -381,9 +400,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
381400
382401 loss_g = - torch .mean (y_fake ) + cross_entropy
383402
384- self . _optimizerG .zero_grad ()
403+ optimizerG .zero_grad ()
385404 loss_g .backward ()
386- self . _optimizerG .step ()
405+ optimizerG .step ()
387406
388407 if self ._verbose :
389408 print (f"Epoch { i + 1 } , Loss G: { loss_g .detach ().cpu (): .4f} ,"
@@ -444,7 +463,5 @@ def sample(self, n, condition_column=None, condition_value=None):
444463
445464 def set_device (self , device ):
446465 self ._device = device
447- if hasattr ( self , ' _generator' ) :
466+ if self . _generator is not None :
448467 self ._generator .to (self ._device )
449- if hasattr (self , '_discriminator' ):
450- self ._discriminator .to (self ._device )
0 commit comments