@@ -95,7 +95,7 @@ def _cond_loss(self, data, c, m):
9595
9696 return (loss * m ).sum () / data .size ()[0 ]
9797
98- def fit (self , train_data , discrete_columns = tuple (), epochs = 300 ):
98+ def fit (self , train_data , discrete_columns = tuple (), epochs = 300 , log_frequency = True ):
9999 """Fit the CTGAN Synthesizer models to the training data.
100100
101101 Args:
@@ -109,6 +109,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
109109 a ``pandas.DataFrame``, this list should contain the column names.
110110 epochs (int):
111111 Number of training epochs. Defaults to 300.
112+ log_frequency (boolean):
113+ Whether to use log frequency of categorical levels in conditional
114+ sampling. Defaults to ``True``.
112115 """
113116
114117 self .transformer = DataTransformer ()
@@ -118,7 +121,11 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
118121 data_sampler = Sampler (train_data , self .transformer .output_info )
119122
120123 data_dim = self .transformer .output_dimensions
121- self .cond_generator = ConditionalGenerator (train_data , self .transformer .output_info )
124+ self .cond_generator = ConditionalGenerator (
125+ train_data ,
126+ self .transformer .output_info ,
127+ log_frequency
128+ )
122129
123130 self .generator = Generator (
124131 self .embedding_dim + self .cond_generator .n_opt ,
0 commit comments