Skip to content

Commit f53e17a

Browse files
authored
[FIX] Fix the binding during parallelization (#33)
* [FIX] Fix the binnding problem on scheduler * flake8 formatting * Update CHANGELOG.rst
1 parent 8ec861f commit f53e17a

6 files changed

Lines changed: 157 additions & 49 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Changelog
1616
* |Enhancement| Improve the logging module | @zzzzwj
1717
* |API| Remove the input argument ``output_dim`` from all methods | @xuyxu
1818
* |Fix| Fix the bug in logging module when using multi-processing | @zzzzwj
19-
19+
* |Fix| Fix the binding problem on scheduler and optimizer when using parallelization | @Alex-Medium and @xuyxi
2020

2121
.. role:: raw-html(raw)
2222
:format: html

torchensemble/adversarial_training.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import torch
1313
import torch.nn as nn
1414
import torch.nn.functional as F
15+
16+
import warnings
1517
from joblib import Parallel, delayed
1618

1719
from ._base import BaseModule, torchensemble_model_doc
@@ -81,6 +83,7 @@ def adddoc(cls):
8183
def _parallel_fit_per_epoch(train_loader,
8284
epsilon,
8385
estimator,
86+
cur_lr,
8487
optimizer,
8588
criterion,
8689
idx,
@@ -95,6 +98,10 @@ def _parallel_fit_per_epoch(train_loader,
9598
out-of-memory error.
9699
"""
97100

101+
if cur_lr:
102+
# Parallelization corrupts the binding between optimizer and scheduler
103+
set_module.update_lr(optimizer, cur_lr)
104+
98105
for batch_idx, (data, target) in enumerate(train_loader):
99106

100107
batch_size = data.size()[0]
@@ -246,11 +253,9 @@ def fit(self,
246253
**self.optimizer_args))
247254

248255
if self.use_scheduler_:
249-
schedulers = []
250-
for i in range(self.n_estimators):
251-
schedulers.append(set_module.set_scheduler(optimizers[i],
252-
self.scheduler_name,
253-
**self.scheduler_args)) # noqa: E501
256+
scheduler_ = set_module.set_scheduler(optimizers[0],
257+
self.scheduler_name,
258+
**self.scheduler_args)
254259

255260
# Utils
256261
criterion = nn.CrossEntropyLoss()
@@ -271,6 +276,11 @@ def _forward(estimators, data):
271276
for epoch in range(epochs):
272277
self.train()
273278

279+
if self.use_scheduler_:
280+
cur_lr = scheduler_.get_last_lr()[0]
281+
else:
282+
cur_lr = None
283+
274284
if self.n_jobs and self.n_jobs > 1:
275285
msg = "Parallelization on the training epoch: {:03d}"
276286
self.logger.info(msg.format(epoch))
@@ -279,6 +289,7 @@ def _forward(estimators, data):
279289
train_loader,
280290
epsilon,
281291
estimator,
292+
cur_lr,
282293
optimizer,
283294
criterion,
284295
idx,
@@ -323,9 +334,14 @@ def _forward(estimators, data):
323334
self.logger.info(msg.format(epoch, acc, best_acc))
324335

325336
# Update the scheduler
326-
if self.use_scheduler_:
327-
for i in range(self.n_estimators):
328-
schedulers[i].step()
337+
with warnings.catch_warnings():
338+
339+
# UserWarning raised by PyTorch is ignored because
340+
# scheduler does not have a real effect on the optimier.
341+
warnings.simplefilter("ignore", UserWarning)
342+
343+
if self.use_scheduler_:
344+
scheduler_.step()
329345

330346
self.estimators_ = nn.ModuleList()
331347
self.estimators_.extend(estimators)
@@ -413,11 +429,9 @@ def fit(self,
413429
**self.optimizer_args))
414430

415431
if self.use_scheduler_:
416-
schedulers = []
417-
for i in range(self.n_estimators):
418-
schedulers.append(set_module.set_scheduler(optimizers[i],
419-
self.scheduler_name,
420-
**self.scheduler_args)) # noqa: E501
432+
scheduler_ = set_module.set_scheduler(optimizers[0],
433+
self.scheduler_name,
434+
**self.scheduler_args)
421435

422436
# Utils
423437
criterion = nn.MSELoss()
@@ -437,6 +451,11 @@ def _forward(estimators, data):
437451
for epoch in range(epochs):
438452
self.train()
439453

454+
if self.use_scheduler_:
455+
cur_lr = scheduler_.get_last_lr()[0]
456+
else:
457+
cur_lr = None
458+
440459
if self.n_jobs and self.n_jobs > 1:
441460
msg = "Parallelization on the training epoch: {:03d}"
442461
self.logger.info(msg.format(epoch))
@@ -445,6 +464,7 @@ def _forward(estimators, data):
445464
train_loader,
446465
epsilon,
447466
estimator,
467+
cur_lr,
448468
optimizer,
449469
criterion,
450470
idx,
@@ -486,9 +506,11 @@ def _forward(estimators, data):
486506
self.logger.info(msg.format(epoch, mse, best_mse))
487507

488508
# Update the scheduler
489-
if self.use_scheduler_:
490-
for i in range(self.n_estimators):
491-
schedulers[i].step()
509+
with warnings.catch_warnings():
510+
warnings.simplefilter("ignore", UserWarning)
511+
512+
if self.use_scheduler_:
513+
scheduler_.step()
492514

493515
self.estimators_ = nn.ModuleList()
494516
self.estimators_.extend(estimators)

torchensemble/bagging.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212

13+
import warnings
1314
from joblib import Parallel, delayed
1415

1516
from ._base import BaseModule, torchensemble_model_doc
@@ -24,6 +25,7 @@
2425

2526
def _parallel_fit_per_epoch(train_loader,
2627
estimator,
28+
cur_lr,
2729
optimizer,
2830
criterion,
2931
idx,
@@ -38,6 +40,10 @@ def _parallel_fit_per_epoch(train_loader,
3840
out-of-memory error.
3941
"""
4042

43+
if cur_lr:
44+
# Parallelization corrupts the binding between optimizer and scheduler
45+
set_module.update_lr(optimizer, cur_lr)
46+
4147
for batch_idx, (data, target) in enumerate(train_loader):
4248

4349
batch_size = data.size(0)
@@ -134,11 +140,9 @@ def fit(self,
134140
**self.optimizer_args))
135141

136142
if self.use_scheduler_:
137-
schedulers = []
138-
for i in range(self.n_estimators):
139-
schedulers.append(set_module.set_scheduler(optimizers[i],
140-
self.scheduler_name,
141-
**self.scheduler_args)) # noqa: E501
143+
scheduler_ = set_module.set_scheduler(optimizers[0],
144+
self.scheduler_name,
145+
**self.scheduler_args)
142146

143147
# Utils
144148
criterion = nn.CrossEntropyLoss()
@@ -159,13 +163,19 @@ def _forward(estimators, data):
159163
for epoch in range(epochs):
160164
self.train()
161165

166+
if self.use_scheduler_:
167+
cur_lr = scheduler_.get_last_lr()[0]
168+
else:
169+
cur_lr = None
170+
162171
if self.n_jobs and self.n_jobs > 1:
163172
msg = "Parallelization on the training epoch: {:03d}"
164173
self.logger.info(msg.format(epoch))
165174

166175
rets = parallel(delayed(_parallel_fit_per_epoch)(
167176
train_loader,
168177
estimator,
178+
cur_lr,
169179
optimizer,
170180
criterion,
171181
idx,
@@ -210,9 +220,14 @@ def _forward(estimators, data):
210220
self.logger.info(msg.format(epoch, acc, best_acc))
211221

212222
# Update the scheduler
213-
if self.use_scheduler_:
214-
for i in range(self.n_estimators):
215-
schedulers[i].step()
223+
with warnings.catch_warnings():
224+
225+
# UserWarning raised by PyTorch is ignored because
226+
# scheduler does not have a real effect on the optimier.
227+
warnings.simplefilter("ignore", UserWarning)
228+
229+
if self.use_scheduler_:
230+
scheduler_.step()
216231

217232
self.estimators_ = nn.ModuleList()
218233
self.estimators_.extend(estimators)
@@ -294,11 +309,9 @@ def fit(self,
294309
**self.optimizer_args))
295310

296311
if self.use_scheduler_:
297-
schedulers = []
298-
for i in range(self.n_estimators):
299-
schedulers.append(set_module.set_scheduler(optimizers[i],
300-
self.scheduler_name,
301-
**self.scheduler_args)) # noqa: E501
312+
scheduler_ = set_module.set_scheduler(optimizers[0],
313+
self.scheduler_name,
314+
**self.scheduler_args)
302315

303316
# Utils
304317
criterion = nn.MSELoss()
@@ -318,13 +331,19 @@ def _forward(estimators, data):
318331
for epoch in range(epochs):
319332
self.train()
320333

334+
if self.use_scheduler_:
335+
cur_lr = scheduler_.get_last_lr()[0]
336+
else:
337+
cur_lr = None
338+
321339
if self.n_jobs and self.n_jobs > 1:
322340
msg = "Parallelization on the training epoch: {:03d}"
323341
self.logger.info(msg.format(epoch))
324342

325343
rets = parallel(delayed(_parallel_fit_per_epoch)(
326344
train_loader,
327345
estimator,
346+
cur_lr,
328347
optimizer,
329348
criterion,
330349
idx,
@@ -366,9 +385,11 @@ def _forward(estimators, data):
366385
self.logger.info(msg.format(epoch, mse, best_mse))
367386

368387
# Update the scheduler
369-
if self.use_scheduler_:
370-
for i in range(self.n_estimators):
371-
schedulers[i].step()
388+
with warnings.catch_warnings():
389+
warnings.simplefilter("ignore", UserWarning)
390+
391+
if self.use_scheduler_:
392+
scheduler_.step()
372393

373394
self.estimators_ = nn.ModuleList()
374395
self.estimators_.extend(estimators)

torchensemble/tests/test_set_optimizer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,29 @@ def test_set_optimizer_Unknown():
4242
with pytest.raises(NotImplementedError) as excinfo:
4343
torchensemble.utils.set_module.set_optimizer(model, "Unknown")
4444
assert "Unknown name of the optimizer" in str(excinfo.value)
45+
46+
47+
def test_update_lr():
48+
cur_lr = 1e-4
49+
model = MLP()
50+
optimizer = torchensemble.utils.set_module.set_optimizer(model,
51+
"Adam",
52+
lr=1e-3)
53+
54+
optimizer = torchensemble.utils.set_module.update_lr(optimizer, cur_lr)
55+
56+
for group in optimizer.param_groups:
57+
assert group["lr"] == cur_lr
58+
59+
60+
def test_update_lr_invalid():
61+
cur_lr = 0
62+
model = MLP()
63+
optimizer = torchensemble.utils.set_module.set_optimizer(model,
64+
"Adam",
65+
lr=1e-3)
66+
67+
err_msg = ("The learning rate should be strictly positive, but got"
68+
" {} instead.").format(cur_lr)
69+
with pytest.raises(ValueError, match=err_msg):
70+
torchensemble.utils.set_module.update_lr(optimizer, cur_lr)

torchensemble/utils/set_module.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@ def set_optimizer(model, optimizer_name, **kwargs):
3636
return optimizer
3737

3838

39+
def update_lr(optimizer, lr):
40+
"""
41+
Manually update the learning rate of the optimizer. This function is used
42+
when the parallelization corrupts the bindings between the optimizer and
43+
the scheduler.
44+
"""
45+
46+
if not lr > 0:
47+
msg = ("The learning rate should be strictly positive, but got"
48+
" {} instead.")
49+
raise ValueError(msg.format(lr))
50+
51+
for group in optimizer.param_groups:
52+
group["lr"] = lr
53+
54+
return optimizer
55+
56+
3957
def set_scheduler(optimizer, scheduler_name, **kwargs):
4058
"""
4159
Set the scheduler on learning rate for the optimizer.

0 commit comments

Comments
 (0)