-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
219 lines (200 loc) · 7.3 KB
/
Copy pathmodel.py
File metadata and controls
219 lines (200 loc) · 7.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
Model Architecture for an improved UNet.
@author Tompnyx
@email tompnyx@outlook.com
"""
import tensorflow as tf
from tensorflow.keras.layers import concatenate, Conv2D, BatchNormalization, Dropout, Input, LeakyReLU, MaxPooling2D,\
UpSampling2D
def improved_unet(height, width, channels):
"""
The improved UNet model's architecture. Given a height, width, and number of channels it will return a
improved UNet model ready for training.
:param height: The height of the image
:param width: The width of the image
:param channels: The number of channels the image has
:return: A keras model of a improved UNet
"""
"""Constants"""
# The number of filters for each convolutional layer
fil = 16
# The kernel size to use
kern = (3, 3)
# The padding argument used for each convolutional layer
pad = 'same'
# The dropout rate used by each Dropout layer
drop = 0.3
# The alpha rate used by each LeakyReLU layer
alp = 0.01
# The activation type for the convolutional layers
actv = 'relu'
# Building the model
inputs = Input((height, width, channels))
# 3x3x3 convolution
conv1 = Conv2D(fil, kern, padding=pad, activation=actv)(inputs)
# Instance Normalisation
bat1 = BatchNormalization()(conv1)
# Leaky ReLU
relu1 = LeakyReLU(alpha=alp)(bat1)
# Context module
context1 = Conv2D(fil, kern, padding=pad, activation=actv)(relu1)
context1 = Dropout(drop)(context1)
context1 = Conv2D(fil, kern, padding=pad, activation=actv)(context1)
# Instance Normalisation
bat1 = BatchNormalization()(context1)
# Leaky ReLU
relu1 = LeakyReLU(alpha=alp)(bat1)
# Element-wise sum
out1 = conv1 + relu1
# Down-sampling
down_samp2 = MaxPooling2D()(out1)
# 3x3x3 stride 2 convolution
conv2 = Conv2D(fil * 2, kern, padding=pad, activation=actv)(down_samp2)
# Instance Normalisation
bat2 = BatchNormalization()(conv2)
# Leaky ReLU
relu2 = LeakyReLU(alpha=alp)(bat2)
# Context module
context2 = Conv2D(fil * 2, kern, padding=pad, activation=actv)(relu2)
context2 = Dropout(drop)(context2)
context2 = Conv2D(fil * 2, kern, padding=pad, activation=actv)(context2)
# Instance Normalisation
bat2 = BatchNormalization()(context2)
# Leaky ReLU
relu2 = LeakyReLU(alpha=alp)(bat2)
# Element-wise sum
out2 = conv2 + relu2
# Down-sampling
down_samp3 = MaxPooling2D()(out2)
# 3x3x3 stride 2 convolution
conv3 = Conv2D(fil * 4, kern, padding=pad, activation=actv)(down_samp3)
# Instance Normalisation
bat3 = BatchNormalization()(conv3)
# Leaky ReLU
relu3 = LeakyReLU(alpha=alp)(bat3)
# Context module
context3 = Conv2D(fil * 4, kern, padding=pad, activation=actv)(relu3)
context3 = Dropout(drop)(context3)
context3 = Conv2D(fil * 4, kern, padding=pad, activation=actv)(context3)
# Instance Normalisation
bat3 = BatchNormalization()(context3)
# Leaky ReLU
relu3 = LeakyReLU(alpha=alp)(bat3)
# Element-wise sum
out3 = conv3 + relu3
# Down-sampling
down_samp4 = MaxPooling2D()(out3)
# 3x3x3 stride 2 convolution
conv4 = Conv2D(fil * 8, kern, padding=pad, activation=actv)(down_samp4)
# Instance Normalisation
bat4 = BatchNormalization()(conv4)
# Leaky ReLU
relu4 = LeakyReLU(alpha=alp)(bat4)
# Context module
context4 = Conv2D(fil * 8, kern, padding=pad, activation=actv)(relu4)
context4 = Dropout(drop)(context4)
context4 = Conv2D(fil * 8, kern, padding=pad, activation=actv)(context4)
# Instance Normalisation
bat4 = BatchNormalization()(context4)
# Leaky ReLU
relu4 = LeakyReLU(alpha=alp)(bat4)
# Element-wise sum
out4 = conv4 + relu4
# Down-sampling
down_samp5 = MaxPooling2D()(out4)
# 3x3x3 stride 2 convolution
conv5 = Conv2D(fil * 16, kern, padding=pad, activation=actv)(down_samp5)
# Instance Normalisation
bat5 = BatchNormalization()(conv5)
# Leaky ReLU
relu5 = LeakyReLU(alpha=alp)(bat5)
# Context module
context5 = Conv2D(fil * 16, kern, padding=pad, activation=actv)(relu5)
context5 = Dropout(drop)(context5)
context5 = Conv2D(fil * 16, kern, padding=pad, activation=actv)(context5)
# Instance Normalisation
bat5 = BatchNormalization()(context5)
# Leaky ReLU
relu5 = LeakyReLU(alpha=alp)(bat5)
# Element-wise sum
out5 = conv5 + relu5
# Up-sampling
up_samp5 = UpSampling2D(size=(2, 2))(out5)
up_samp5 = Conv2D(fil * 8, kern, padding=pad, activation=actv)(up_samp5)
# Expansive path
# Instance Normalisation
bat6 = BatchNormalization()(up_samp5)
# Leaky ReLU
relu6 = LeakyReLU(alpha=alp)(bat6)
# Concatenation
concat6 = concatenate([relu6, out4])
# Localisation
local6 = Conv2D(fil * 8, kern, padding=pad, activation=actv)(concat6)
local6 = Conv2D(fil * 8, (1, 1), padding=pad, activation=actv)(local6)
# Instance Normalisation
bat6 = BatchNormalization()(local6)
# Leaky ReLU
relu6 = LeakyReLU(alpha=alp)(bat6)
# Up-sampling
up_samp6 = UpSampling2D(size=(2, 2))(relu6)
up_samp6 = Conv2D(fil * 4, (2, 2), padding=pad, activation=actv)(up_samp6)
# Instance Normalisation
bat7 = BatchNormalization()(up_samp6)
# Leaky ReLU
relu7 = LeakyReLU(alpha=alp)(bat7)
# Concatenation
concat7 = concatenate([relu7, out3])
# Localisation
local7 = Conv2D(fil * 4, kern, padding=pad, activation=actv)(concat7)
local7 = Conv2D(fil * 4, (1, 1), padding=pad, activation=actv)(local7)
# Instance Normalisation
bat7 = BatchNormalization()(local7)
# Leaky ReLU
relu7 = LeakyReLU(alpha=alp)(bat7)
# Segmentation
segment7 = relu7
# Up-sampling
up_samp7 = UpSampling2D(size=(2, 2))(relu7)
up_samp7 = Conv2D(fil * 2, (2, 2), padding=pad, activation=actv)(up_samp7)
# Instance Normalisation
bat8 = BatchNormalization()(up_samp7)
# Leaky ReLU
relu8 = LeakyReLU(alpha=alp)(bat8)
# Concatenation
concat8 = concatenate([relu8, out2])
# Localisation
local8 = Conv2D(fil * 2, kern, padding=pad, activation=actv)(concat8)
local8 = Conv2D(fil * 2, (1, 1), padding=pad, activation=actv)(local8)
# Instance Normalisation
bat8 = BatchNormalization()(local8)
# Leaky ReLU
relu8 = LeakyReLU(alpha=alp)(bat8)
# Segmentation
segment8 = relu8
# Up-sampling
up_samp8 = UpSampling2D(size=(2, 2))(relu8)
up_samp8 = Conv2D(fil, (2, 2), padding=pad, activation=actv)(up_samp8)
# Instance Normalisation
bat9 = BatchNormalization()(up_samp8)
# Leaky ReLU
relu9 = LeakyReLU(alpha=alp)(bat9)
# Concatenation
concat9 = concatenate([relu9, out1])
# 3x3x3 convolution
conv9 = Conv2D(fil, kern, padding=pad, activation=actv)(concat9)
# Instance Normalisation
bat9 = BatchNormalization()(conv9)
# Leaky ReLU
relu9 = LeakyReLU(alpha=alp)(bat9)
# Segmentation
segment9 = relu9
# Upscale Segmented Layers and apply
segment7 = UpSampling2D(size=(2, 2))(segment7)
segment7 = Conv2D(fil * 2, (1, 1))(segment7)
segment8 = segment8 + segment7
segment8 = UpSampling2D(size=(2, 2))(segment8)
segment8 = Conv2D(fil, (1, 1))(segment8)
segment9 = segment9 + segment8
# Softmax
outputs = Conv2D(1, (1, 1), activation='sigmoid')(segment9)
return tf.keras.Model(inputs=[inputs], outputs=[outputs])