-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathcreateUnet3d.m
More file actions
122 lines (101 loc) · 4.34 KB
/
Copy pathcreateUnet3d.m
File metadata and controls
122 lines (101 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
function lgraph = createUnet3d(inputSize)
% Create a 3-D U-Net
%
% Copyright 2018 The MathWorks, Inc.
inputL = image3dInputLayer(inputSize,'Normalization','none','Name','input');
% Create the contracting path of the 3-D U-Net
encoder_d1 = createUnet3dEncoderModule(1,[32 64 ]);
encoder_d2 = createUnet3dEncoderModule(2,[64 128 ]);
encoder_d3 = createUnet3dEncoderModule(3,[128 256 ]);
% Create the expanding path of the 3-D U-Net
decoder_l4 = createUnet3dDecoderModule(4,[256 512]);
decoder_l3 = createUnet3dDecoderModule(3,[256 256]);
decoder_l2 = createUnet3dDecoderModule(2,[128 128]);
decoder_l1 = createUnet3dFinalDecoderModule(1,[64 64]);
layers = [inputL; encoder_d1; encoder_d2; encoder_d3; decoder_l4];
lgraph = layerGraph(layers);
lgraph = addLayers(lgraph,decoder_l3);
lgraph = addLayers(lgraph,decoder_l2);
lgraph = addLayers(lgraph,decoder_l1);
% Create the skip level connections between encoder and decoder sections
concat1 = concatenationLayer(4,2,'Name','concat1');
lgraph = addLayers(lgraph,concat1);
concat2 = concatenationLayer(4,2,'Name','concat2');
lgraph = addLayers(lgraph,concat2);
concat3 = concatenationLayer(4,2,'Name','concat3');
lgraph = addLayers(lgraph,concat3);
% Connect the encoder and decoder section through concatenationLayer
lgraph = connectLayers(lgraph,encoder_d1(end-1).Name,[concat1.Name '/' 'in1']);
lgraph = connectLayers(lgraph,decoder_l2(end).Name,[concat1.Name '/' 'in2']);
lgraph = connectLayers(lgraph,encoder_d2(end-1).Name,[concat2.Name '/' 'in1']);
lgraph = connectLayers(lgraph,decoder_l3(end).Name,[concat2.Name '/' 'in2']);
lgraph = connectLayers(lgraph,encoder_d3(end-1).Name,[concat3.Name '/' 'in1']);
lgraph = connectLayers(lgraph,decoder_l4(end).Name,[concat3.Name '/' 'in2']);
% Connect output of concatenationLayer to next decoder section
lgraph = connectLayers(lgraph,[concat3.Name '/' 'out'],decoder_l3(1).Name);
lgraph = connectLayers(lgraph,[concat2.Name '/' 'out'],decoder_l2(1).Name);
lgraph = connectLayers(lgraph,[concat1.Name '/' 'out'],decoder_l1(1).Name);
end
function layers = createUnet3dEncoderModule(ModuleNum,NumFilters)
layers = [];
for id=1:length(NumFilters)
if id==1
sublayers = [
convolution3dLayer(3,NumFilters(id),'Padding','same', ...
'WeightsInitializer','narrow-normal', ...
'Name',iGetName('en','conv',ModuleNum,id));
batchNormalizationLayer('Name',iGetName('en','bn',ModuleNum,id));
reluLayer('Name',iGetName('en','relu',ModuleNum,id));
];
else
sublayers = [
convolution3dLayer(3,NumFilters(id),'Padding','same', ...
'WeightsInitializer','narrow-normal', ...
'Name',iGetName('en','conv',ModuleNum,id));
reluLayer('Name',iGetName('en','relu',ModuleNum,id));
];
end
layers = [layers; sublayers];
end
maxpool = maxPooling3dLayer(2,'stride',2,'Padding','same', ...
'Name',iGetName('en','maxpool',ModuleNum));
layers = [layers; maxpool];
end
function layers = createUnet3dDecoderModule(ModuleNum,NumFilters)
layers = [];
for id=1:length(NumFilters)
sublayers = [
convolution3dLayer(3,NumFilters(id),'Padding','same', ...
'WeightsInitializer','narrow-normal', ...
'Name',iGetName('de','conv',ModuleNum,id));
reluLayer('Name',iGetName('de','relu',ModuleNum,id));
];
layers = [layers; sublayers];
end
transConv = transposedConv3dLayer(2,NumFilters(end),'stride',2, ...
'Name',iGetName('de','transconv',ModuleNum));
layers = [layers; transConv];
end
function layers = createUnet3dFinalDecoderModule(ModuleNum,NumFilters)
layers = [];
for id=1:length(NumFilters)
sublayers = [
convolution3dLayer(3,NumFilters(id),'Padding','same', ...
'Name',iGetName('de','conv',ModuleNum,id));
reluLayer('Name',iGetName('de','relu',ModuleNum,id));
];
layers = [layers; sublayers];
end
numLabels = 2;
convLast = convolution3dLayer(1,numLabels,'Name','convLast');
softmaxL = softmaxLayer('Name','softmax');
pixelCL = dicePixelClassification3dLayer('output');
layers = [layers; convLast; softmaxL; pixelCL];
end
function myName = iGetName(moduleType,layerType,varargin)
if numel(varargin) == 1
myName = [moduleType num2str(varargin{1}) '_' layerType];
elseif numel(varargin) == 2
myName = [moduleType num2str(varargin{1}) '_' layerType num2str(varargin{2})];
end
end