6060 Parameters
6161 ----------
6262 train_loader : torch.utils.data.DataLoader
63- A :mod:`torch.utils.data.DataLoader` container that contains the
64- training data.
63+ A data loader that contains the training data.
6564 epochs : int, default=100
6665 The number of training epochs per base estimator.
66+ use_reduction_sum : bool, default=True
67+ Whether to set ``reduction="sum"`` for the internal mean squared
68+ error used to fit each base estimator.
6769 log_interval : int, default=100
6870 The number of batches to wait before logging the training status.
6971 test_loader : torch.utils.data.DataLoader, default=None
70- A :mod:`torch.utils.data.DataLoader` container that contains the
71- evaluating data.
72+ A data loader that contains the evaluating data.
7273
7374 - If ``None``, no validation is conducted after each base
7475 estimator being trained.
8081 adding the base estimator fitted in current iteration, the internal
8182 counter on early stopping will increase by one. When the value of
8283 the internal counter reaches ``early_stopping_rounds``, the
83- training stage will terminate instantly.
84+ training stage will terminate instantly.
8485 save_model : bool, default=True
8586 Specify whether to save the model parameters.
8687
@@ -227,6 +228,7 @@ def fit(
227228 self ,
228229 train_loader ,
229230 epochs = 100 ,
231+ use_reduction_sum = True ,
230232 log_interval = 100 ,
231233 test_loader = None ,
232234 early_stopping_rounds = 2 ,
@@ -243,7 +245,9 @@ def fit(
243245 )
244246
245247 # Utils
246- criterion = nn .MSELoss (reduction = "sum" )
248+ criterion = (
249+ nn .MSELoss (reduction = "sum" ) if use_reduction_sum else nn .MSELoss ()
250+ )
247251 n_counter = 0 # a counter on early stopping
248252
249253 for est_idx , estimator in enumerate (self .estimators_ ):
@@ -397,6 +401,7 @@ def fit(
397401 self ,
398402 train_loader ,
399403 epochs = 100 ,
404+ use_reduction_sum = True ,
400405 log_interval = 100 ,
401406 test_loader = None ,
402407 early_stopping_rounds = 2 ,
@@ -503,6 +508,7 @@ def fit(
503508 self ,
504509 train_loader ,
505510 epochs = 100 ,
511+ use_reduction_sum = True ,
506512 log_interval = 100 ,
507513 test_loader = None ,
508514 early_stopping_rounds = 2 ,
0 commit comments