Skip to content

Commit cc79216

Browse files
zzzzwjzzzzwjxuyxu
authored
[ENH] Fix the parallelization issue on logging
* Create the pytest script for logging module * Fix warnings raised by flake8 * formalize the name of my function * Replace python.logging with MultiProcessing-Logging module * Eliminate warnings produced by flake8 * Fix one small bug * remove the debug script * Replace multiprocessing logger with print * Remove the test_logging script which isn't able to work on Windows * Write log before parallelization * restore the example on classification Co-authored-by: zzzzwj <zwj@nju.edu.cn> Co-authored-by: Yi-Xuan Xu <xuyx@lamda.nju.edu.cn>
1 parent 671ee28 commit cc79216

3 files changed

Lines changed: 42 additions & 28 deletions

File tree

torchensemble/adversarial_training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ def _forward(estimators, data):
270270
# Training loop
271271
for epoch in range(epochs):
272272
self.train()
273+
274+
if self.n_jobs and self.n_jobs > 1:
275+
msg = "Parallelization on the training epoch: {:03d}"
276+
self.logger.info(msg.format(epoch))
277+
273278
rets = parallel(delayed(_parallel_fit_per_epoch)(
274279
train_loader,
275280
epsilon,
@@ -431,6 +436,11 @@ def _forward(estimators, data):
431436
# Training loop
432437
for epoch in range(epochs):
433438
self.train()
439+
440+
if self.n_jobs and self.n_jobs > 1:
441+
msg = "Parallelization on the training epoch: {:03d}"
442+
self.logger.info(msg.format(epoch))
443+
434444
rets = parallel(delayed(_parallel_fit_per_epoch)(
435445
train_loader,
436446
epsilon,

torchensemble/bagging.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ def _parallel_fit_per_epoch(train_loader,
3838
out-of-memory error.
3939
"""
4040

41-
msg_list = []
42-
4341
for batch_idx, (data, target) in enumerate(train_loader):
4442

4543
batch_size = data.size(0)
@@ -70,14 +68,14 @@ def _parallel_fit_per_epoch(train_loader,
7068

7169
msg = ("Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
7270
" | Loss: {:.5f} | Correct: {:d}/{:d}")
73-
msg_list.append(msg.format(idx, epoch, batch_idx, loss,
74-
correct, subsample_size))
71+
print(msg.format(idx, epoch, batch_idx, loss,
72+
correct, subsample_size))
7573
else:
7674
msg = ("Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
7775
" | Loss: {:.5f}")
78-
msg_list.append(msg.format(idx, epoch, batch_idx, loss))
76+
print(msg.format(idx, epoch, batch_idx, loss))
7977

80-
return estimator, optimizer, msg_list
78+
return estimator, optimizer
8179

8280

8381
@torchensemble_model_doc("""Implementation on the BaggingClassifier.""",
@@ -160,6 +158,11 @@ def _forward(estimators, data):
160158
# Training loop
161159
for epoch in range(epochs):
162160
self.train()
161+
162+
if self.n_jobs and self.n_jobs > 1:
163+
msg = "Parallelization on the training epoch: {:03d}"
164+
self.logger.info(msg.format(epoch))
165+
163166
rets = parallel(delayed(_parallel_fit_per_epoch)(
164167
train_loader,
165168
estimator,
@@ -176,12 +179,9 @@ def _forward(estimators, data):
176179
)
177180

178181
estimators, optimizers = [], []
179-
for estimator, optimizer, msgs in rets:
182+
for estimator, optimizer in rets:
180183
estimators.append(estimator)
181184
optimizers.append(optimizer)
182-
# Write logging info
183-
for msg in msgs:
184-
self.logger.info(msg)
185185

186186
# Validation
187187
if test_loader:
@@ -317,6 +317,11 @@ def _forward(estimators, data):
317317
# Training loop
318318
for epoch in range(epochs):
319319
self.train()
320+
321+
if self.n_jobs and self.n_jobs > 1:
322+
msg = "Parallelization on the training epoch: {:03d}"
323+
self.logger.info(msg.format(epoch))
324+
320325
rets = parallel(delayed(_parallel_fit_per_epoch)(
321326
train_loader,
322327
estimator,
@@ -333,12 +338,9 @@ def _forward(estimators, data):
333338
)
334339

335340
estimators, optimizers = [], []
336-
for estimator, optimizer, msgs in rets:
341+
for estimator, optimizer in rets:
337342
estimators.append(estimator)
338343
optimizers.append(optimizer)
339-
# Write logging info
340-
for msg in msgs:
341-
self.logger.info(msg)
342344

343345
# Validation
344346
if test_loader:

torchensemble/voting.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def _parallel_fit_per_epoch(train_loader,
3737
out-of-memory error.
3838
"""
3939

40-
msg_list = []
41-
4240
for batch_idx, (data, target) in enumerate(train_loader):
4341

4442
batch_size = data.size(0)
@@ -60,15 +58,15 @@ def _parallel_fit_per_epoch(train_loader,
6058

6159
msg = ("Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
6260
" | Loss: {:.5f} | Correct: {:d}/{:d}")
63-
msg_list.append(msg.format(idx, epoch, batch_idx, loss,
64-
correct, batch_size))
61+
print(msg.format(idx, epoch, batch_idx, loss,
62+
correct, batch_size))
6563
# Regression
6664
else:
6765
msg = ("Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
6866
" | Loss: {:.5f}")
69-
msg_list.append(msg.format(idx, epoch, batch_idx, loss))
67+
print(msg.format(idx, epoch, batch_idx, loss))
7068

71-
return estimator, optimizer, msg_list
69+
return estimator, optimizer
7270

7371

7472
@torchensemble_model_doc("""Implementation on the VotingClassifier.""",
@@ -151,6 +149,11 @@ def _forward(estimators, data):
151149
# Training loop
152150
for epoch in range(epochs):
153151
self.train()
152+
153+
if self.n_jobs and self.n_jobs > 1:
154+
msg = "Parallelization on the training epoch: {:03d}"
155+
self.logger.info(msg.format(epoch))
156+
154157
rets = parallel(delayed(_parallel_fit_per_epoch)(
155158
train_loader,
156159
estimator,
@@ -167,12 +170,9 @@ def _forward(estimators, data):
167170
)
168171

169172
estimators, optimizers = [], []
170-
for estimator, optimizer, msgs in rets:
173+
for estimator, optimizer in rets:
171174
estimators.append(estimator)
172175
optimizers.append(optimizer)
173-
# Write logging info
174-
for msg in msgs:
175-
self.logger.info(msg)
176176

177177
# Validation
178178
if test_loader:
@@ -309,6 +309,11 @@ def _forward(estimators, data):
309309
# Training loop
310310
for epoch in range(epochs):
311311
self.train()
312+
313+
if self.n_jobs and self.n_jobs > 1:
314+
msg = "Parallelization on the training epoch: {:03d}"
315+
self.logger.info(msg.format(epoch))
316+
312317
rets = parallel(delayed(_parallel_fit_per_epoch)(
313318
train_loader,
314319
estimator,
@@ -325,12 +330,9 @@ def _forward(estimators, data):
325330
)
326331

327332
estimators, optimizers = [], []
328-
for estimator, optimizer, msgs in rets:
333+
for estimator, optimizer in rets:
329334
estimators.append(estimator)
330335
optimizers.append(optimizer)
331-
# Write logging info
332-
for msg in msgs:
333-
self.logger.info(msg)
334336

335337
# Validation
336338
if test_loader:

0 commit comments

Comments
 (0)