Skip to content

Commit 02703c2

Browse files
authored
feat: add soft gradient boosting (#95)
* update code * update doc * update test * update test * update example * update example
1 parent 3a23e56 commit 02703c2

9 files changed

Lines changed: 576 additions & 1 deletion

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Changelog
1717

1818
Ver 0.1.*
1919
---------
20-
20+
* |Feature| |API| Add :class:`SoftGradientBoostingClassifier` and :class:`SoftGradientBoostingRegressor` | `@xuyxu <https://github.com/xuyxu>`__
2121
* |Feature| |API| Support using dataloader with multiple input | `@xuyxu <https://github.com/xuyxu>`__
2222
* |Fix| Fix missing functionality of ``use_reduction_sum`` for :meth:`fit` of Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
2323
* |Enhancement| Relax :mod:`tensorboard` as a soft dependency | `@xuyxu <https://github.com/xuyxu>`__

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ Supported Ensemble
100100
+------------------------------+------------+---------------------------+
101101
| Fast Geometric Ensemble [6]_ | Sequential | fast_geometric.py |
102102
+------------------------------+------------+---------------------------+
103+
| Soft Gradient Boosting [7]_ | Parallel | soft_gradient_boosting.py |
104+
+------------------------------+------------+---------------------------+
103105

104106
Dependencies
105107
------------
@@ -123,6 +125,8 @@ Reference
123125
124126
.. [6] Garipov, Timur, et al. Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs. NeurIPS, 2018.
125127
128+
.. [7] Feng, Ji, et al. Soft Gradient Boosting Machine. ArXiv, 2020.
129+
126130
.. _pytorch: https://pytorch.org/
127131

128132
.. _pypi: https://pypi.org/project/torchensemble/

docs/parameters.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,24 @@ FastGeometricRegressor
165165

166166
.. autoclass:: torchensemble.fast_geometric.FastGeometricRegressor
167167
:members:
168+
169+
Soft Gradient Boosting
170+
----------------------
171+
172+
In soft gradient boosting, all base estimators could be simultaneously fitted,
173+
while achieving the similar boosting improvements as in gradient boosting.
174+
175+
Reference:
176+
J. Feng, Y.-X. Xu et al., Soft Gradient Boosting Machine, ArXiv, 2020.
177+
178+
SoftGradientBoostingClassifier
179+
******************************
180+
181+
.. autoclass:: torchensemble.soft_gradient_boosting.SoftGradientBoostingClassifier
182+
:members:
183+
184+
SoftGradientBoostingRegressor
185+
*****************************
186+
187+
.. autoclass:: torchensemble.soft_gradient_boosting.SoftGradientBoostingRegressor
188+
:members:

examples/classification_cifar10_cnn.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torchensemble.bagging import BaggingClassifier
1313
from torchensemble.gradient_boosting import GradientBoostingClassifier
1414
from torchensemble.snapshot_ensemble import SnapshotEnsembleClassifier
15+
from torchensemble.soft_gradient_boosting import SoftGradientBoostingClassifier
1516

1617
from torchensemble.utils.logging import set_logger
1718

@@ -228,5 +229,34 @@ def forward(self, x):
228229
)
229230
)
230231

232+
# SoftGradientBoostingClassifier
233+
model = SoftGradientBoostingClassifier(
234+
estimator=LeNet5, n_estimators=n_estimators, cuda=True
235+
)
236+
237+
# Set the optimizer
238+
model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)
239+
240+
# Training
241+
tic = time.time()
242+
model.fit(train_loader, epochs=epochs)
243+
toc = time.time()
244+
training_time = toc - tic
245+
246+
# Evaluating
247+
tic = time.time()
248+
testing_acc = model.evaluate(test_loader)
249+
toc = time.time()
250+
evaluating_time = toc - tic
251+
252+
records.append(
253+
(
254+
"SoftGradientBoostingClassifier",
255+
training_time,
256+
evaluating_time,
257+
testing_acc,
258+
)
259+
)
260+
231261
# Print results on different ensemble methods
232262
display_records(records, logger)

torchensemble/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from .adversarial_training import AdversarialTrainingRegressor
1313
from .fast_geometric import FastGeometricClassifier
1414
from .fast_geometric import FastGeometricRegressor
15+
from .soft_gradient_boosting import SoftGradientBoostingClassifier
16+
from .soft_gradient_boosting import SoftGradientBoostingRegressor
1517

1618

1719
__all__ = [
@@ -29,4 +31,6 @@
2931
"AdversarialTrainingRegressor",
3032
"FastGeometricClassifier",
3133
"FastGeometricRegressor",
34+
"SoftGradientBoostingClassifier",
35+
"SoftGradientBoostingRegressor",
3236
]

0 commit comments

Comments
 (0)