forked from TheoEst/abdominal_registration
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_loader.py
More file actions
140 lines (105 loc) · 3.61 KB
/
model_loader.py
File metadata and controls
140 lines (105 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 3 16:39:16 2020
@author: T_ESTIENNE
"""
import torch
from collections import OrderedDict
repo = 'abdominal_registration'
main_path = './' + repo + '/'
networks_path = main_path
load_model_path = main_path + '/save/models/'
# %% Creation of model
def create_model(args, kwargs):
'''
Dynamic creation of a network
'''
model_type = args.arch
package_name = model_type
package = __import__(repo + '.' + package_name)
network_package = getattr(package, model_type)
model = getattr(network_package, model_type)(**kwargs)
print('Create {} model'.format(model_type))
return model
# %%
def handleDataParallel(checkpoint):
'''
If the original data is save with DataParallel
we need to create a new OrderedDict that does not contains
'module'
'''
first_key = checkpoint.keys().__iter__().__next__()
# Check if the module was saved with DataParallel
if first_key.startswith('module'):
print('handleDataParallel')
new_checkpoint = OrderedDict()
for key, value in checkpoint.items():
name = key[7:] # remove 'module'
new_checkpoint[name] = value
return new_checkpoint
else:
return checkpoint
def load_model(args, kwargs):
'''
This function load a pretrained model
'''
file = args.model_abspath
name = args.model_abspath.split('/')[-1]
# Model
model = create_model(args, kwargs)
print("=> loading model '{}'".format(name))
checkpoint = torch.load(file, map_location=lambda storage, loc: storage)
if 'epoch' in checkpoint.keys():
epoch = checkpoint['epoch']
best_loss = checkpoint['val_loss']
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
state_dict = handleDataParallel(state_dict)
model.load_state_dict(state_dict, strict=False)
if 'epoch' in checkpoint.keys():
print("=> loaded model '{}' (epoch {} / val_loss {}"
.format(name, epoch, best_loss))
return model, name
def load_pretrained_model(args, kwargs):
'''
This function load a pretrained model
'''
file = args.model_abspath
name = args.model_abspath.split('/')[-1]
print_ = True
# print_ = False
# Model
model = create_model(args, kwargs)
print("=> loading model '{}'".format(name))
checkpoint = torch.load(file, map_location=lambda storage, loc: storage)
epoch = checkpoint['epoch']
best_loss = checkpoint['val_loss']
pretrained_dict = handleDataParallel(checkpoint['state_dict'])
model_dict = model.state_dict()
if print_:
print()
print('Model')
for param_tensor in model_dict:
print(param_tensor, "\t", model_dict[param_tensor].size())
print()
print(model)
print('Pretrained dict')
for param_tensor in pretrained_dict:
print(param_tensor, "\t", pretrained_dict[param_tensor].size())
print()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k,
v in pretrained_dict.items() if k in model_dict}
if print_:
print()
print('Intersection dict')
for param_tensor in pretrained_dict:
print(param_tensor, "\t", pretrained_dict[param_tensor].size())
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
print("=> loaded model '{}' (epoch {} / val_loss {}"
.format(name, epoch, best_loss))
return model, name