File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -66,11 +66,11 @@ def main():
6666 if args .load :
6767 model = CTGANSynthesizer .load (args .load )
6868 else :
69- generator_dims = [int (x ) for x in args .generator_dims .split (',' )]
70- discriminator_dims = [int (x ) for x in args .discriminator_dims .split (',' )]
69+ generator_dim = [int (x ) for x in args .generator_dim .split (',' )]
70+ discriminator_dim = [int (x ) for x in args .discriminator_dim .split (',' )]
7171 model = CTGANSynthesizer (
72- embedding_dim = args .embedding_dim , generator_dims = generator_dims ,
73- discriminator_dims = discriminator_dims , generator_lr = args .generator_lr ,
72+ embedding_dim = args .embedding_dim , generator_dim = generator_dim ,
73+ discriminator_dim = discriminator_dim , generator_lr = args .generator_lr ,
7474 generator_decay = args .generator_decay , discriminator_lr = args .discriminator_lr ,
7575 discriminator_decay = args .discriminator_decay , batch_size = args .batch_size ,
7676 epochs = args .epochs )
Original file line number Diff line number Diff line change 1313
1414class Discriminator (Module ):
1515
16- def __init__ (self , input_dim , dis_dims , pack = 10 ):
16+ def __init__ (self , input_dim , discriminator_dim , pack = 10 ):
1717 super (Discriminator , self ).__init__ ()
1818 dim = input_dim * pack
1919 self .pack = pack
2020 self .packdim = dim
2121 seq = []
22- for item in list (dis_dims ):
22+ for item in list (discriminator_dim ):
2323 seq += [Linear (dim , item ), LeakyReLU (0.2 ), Dropout (0.5 )]
2424 dim = item
2525
You can’t perform that action at this time.
0 commit comments