Skip to content

Commit a84b5d2

Browse files
author
Matt Gadd
authored
[ENH] Add model deserialization method (#41)
* loading an ensemble from checkpoint * some code quality checks * add contribution in CHANGELOG.rst
1 parent d1d9472 commit a84b5d2

2 files changed

Lines changed: 22 additions & 0 deletions

File tree

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Changelog
22
=========
33

4+
[Ver 0.1.*]
5+
-----------
6+
7+
* |MajorFeature| Add methods on model serialization :meth:`load()` for all ensembles | @mttgdd
8+
49
[Beta]
510
------
611

torchensemble/utils/io.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,20 @@ def save(model, save_dir, logger):
2323
torch.save(state, save_dir)
2424

2525
return
26+
27+
28+
def load(model, save_dir="./", logger=None):
29+
"""Implement model deserialization from the specified directory."""
30+
if not os.path.exists(save_dir):
31+
raise FileExistsError("`{}` does not exist".format(save_dir))
32+
33+
# {Ensemble_Method_Name}_{Base_Estimator_Name}_{n_estimators}
34+
filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,
35+
model.base_estimator_.__name__,
36+
model.n_estimators)
37+
save_dir = os.path.join(save_dir, filename)
38+
39+
if logger:
40+
logger.info("Loading the model from `{}`".format(save_dir))
41+
42+
model.load_state_dict(torch.load(save_dir)["model"])

0 commit comments

Comments
 (0)