Skip to content

Commit e540f3d

Browse files
authored
feat(model): add evaluate and refactor predict (#59)
* feat(model): add `evaluate` and refactor `predict` * doc: update CHANGELOG.rst * test(predict): add unit tests on `predict` * feat: add no_grad decorator for `evaluate` * style: minor fix for mse * chore: minor improvement on mse * fix: update example code * doc: improve docstrings * test: improve unit tests
1 parent 939db83 commit e540f3d

15 files changed

Lines changed: 372 additions & 309 deletions

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Changelog
44
Ver 0.1.*
55
---------
66

7+
* |Feature| |API| Improve the functionality of :meth:`evaluate` and :meth:`predict` | `@xuyxu <https://github.com/xuyxu>`__
78
* |Feature| |API| Add :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
89
* |Enhancement| Add flexible instantiation of optimizers and schedulers | `@cspsampedro <https://github.com/cspsampedro>`__
910
* |Feature| |API| Add support on accepting instantiated base estimators as valid input | `@xuyxu <https://github.com/xuyxu>`__

examples/classification_cifar10_cnn.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torchensemble.bagging import BaggingClassifier
1313
from torchensemble.gradient_boosting import GradientBoostingClassifier
1414
from torchensemble.snapshot_ensemble import SnapshotEnsembleClassifier
15-
from torchensemble.fast_geometric import FastGeometricClassifier
1615

1716
from torchensemble.utils.logging import set_logger
1817

@@ -115,7 +114,7 @@ def forward(self, x):
115114

116115
# Evaluating
117116
tic = time.time()
118-
testing_acc = model.predict(test_loader)
117+
testing_acc = model.evaluate(test_loader)
119118
toc = time.time()
120119
evaluating_time = toc - tic
121120

@@ -131,13 +130,15 @@ def forward(self, x):
131130
# Set the optimizer
132131
model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)
133132

133+
# Training
134134
tic = time.time()
135135
model.fit(train_loader, epochs=epochs)
136136
toc = time.time()
137137
training_time = toc - tic
138138

139+
# Evaluating
139140
tic = time.time()
140-
testing_acc = model.predict(test_loader)
141+
testing_acc = model.evaluate(test_loader)
141142
toc = time.time()
142143
evaluating_time = toc - tic
143144

@@ -153,13 +154,15 @@ def forward(self, x):
153154
# Set the optimizer
154155
model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)
155156

157+
# Training
156158
tic = time.time()
157159
model.fit(train_loader, epochs=epochs)
158160
toc = time.time()
159161
training_time = toc - tic
160162

163+
# Evaluating
161164
tic = time.time()
162-
testing_acc = model.predict(test_loader)
165+
testing_acc = model.evaluate(test_loader)
163166
toc = time.time()
164167
evaluating_time = toc - tic
165168

@@ -175,13 +178,15 @@ def forward(self, x):
175178
# Set the optimizer
176179
model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)
177180

181+
# Training
178182
tic = time.time()
179183
model.fit(train_loader, epochs=epochs)
180184
toc = time.time()
181185
training_time = toc - tic
182186

187+
# Evaluating
183188
tic = time.time()
184-
testing_acc = model.predict(test_loader)
189+
testing_acc = model.evaluate(test_loader)
185190
toc = time.time()
186191
evaluating_time = toc - tic
187192

@@ -202,13 +207,15 @@ def forward(self, x):
202207
# Set the optimizer
203208
model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)
204209

210+
# Training
205211
tic = time.time()
206212
model.fit(train_loader, epochs=epochs)
207213
toc = time.time()
208214
training_time = toc - tic
209215

216+
# Evaluating
210217
tic = time.time()
211-
testing_acc = model.predict(test_loader)
218+
testing_acc = model.evaluate(test_loader)
212219
toc = time.time()
213220
evaluating_time = toc - tic
214221

examples/fast_geometric_ensemble_cifar10_resnet18.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,5 +168,5 @@ def forward(self, x):
168168
)
169169

170170
# Evaluate
171-
acc = model.predict(test_loader)
171+
acc = model.evaluate(test_loader)
172172
print("Testing Acc: {:.3f}".format(acc))

examples/regression_YearPredictionMSD_mlp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def forward(self, x):
114114
training_time = toc - tic
115115

116116
tic = time.time()
117-
testing_mse = model.predict(test_loader)
117+
testing_mse = model.evaluate(test_loader)
118118
toc = time.time()
119119
evaluating_time = toc - tic
120120

@@ -136,7 +136,7 @@ def forward(self, x):
136136
training_time = toc - tic
137137

138138
tic = time.time()
139-
testing_mse = model.predict(test_loader)
139+
testing_mse = model.evaluate(test_loader)
140140
toc = time.time()
141141
evaluating_time = toc - tic
142142

@@ -158,7 +158,7 @@ def forward(self, x):
158158
training_time = toc - tic
159159

160160
tic = time.time()
161-
testing_mse = model.predict(test_loader)
161+
testing_mse = model.evaluate(test_loader)
162162
toc = time.time()
163163
evaluating_time = toc - tic
164164

@@ -180,7 +180,7 @@ def forward(self, x):
180180
training_time = toc - tic
181181

182182
tic = time.time()
183-
testing_mse = model.predict(test_loader)
183+
testing_mse = model.evaluate(test_loader)
184184
toc = time.time()
185185
evaluating_time = toc - tic
186186

@@ -207,7 +207,7 @@ def forward(self, x):
207207
training_time = toc - tic
208208

209209
tic = time.time()
210-
testing_acc = model.predict(test_loader)
210+
testing_acc = model.evaluate(test_loader)
211211
toc = time.time()
212212
evaluating_time = toc - tic
213213

torchensemble/_base.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import torch
44
import logging
55
import warnings
6+
import numpy as np
67
import torch.nn as nn
78

89
from . import _constants as const
910

1011

11-
def torchensemble_model_doc(header, item):
12+
def torchensemble_model_doc(header="", item="model"):
1213
"""
1314
A decorator on obtaining documentation for different methods in the
1415
ensemble. This decorator is modified from `sklearn.py` in XGBoost.
@@ -27,12 +28,13 @@ def get_doc(item):
2728
"model": const.__model_doc,
2829
"seq_model": const.__seq_model_doc,
2930
"fit": const.__fit_doc,
31+
"predict": const.__predict_doc,
3032
"set_optimizer": const.__set_optimizer_doc,
3133
"set_scheduler": const.__set_scheduler_doc,
3234
"classifier_forward": const.__classification_forward_doc,
33-
"classifier_predict": const.__classification_predict_doc,
35+
"classifier_evaluate": const.__classification_evaluate_doc,
3436
"regressor_forward": const.__regression_forward_doc,
35-
"regressor_predict": const.__regression_predict_doc,
37+
"regressor_evaluate": const.__regression_evaluate_doc,
3638
}
3739
return __doc[item]
3840

@@ -45,7 +47,7 @@ def adddoc(cls):
4547
return adddoc
4648

4749

48-
class BaseModule(abc.ABC, nn.Module):
50+
class BaseModule(nn.Module):
4951
"""Base class for all ensembles.
5052
5153
WARNING: This class cannot be used directly.
@@ -160,13 +162,13 @@ def _validate_parameters(self, epochs, log_interval):
160162
@abc.abstractmethod
161163
def set_optimizer(self, optimizer_name, **kwargs):
162164
"""
163-
Implementation on the process of setting the optimizer.
165+
Implementation on setting the parameter optimizer.
164166
"""
165167

166168
@abc.abstractmethod
167169
def set_scheduler(self, scheduler_name, **kwargs):
168170
"""
169-
Implementation on the process of setting the scheduler.
171+
Implementation on setting the learning rate scheduler.
170172
"""
171173

172174
@abc.abstractmethod
@@ -191,8 +193,81 @@ def fit(
191193
Implementation on the training stage of the ensemble.
192194
"""
193195

194-
@abc.abstractmethod
195-
def predict(self, test_loader):
196-
"""
197-
Implementation on the evaluating stage of the ensemble.
198-
"""
196+
@torch.no_grad()
197+
def predict(self, X, return_numpy=True):
198+
"""Docstrings decorated by downstream models."""
199+
self.eval()
200+
pred = None
201+
202+
if isinstance(X, torch.Tensor):
203+
pred = self.forward(X.to(self.device))
204+
elif isinstance(X, np.ndarray):
205+
X = torch.Tensor(X).to(self.device)
206+
pred = self.forward(X)
207+
else:
208+
msg = (
209+
"The type of input X should be one of {{torch.Tensor,"
210+
" np.ndarray}}."
211+
)
212+
raise ValueError(msg)
213+
214+
pred = pred.cpu()
215+
if return_numpy:
216+
return pred.numpy()
217+
218+
return pred
219+
220+
221+
class BaseClassifier(BaseModule):
222+
"""Base class for all ensemble classifiers.
223+
224+
WARNING: This class cannot be used directly.
225+
Please use the derived classes instead.
226+
"""
227+
228+
@torch.no_grad()
229+
def evaluate(self, test_loader, return_loss=False):
230+
"""Docstrings decorated by downstream models."""
231+
self.eval()
232+
correct = 0
233+
total = 0
234+
criterion = nn.CrossEntropyLoss()
235+
loss = 0.0
236+
237+
for _, (data, target) in enumerate(test_loader):
238+
data, target = data.to(self.device), target.to(self.device)
239+
output = self.forward(data)
240+
_, predicted = torch.max(output.data, 1)
241+
correct += (predicted == target).sum().item()
242+
total += target.size(0)
243+
loss += criterion(output, target)
244+
245+
acc = 100 * correct / total
246+
loss /= len(test_loader)
247+
248+
if return_loss:
249+
return acc, float(loss)
250+
251+
return acc
252+
253+
254+
class BaseRegressor(BaseModule):
255+
"""Base class for all ensemble regressors.
256+
257+
WARNING: This class cannot be used directly.
258+
Please use the derived classes instead.
259+
"""
260+
261+
@torch.no_grad()
262+
def evaluate(self, test_loader):
263+
"""Docstrings decorated by downstream models."""
264+
self.eval()
265+
mse = 0.0
266+
criterion = nn.MSELoss()
267+
268+
for _, (data, target) in enumerate(test_loader):
269+
data, target = data.to(self.device), target.to(self.device)
270+
output = self.forward(data)
271+
mse += criterion(output, target)
272+
273+
return float(mse) / len(test_loader)

torchensemble/_constants.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,27 @@
119119
"""
120120

121121

122+
__predict_doc = """
123+
Return the predictions of the ensemble given the testing data.
124+
125+
Parameters
126+
----------
127+
X : {Tensor, ndarray}
128+
A data batch in the form of tensor or Numpy array.
129+
return_numpy : bool, default=True
130+
Whether to convert the predictions into a Numpy array.
131+
132+
Returns
133+
-------
134+
pred : Array of shape (n_samples, n_outputs)
135+
For classifiers, ``n_outputs`` is the number of distinct classes. For
136+
regressors, ``n_output`` is the number of target variables.
137+
138+
- If ``return_numpy`` is ``False``, the result is a tensor.
139+
- If ``return_numpy`` is ``True``, the result is a Numpy array.
140+
"""
141+
142+
122143
__classification_forward_doc = """
123144
Parameters
124145
----------
@@ -133,17 +154,25 @@
133154
"""
134155

135156

136-
__classification_predict_doc = """
157+
__classification_evaluate_doc = """
158+
Compute the classification accuracy of the ensemble given the testing
159+
dataloader and optionally the average cross-entropy loss.
160+
137161
Parameters
138162
----------
139163
test_loader : torch.utils.data.DataLoader
140-
A :mod:`torch.utils.data.DataLoader` container that contains the
141-
evaluating data.
164+
A data loader that contains the testing data.
165+
return_loss : bool, default=False
166+
Whether to return the average cross-entropy loss over all batches
167+
in the ``test_loader``.
142168
143169
Returns
144170
-------
145171
accuracy : float
146-
The testing accuracy of the fitted ensemble on ``test_loader``.
172+
The classification accuracy of the fitted ensemble on ``test_loader``.
173+
loss : float
174+
The average cross-entropy loss of the fitted ensemble on
175+
``test_loader``, only available when ``return_loss`` is True.
147176
"""
148177

149178

@@ -161,16 +190,18 @@
161190
"""
162191

163192

164-
__regression_predict_doc = """
193+
__regression_evaluate_doc = """
194+
Compute the mean squared error (MSE) of the ensemble given the testing
195+
dataloader.
196+
165197
Parameters
166198
----------
167199
test_loader : torch.utils.data.DataLoader
168-
A :mod:`torch.utils.data.DataLoader` container that contains the
169-
evaluating data.
200+
A data loader that contains the testing data.
170201
171202
Returns
172203
-------
173204
mse : float
174-
The testing mean squared error (MSE) of the fitted ensemble on
205+
The testing mean squared error of the fitted ensemble on
175206
``test_loader``.
176207
"""

0 commit comments

Comments
 (0)