2323from .utils import operator as op
2424
2525
26- __all__ = [
27- "_BaseAdversarialTraining" ,
28- "AdversarialTrainingClassifier" ,
29- "AdversarialTrainingRegressor" ,
30- ]
26+ __all__ = ["AdversarialTrainingClassifier" , "AdversarialTrainingRegressor" ]
3127
3228
3329__fit_doc = """
@@ -218,10 +214,6 @@ def _validate_parameters(self, epochs, epsilon, log_interval):
218214 "model" ,
219215)
220216class AdversarialTrainingClassifier (_BaseAdversarialTraining , BaseClassifier ):
221- def __init__ (self , ** kwargs ):
222- super ().__init__ (** kwargs )
223- self .is_classification = True
224-
225217 @torchensemble_model_doc (
226218 """Implementation on the data forwarding in AdversarialTrainingClassifier.""" , # noqa: E501
227219 "classifier_forward" ,
@@ -240,17 +232,14 @@ def forward(self, x):
240232 "set_optimizer" ,
241233 )
242234 def set_optimizer (self , optimizer_name , ** kwargs ):
243- self .optimizer_name = optimizer_name
244- self .optimizer_args = kwargs
235+ super ().set_optimizer (optimizer_name , ** kwargs )
245236
246237 @torchensemble_model_doc (
247238 """Set the attributes on scheduler for AdversarialTrainingClassifier.""" , # noqa: E501
248239 "set_scheduler" ,
249240 )
250241 def set_scheduler (self , scheduler_name , ** kwargs ):
251- self .scheduler_name = scheduler_name
252- self .scheduler_args = kwargs
253- self .use_scheduler_ = True
242+ super ().set_scheduler (scheduler_name , ** kwargs )
254243
255244 @_adversarial_training_model_doc (
256245 """Implementation on the training stage of AdversarialTrainingClassifier.""" , # noqa: E501
@@ -268,7 +257,7 @@ def fit(
268257 ):
269258
270259 self ._validate_parameters (epochs , epsilon , log_interval )
271- self .n_outputs = self ._decide_n_outputs (train_loader , True )
260+ self .n_outputs = self ._decide_n_outputs (train_loader )
272261
273262 # Instantiate a pool of base estimators, optimizers, and schedulers.
274263 estimators = []
@@ -404,10 +393,6 @@ def predict(self, X, return_numpy=True):
404393 "model" ,
405394)
406395class AdversarialTrainingRegressor (_BaseAdversarialTraining , BaseRegressor ):
407- def __init__ (self , ** kwargs ):
408- super ().__init__ (** kwargs )
409- self .is_classification = False
410-
411396 @torchensemble_model_doc (
412397 """Implementation on the data forwarding in AdversarialTrainingRegressor.""" , # noqa: E501
413398 "regressor_forward" ,
@@ -424,17 +409,14 @@ def forward(self, x):
424409 "set_optimizer" ,
425410 )
426411 def set_optimizer (self , optimizer_name , ** kwargs ):
427- self .optimizer_name = optimizer_name
428- self .optimizer_args = kwargs
412+ super ().set_optimizer (optimizer_name , ** kwargs )
429413
430414 @torchensemble_model_doc (
431415 """Set the attributes on scheduler for AdversarialTrainingRegressor.""" , # noqa: E501
432416 "set_scheduler" ,
433417 )
434418 def set_scheduler (self , scheduler_name , ** kwargs ):
435- self .scheduler_name = scheduler_name
436- self .scheduler_args = kwargs
437- self .use_scheduler_ = True
419+ super ().set_scheduler (scheduler_name , ** kwargs )
438420
439421 @_adversarial_training_model_doc (
440422 """Implementation on the training stage of AdversarialTrainingRegressor.""" , # noqa: E501
@@ -452,7 +434,7 @@ def fit(
452434 ):
453435
454436 self ._validate_parameters (epochs , epsilon , log_interval )
455- self .n_outputs = self ._decide_n_outputs (train_loader , True )
437+ self .n_outputs = self ._decide_n_outputs (train_loader )
456438
457439 # Instantiate a pool of base estimators, optimizers, and schedulers.
458440 estimators = []
0 commit comments