@@ -35,14 +35,15 @@ class CTGANSynthesizer(object):
3535 """
3636
3737 def __init__ (self , embedding_dim = 128 , gen_dim = (256 , 256 ), dis_dim = (256 , 256 ),
38- l2scale = 1e-6 , batch_size = 500 ):
38+ l2scale = 1e-6 , batch_size = 500 , log_frequency = True ):
3939
4040 self .embedding_dim = embedding_dim
4141 self .gen_dim = gen_dim
4242 self .dis_dim = dis_dim
4343
4444 self .l2scale = l2scale
4545 self .batch_size = batch_size
46+ self .log_frequency = log_frequency
4647 self .device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
4748 self .trained_epoches = 0
4849
@@ -63,6 +64,9 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
6364 but will be differentiated as if it is the soft sample in autograd
6465 dim (int):
6566 a dimension along which softmax will be computed. Default: -1.
67+ log_frequency (boolean):
68+ Whether to use log frequency of categorical levels in conditional
69+ sampling. Defaults to ``True``.
6670
6771 Returns:
6872 Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
@@ -130,7 +134,7 @@ def _cond_loss(self, data, c, m):
130134
131135 return (loss * m ).sum () / data .size ()[0 ]
132136
133- def fit (self , train_data , discrete_columns = tuple (), epochs = 300 , log_frequency = True ):
137+ def fit (self , train_data , discrete_columns = tuple (), epochs = 300 ):
134138 """Fit the CTGAN Synthesizer models to the training data.
135139
136140 Args:
@@ -144,9 +148,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
144148 a ``pandas.DataFrame``, this list should contain the column names.
145149 epochs (int):
146150 Number of training epochs. Defaults to 300.
147- log_frequency (boolean):
148- Whether to use log frequency of categorical levels in conditional
149- sampling. Defaults to ``True``.
150151 """
151152
152153 if not hasattr (self , "transformer" ):
@@ -162,7 +163,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
162163 self .cond_generator = ConditionalGenerator (
163164 train_data ,
164165 self .transformer .output_info ,
165- log_frequency
166+ self . log_frequency
166167 )
167168
168169 if not hasattr (self , "generator" ):
0 commit comments