1212import torch
1313import torch .nn as nn
1414import torch .nn .functional as F
15+
16+ import warnings
1517from joblib import Parallel , delayed
1618
1719from ._base import BaseModule , torchensemble_model_doc
@@ -81,6 +83,7 @@ def adddoc(cls):
8183def _parallel_fit_per_epoch (train_loader ,
8284 epsilon ,
8385 estimator ,
86+ cur_lr ,
8487 optimizer ,
8588 criterion ,
8689 idx ,
@@ -95,6 +98,10 @@ def _parallel_fit_per_epoch(train_loader,
9598 out-of-memory error.
9699 """
97100
101+ if cur_lr :
102+ # Parallelization corrupts the binding between optimizer and scheduler
103+ set_module .update_lr (optimizer , cur_lr )
104+
98105 for batch_idx , (data , target ) in enumerate (train_loader ):
99106
100107 batch_size = data .size ()[0 ]
@@ -246,11 +253,9 @@ def fit(self,
246253 ** self .optimizer_args ))
247254
248255 if self .use_scheduler_ :
249- schedulers = []
250- for i in range (self .n_estimators ):
251- schedulers .append (set_module .set_scheduler (optimizers [i ],
252- self .scheduler_name ,
253- ** self .scheduler_args )) # noqa: E501
256+ scheduler_ = set_module .set_scheduler (optimizers [0 ],
257+ self .scheduler_name ,
258+ ** self .scheduler_args )
254259
255260 # Utils
256261 criterion = nn .CrossEntropyLoss ()
@@ -271,6 +276,11 @@ def _forward(estimators, data):
271276 for epoch in range (epochs ):
272277 self .train ()
273278
279+ if self .use_scheduler_ :
280+ cur_lr = scheduler_ .get_last_lr ()[0 ]
281+ else :
282+ cur_lr = None
283+
274284 if self .n_jobs and self .n_jobs > 1 :
275285 msg = "Parallelization on the training epoch: {:03d}"
276286 self .logger .info (msg .format (epoch ))
@@ -279,6 +289,7 @@ def _forward(estimators, data):
279289 train_loader ,
280290 epsilon ,
281291 estimator ,
292+ cur_lr ,
282293 optimizer ,
283294 criterion ,
284295 idx ,
@@ -323,9 +334,14 @@ def _forward(estimators, data):
323334 self .logger .info (msg .format (epoch , acc , best_acc ))
324335
325336 # Update the scheduler
326- if self .use_scheduler_ :
327- for i in range (self .n_estimators ):
328- schedulers [i ].step ()
337+ with warnings .catch_warnings ():
338+
339+ # UserWarning raised by PyTorch is ignored because
340+ # scheduler does not have a real effect on the optimier.
341+ warnings .simplefilter ("ignore" , UserWarning )
342+
343+ if self .use_scheduler_ :
344+ scheduler_ .step ()
329345
330346 self .estimators_ = nn .ModuleList ()
331347 self .estimators_ .extend (estimators )
@@ -413,11 +429,9 @@ def fit(self,
413429 ** self .optimizer_args ))
414430
415431 if self .use_scheduler_ :
416- schedulers = []
417- for i in range (self .n_estimators ):
418- schedulers .append (set_module .set_scheduler (optimizers [i ],
419- self .scheduler_name ,
420- ** self .scheduler_args )) # noqa: E501
432+ scheduler_ = set_module .set_scheduler (optimizers [0 ],
433+ self .scheduler_name ,
434+ ** self .scheduler_args )
421435
422436 # Utils
423437 criterion = nn .MSELoss ()
@@ -437,6 +451,11 @@ def _forward(estimators, data):
437451 for epoch in range (epochs ):
438452 self .train ()
439453
454+ if self .use_scheduler_ :
455+ cur_lr = scheduler_ .get_last_lr ()[0 ]
456+ else :
457+ cur_lr = None
458+
440459 if self .n_jobs and self .n_jobs > 1 :
441460 msg = "Parallelization on the training epoch: {:03d}"
442461 self .logger .info (msg .format (epoch ))
@@ -445,6 +464,7 @@ def _forward(estimators, data):
445464 train_loader ,
446465 epsilon ,
447466 estimator ,
467+ cur_lr ,
448468 optimizer ,
449469 criterion ,
450470 idx ,
@@ -486,9 +506,11 @@ def _forward(estimators, data):
486506 self .logger .info (msg .format (epoch , mse , best_mse ))
487507
488508 # Update the scheduler
489- if self .use_scheduler_ :
490- for i in range (self .n_estimators ):
491- schedulers [i ].step ()
509+ with warnings .catch_warnings ():
510+ warnings .simplefilter ("ignore" , UserWarning )
511+
512+ if self .use_scheduler_ :
513+ scheduler_ .step ()
492514
493515 self .estimators_ = nn .ModuleList ()
494516 self .estimators_ .extend (estimators )
0 commit comments