Skip to content

Commit f24e689

Browse files
committed
fix(SE): training loss turns into NaN
1 parent 4081b57 commit f24e689

1 file changed

Lines changed: 14 additions & 12 deletions

File tree

torchensemble/snapshot_ensemble.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
"""
1111

1212

13-
import copy
1413
import math
1514
import torch
1615
import logging
@@ -111,9 +110,6 @@ def __init__(
111110
self.device = torch.device("cuda" if cuda else "cpu")
112111
self.logger = logging.getLogger()
113112

114-
# Used to generate snapshots
115-
self.dummy_estimator_ = self._make_estimator()
116-
117113
self.estimators_ = nn.ModuleList()
118114

119115
def _validate_parameters(self, lr_clip, epochs, log_interval):
@@ -255,9 +251,11 @@ def fit(
255251
train_loader, self.is_classification
256252
)
257253

254+
estimator = self._make_estimator()
255+
258256
# Set the optimizer and scheduler
259257
optimizer = set_module.set_optimizer(
260-
self.dummy_estimator_, self.optimizer_name, **self.optimizer_args
258+
estimator, self.optimizer_name, **self.optimizer_args
261259
)
262260

263261
scheduler = self._set_scheduler(optimizer, epochs * len(train_loader))
@@ -269,7 +267,7 @@ def fit(
269267
n_iters_per_estimator = epochs * len(train_loader) // self.n_estimators
270268

271269
# Training loop
272-
self.dummy_estimator_.train()
270+
estimator.train()
273271
for epoch in range(epochs):
274272
for batch_idx, (data, target) in enumerate(train_loader):
275273

@@ -280,7 +278,7 @@ def fit(
280278
optimizer = self._clip_lr(optimizer, lr_clip)
281279

282280
optimizer.zero_grad()
283-
output = self.dummy_estimator_(data)
281+
output = estimator(data)
284282
loss = criterion(output, target)
285283
loss.backward()
286284
optimizer.step()
@@ -314,7 +312,8 @@ def fit(
314312
if counter % n_iters_per_estimator == 0:
315313

316314
# Generate and save the snapshot
317-
snapshot = copy.deepcopy(self.dummy_estimator_)
315+
snapshot = self._make_estimator()
316+
snapshot.load_state_dict(estimator.state_dict())
318317
self.estimators_.append(snapshot)
319318

320319
msg = "Save the snapshot model with index: {}"
@@ -403,9 +402,11 @@ def fit(
403402
train_loader, self.is_classification
404403
)
405404

405+
estimator = self._make_estimator()
406+
406407
# Set the optimizer and scheduler
407408
optimizer = set_module.set_optimizer(
408-
self.dummy_estimator_, self.optimizer_name, **self.optimizer_args
409+
estimator, self.optimizer_name, **self.optimizer_args
409410
)
410411

411412
scheduler = self._set_scheduler(optimizer, epochs * len(train_loader))
@@ -417,7 +418,7 @@ def fit(
417418
n_iters_per_estimator = epochs * len(train_loader) // self.n_estimators
418419

419420
# Training loop
420-
self.dummy_estimator_.train()
421+
estimator.train()
421422
for epoch in range(epochs):
422423
for batch_idx, (data, target) in enumerate(train_loader):
423424

@@ -427,7 +428,7 @@ def fit(
427428
optimizer = self._clip_lr(optimizer, lr_clip)
428429

429430
optimizer.zero_grad()
430-
output = self.dummy_estimator_(data)
431+
output = estimator(data)
431432
loss = criterion(output, target)
432433
loss.backward()
433434
optimizer.step()
@@ -455,7 +456,8 @@ def fit(
455456

456457
if counter % n_iters_per_estimator == 0:
457458
# Generate and save the snapshot
458-
snapshot = copy.deepcopy(self.dummy_estimator_)
459+
snapshot = self._make_estimator()
460+
snapshot.load_state_dict(estimator.state_dict())
459461
self.estimators_.append(snapshot)
460462

461463
msg = "Save the snapshot model with index: {}"

0 commit comments

Comments
 (0)