Skip to content

Commit 317e2e5

Browse files
authored
feat: add internal unsqueeze operation in forward of all classifiers (#136)
1 parent 365690a commit 317e2e5

12 files changed

Lines changed: 46 additions & 17 deletions

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21+
* |Feature| Add internal :meth:`unsqueeze` operation in :meth:`forward` of all classifiers | `@xuyxu <https://github.com/xuyxu>`__
2122
* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
2223
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
2324
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__

torchensemble/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_doc(item):
2929
__doc = {
3030
"model": const.__model_doc,
3131
"seq_model": const.__seq_model_doc,
32-
"tree_ensmeble_model": const.__tree_ensemble_doc,
32+
"tree_ensemble_model": const.__tree_ensemble_doc,
3333
"fit": const.__fit_doc,
3434
"predict": const.__predict_doc,
3535
"set_optimizer": const.__set_optimizer_doc,

torchensemble/adversarial_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ class AdversarialTrainingClassifier(_BaseAdversarialTraining, BaseClassifier):
226226
def forward(self, *x):
227227
# Take the average over class distributions from all base estimators.
228228
outputs = [
229-
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
229+
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
230+
for estimator in self.estimators_
230231
]
231232
proba = op.average(outputs)
232233

torchensemble/bagging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class BaggingClassifier(BaseClassifier):
9494
def forward(self, *x):
9595
# Average over class distributions from all base estimators.
9696
outputs = [
97-
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
97+
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
98+
for estimator in self.estimators_
9899
]
99100
proba = op.average(outputs)
100101

torchensemble/fast_geometric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class FastGeometricClassifier(_BaseFastGeometric, BaseClassifier):
176176
"classifier_forward",
177177
)
178178
def forward(self, *x):
179-
proba = self._forward(*x)
179+
proba = op.unsqueeze_tensor(self._forward(*x))
180180

181181
return F.softmax(proba, dim=1)
182182

torchensemble/fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _forward(self, *x):
3939
"classifier_forward",
4040
)
4141
def forward(self, *x):
42-
output = self._forward(*x)
42+
output = op.unsqueeze_tensor(self._forward(*x))
4343
proba = F.softmax(output, dim=1)
4444

4545
return proba

torchensemble/gradient_boosting.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,10 @@ def fit(
420420
"classifier_forward",
421421
)
422422
def forward(self, *x):
423-
output = [estimator(*x) for estimator in self.estimators_]
423+
output = [
424+
op.unsqueeze_tensor(estimator(*x))
425+
for estimator in self.estimators_
426+
]
424427
output = op.sum_with_multiplicative(output, self.shrinkage_rate)
425428
proba = F.softmax(output, dim=1)
426429

torchensemble/snapshot_ensemble.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,16 @@ class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier):
212212
def __init__(self, voting_strategy="soft", **kwargs):
213213
super().__init__(**kwargs)
214214

215+
implemented_strategies = {"soft", "hard"}
216+
if voting_strategy not in implemented_strategies:
217+
msg = (
218+
"Voting strategy {} is not implemented, "
219+
"please choose from {}."
220+
)
221+
raise ValueError(
222+
msg.format(voting_strategy, implemented_strategies)
223+
)
224+
215225
self.voting_strategy = voting_strategy
216226

217227
@torchensemble_model_doc(
@@ -221,13 +231,13 @@ def __init__(self, voting_strategy="soft", **kwargs):
221231
def forward(self, *x):
222232

223233
outputs = [
224-
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
234+
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
235+
for estimator in self.estimators_
225236
]
226237

227238
if self.voting_strategy == "soft":
228239
proba = op.average(outputs)
229-
230-
elif self.voting_strategy == "hard":
240+
else:
231241
proba = op.majority_vote(outputs)
232242

233243
return proba

torchensemble/soft_gradient_boosting.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,10 @@ def fit(
406406
"classifier_forward",
407407
)
408408
def forward(self, *x):
409-
output = [estimator(*x) for estimator in self.estimators_]
409+
output = [
410+
op.unsqueeze_tensor(estimator(*x))
411+
for estimator in self.estimators_
412+
]
410413
output = op.sum_with_multiplicative(output, self.shrinkage_rate)
411414
proba = F.softmax(output, dim=1)
412415

torchensemble/tests/test_all_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self):
4949
self.linear2 = nn.Linear(2, 2)
5050

5151
def forward(self, X):
52-
X = X.view(X.size()[0], -1)
52+
X = X.view(X.size(0), -1)
5353
output = self.linear1(X)
5454
output = self.linear2(output)
5555
return output

0 commit comments

Comments
 (0)