forked from TheoEst/abdominal_registration
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
131 lines (98 loc) · 3.2 KB
/
utils.py
File metadata and controls
131 lines (98 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 3 16:43:25 2020
@author: T_ESTIENNE
"""
import torch
import numpy as np
import os
import math
repo = 'abdominal_registration/'
main_path = './' + repo
def to_var(args, x):
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if args.cuda:
x = x.cuda(args.gpu)
elif args.data_parallel:
x = x.cuda()
return torch.autograd.Variable(x)
def to_numpy(args, x):
if not (isinstance(x, np.ndarray) or x is None):
if args.cuda:
x = x.cpu()
x = x.detach().numpy()
return x
def save_checkpoint(args, state, is_best):
'''
Save the current model.
If the model is the best model since beginning of the training
it will be copy
'''
save_path = args.model_path
if not os.path.isdir(save_path):
os.makedirs(save_path)
epoch = state['epoch']
if args.save and epoch % args.save_frequency == 0:
val_loss = state['val_loss']
filename = save_path + '/' + \
'model.{:02d}--{:.3f}.pth.tar'.format(epoch, val_loss)
torch.save(state, filename)
if is_best:
filename = save_path + '/model_best.pth.tar'
torch.save(state, filename)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if not math.isnan(val):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def return_string(self):
return '{loss.val:.4f} ({loss.avg:.4f})\t'.format(loss=self)
class MultiAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, names=[]):
self.average_dict = {}
self.names = names
for name in self.names:
self.average_dict[name] = AverageMeter()
def get(self, name):
return self.average_dict[name]
def update(self, name, val, n=1):
if name not in self.names:
self.names.append(name)
self.average_dict[name] = AverageMeter()
self.get(name).update(val, n)
def return_string(self):
string = ''
for name in self.names:
string += (str(name) + ' ' +
self.get(name).return_string() + '\t')
return string
def update_Logger(self, Logger, epoch):
for name in self.names:
Logger.log_value(name, self.get(name).avg, epoch)
return Logger
def return_all_avg(self):
return {name : self.average_dict[name].avg for name in self.names}
def print_summary(epoch, i, nb_batch, loss_dict, logging, mode):
'''
mode = Train or Test
'''
summary = '[' + str(mode) + '] Epoch: [{0}][{1}/{2}]\t'.format(
epoch, i, nb_batch)
string = ''
if isinstance(loss_dict, MultiAverageMeter):
string += loss_dict.return_string()
else:
for loss_name, loss in loss_dict.items():
string += (loss_name + ' {:.4f} \t').format(loss)
summary += string
logging.info(summary)