33import torch
44import logging
55import warnings
6+ import numpy as np
67import torch .nn as nn
78
89from . import _constants as const
910
1011
11- def torchensemble_model_doc (header , item ):
12+ def torchensemble_model_doc (header = "" , item = "model" ):
1213 """
1314 A decorator on obtaining documentation for different methods in the
1415 ensemble. This decorator is modified from `sklearn.py` in XGBoost.
@@ -27,12 +28,13 @@ def get_doc(item):
2728 "model" : const .__model_doc ,
2829 "seq_model" : const .__seq_model_doc ,
2930 "fit" : const .__fit_doc ,
31+ "predict" : const .__predict_doc ,
3032 "set_optimizer" : const .__set_optimizer_doc ,
3133 "set_scheduler" : const .__set_scheduler_doc ,
3234 "classifier_forward" : const .__classification_forward_doc ,
33- "classifier_predict " : const .__classification_predict_doc ,
35+ "classifier_evaluate " : const .__classification_evaluate_doc ,
3436 "regressor_forward" : const .__regression_forward_doc ,
35- "regressor_predict " : const .__regression_predict_doc ,
37+ "regressor_evaluate " : const .__regression_evaluate_doc ,
3638 }
3739 return __doc [item ]
3840
@@ -45,7 +47,7 @@ def adddoc(cls):
4547 return adddoc
4648
4749
48- class BaseModule (abc . ABC , nn .Module ):
50+ class BaseModule (nn .Module ):
4951 """Base class for all ensembles.
5052
5153 WARNING: This class cannot be used directly.
@@ -160,13 +162,13 @@ def _validate_parameters(self, epochs, log_interval):
160162 @abc .abstractmethod
161163 def set_optimizer (self , optimizer_name , ** kwargs ):
162164 """
163- Implementation on the process of setting the optimizer.
165+ Implementation on setting the parameter optimizer.
164166 """
165167
166168 @abc .abstractmethod
167169 def set_scheduler (self , scheduler_name , ** kwargs ):
168170 """
169- Implementation on the process of setting the scheduler.
171+ Implementation on setting the learning rate scheduler.
170172 """
171173
172174 @abc .abstractmethod
@@ -191,8 +193,81 @@ def fit(
191193 Implementation on the training stage of the ensemble.
192194 """
193195
194- @abc .abstractmethod
195- def predict (self , test_loader ):
196- """
197- Implementation on the evaluating stage of the ensemble.
198- """
196+ @torch .no_grad ()
197+ def predict (self , X , return_numpy = True ):
198+ """Docstrings decorated by downstream models."""
199+ self .eval ()
200+ pred = None
201+
202+ if isinstance (X , torch .Tensor ):
203+ pred = self .forward (X .to (self .device ))
204+ elif isinstance (X , np .ndarray ):
205+ X = torch .Tensor (X ).to (self .device )
206+ pred = self .forward (X )
207+ else :
208+ msg = (
209+ "The type of input X should be one of {{torch.Tensor,"
210+ " np.ndarray}}."
211+ )
212+ raise ValueError (msg )
213+
214+ pred = pred .cpu ()
215+ if return_numpy :
216+ return pred .numpy ()
217+
218+ return pred
219+
220+
221+ class BaseClassifier (BaseModule ):
222+ """Base class for all ensemble classifiers.
223+
224+ WARNING: This class cannot be used directly.
225+ Please use the derived classes instead.
226+ """
227+
228+ @torch .no_grad ()
229+ def evaluate (self , test_loader , return_loss = False ):
230+ """Docstrings decorated by downstream models."""
231+ self .eval ()
232+ correct = 0
233+ total = 0
234+ criterion = nn .CrossEntropyLoss ()
235+ loss = 0.0
236+
237+ for _ , (data , target ) in enumerate (test_loader ):
238+ data , target = data .to (self .device ), target .to (self .device )
239+ output = self .forward (data )
240+ _ , predicted = torch .max (output .data , 1 )
241+ correct += (predicted == target ).sum ().item ()
242+ total += target .size (0 )
243+ loss += criterion (output , target )
244+
245+ acc = 100 * correct / total
246+ loss /= len (test_loader )
247+
248+ if return_loss :
249+ return acc , float (loss )
250+
251+ return acc
252+
253+
254+ class BaseRegressor (BaseModule ):
255+ """Base class for all ensemble regressors.
256+
257+ WARNING: This class cannot be used directly.
258+ Please use the derived classes instead.
259+ """
260+
261+ @torch .no_grad ()
262+ def evaluate (self , test_loader ):
263+ """Docstrings decorated by downstream models."""
264+ self .eval ()
265+ mse = 0.0
266+ criterion = nn .MSELoss ()
267+
268+ for _ , (data , target ) in enumerate (test_loader ):
269+ data , target = data .to (self .device ), target .to (self .device )
270+ output = self .forward (data )
271+ mse += criterion (output , target )
272+
273+ return float (mse ) / len (test_loader )
0 commit comments