Skip to content

Commit b1eb01b

Browse files
authored
feat(GBM): Add use_reduction_sum parameter for fit (#60)
* configure mse reduction * Update CHANGELOG.rst
1 parent e540f3d commit b1eb01b

2 files changed

Lines changed: 13 additions & 6 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Changelog
44
Ver 0.1.*
55
---------
66

7+
* |Enhancement| |API| Add ``use_reduction_sum`` parameter for :meth:`fit` of Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
78
* |Feature| |API| Improve the functionality of :meth:`evaluate` and :meth:`predict` | `@xuyxu <https://github.com/xuyxu>`__
89
* |Feature| |API| Add :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
910
* |Enhancement| Add flexible instantiation of optimizers and schedulers | `@cspsampedro <https://github.com/cspsampedro>`__

torchensemble/gradient_boosting.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@
6060
Parameters
6161
----------
6262
train_loader : torch.utils.data.DataLoader
63-
A :mod:`torch.utils.data.DataLoader` container that contains the
64-
training data.
63+
A data loader that contains the training data.
6564
epochs : int, default=100
6665
The number of training epochs per base estimator.
66+
use_reduction_sum : bool, default=True
67+
Whether to set ``reduction="sum"`` for the internal mean squared
68+
error used to fit each base estimator.
6769
log_interval : int, default=100
6870
The number of batches to wait before logging the training status.
6971
test_loader : torch.utils.data.DataLoader, default=None
70-
A :mod:`torch.utils.data.DataLoader` container that contains the
71-
evaluating data.
72+
A data loader that contains the evaluating data.
7273
7374
- If ``None``, no validation is conducted after each base
7475
estimator being trained.
@@ -80,7 +81,7 @@
8081
adding the base estimator fitted in current iteration, the internal
8182
counter on early stopping will increase by one. When the value of
8283
the internal counter reaches ``early_stopping_rounds``, the
83-
training stage will terminate instantly.
84+
training stage will terminate instantly.
8485
save_model : bool, default=True
8586
Specify whether to save the model parameters.
8687
@@ -227,6 +228,7 @@ def fit(
227228
self,
228229
train_loader,
229230
epochs=100,
231+
use_reduction_sum=True,
230232
log_interval=100,
231233
test_loader=None,
232234
early_stopping_rounds=2,
@@ -243,7 +245,9 @@ def fit(
243245
)
244246

245247
# Utils
246-
criterion = nn.MSELoss(reduction="sum")
248+
criterion = (
249+
nn.MSELoss(reduction="sum") if use_reduction_sum else nn.MSELoss()
250+
)
247251
n_counter = 0 # a counter on early stopping
248252

249253
for est_idx, estimator in enumerate(self.estimators_):
@@ -397,6 +401,7 @@ def fit(
397401
self,
398402
train_loader,
399403
epochs=100,
404+
use_reduction_sum=True,
400405
log_interval=100,
401406
test_loader=None,
402407
early_stopping_rounds=2,
@@ -503,6 +508,7 @@ def fit(
503508
self,
504509
train_loader,
505510
epochs=100,
511+
use_reduction_sum=True,
506512
log_interval=100,
507513
test_loader=None,
508514
early_stopping_rounds=2,

0 commit comments

Comments
 (0)