1010"""
1111
1212
13- import copy
1413import math
1514import torch
1615import logging
@@ -111,9 +110,6 @@ def __init__(
111110 self .device = torch .device ("cuda" if cuda else "cpu" )
112111 self .logger = logging .getLogger ()
113112
114- # Used to generate snapshots
115- self .dummy_estimator_ = self ._make_estimator ()
116-
117113 self .estimators_ = nn .ModuleList ()
118114
119115 def _validate_parameters (self , lr_clip , epochs , log_interval ):
@@ -255,9 +251,11 @@ def fit(
255251 train_loader , self .is_classification
256252 )
257253
254+ estimator = self ._make_estimator ()
255+
258256 # Set the optimizer and scheduler
259257 optimizer = set_module .set_optimizer (
260- self . dummy_estimator_ , self .optimizer_name , ** self .optimizer_args
258+ estimator , self .optimizer_name , ** self .optimizer_args
261259 )
262260
263261 scheduler = self ._set_scheduler (optimizer , epochs * len (train_loader ))
@@ -269,7 +267,7 @@ def fit(
269267 n_iters_per_estimator = epochs * len (train_loader ) // self .n_estimators
270268
271269 # Training loop
272- self . dummy_estimator_ .train ()
270+ estimator .train ()
273271 for epoch in range (epochs ):
274272 for batch_idx , (data , target ) in enumerate (train_loader ):
275273
@@ -280,7 +278,7 @@ def fit(
280278 optimizer = self ._clip_lr (optimizer , lr_clip )
281279
282280 optimizer .zero_grad ()
283- output = self . dummy_estimator_ (data )
281+ output = estimator (data )
284282 loss = criterion (output , target )
285283 loss .backward ()
286284 optimizer .step ()
@@ -314,7 +312,8 @@ def fit(
314312 if counter % n_iters_per_estimator == 0 :
315313
316314 # Generate and save the snapshot
317- snapshot = copy .deepcopy (self .dummy_estimator_ )
315+ snapshot = self ._make_estimator ()
316+ snapshot .load_state_dict (estimator .state_dict ())
318317 self .estimators_ .append (snapshot )
319318
320319 msg = "Save the snapshot model with index: {}"
@@ -403,9 +402,11 @@ def fit(
403402 train_loader , self .is_classification
404403 )
405404
405+ estimator = self ._make_estimator ()
406+
406407 # Set the optimizer and scheduler
407408 optimizer = set_module .set_optimizer (
408- self . dummy_estimator_ , self .optimizer_name , ** self .optimizer_args
409+ estimator , self .optimizer_name , ** self .optimizer_args
409410 )
410411
411412 scheduler = self ._set_scheduler (optimizer , epochs * len (train_loader ))
@@ -417,7 +418,7 @@ def fit(
417418 n_iters_per_estimator = epochs * len (train_loader ) // self .n_estimators
418419
419420 # Training loop
420- self . dummy_estimator_ .train ()
421+ estimator .train ()
421422 for epoch in range (epochs ):
422423 for batch_idx , (data , target ) in enumerate (train_loader ):
423424
@@ -427,7 +428,7 @@ def fit(
427428 optimizer = self ._clip_lr (optimizer , lr_clip )
428429
429430 optimizer .zero_grad ()
430- output = self . dummy_estimator_ (data )
431+ output = estimator (data )
431432 loss = criterion (output , target )
432433 loss .backward ()
433434 optimizer .step ()
@@ -455,7 +456,8 @@ def fit(
455456
456457 if counter % n_iters_per_estimator == 0 :
457458 # Generate and save the snapshot
458- snapshot = copy .deepcopy (self .dummy_estimator_ )
459+ snapshot = self ._make_estimator ()
460+ snapshot .load_state_dict (estimator .state_dict ())
459461 self .estimators_ .append (snapshot )
460462
461463 msg = "Save the snapshot model with index: {}"
0 commit comments