77import org .deeplearning4j .datasets .fetchers .DataSetType ;
88import org .deeplearning4j .datasets .iterator .impl .Cifar10DataSetIterator ;
99import org .deeplearning4j .datasets .iterator .impl .MnistDataSetIterator ;
10+ import org .deeplearning4j .examples .samediff .training .SameDiffCustomListenerExample ;
1011import org .deeplearning4j .examples .samediff .training .SameDiffMNISTTrainingExample ;
1112import org .nd4j .autodiff .listeners .At ;
1213import org .nd4j .autodiff .listeners .BaseListener ;
4243/**
4344 * This is an example of doing transfer learning by importing a tensorflow model of mobilenet and replacing the last layer.
4445 *
45- * It turns the original imagenet model into a model for CIFAR 10.
46+ * It turns the original ImageNet model into a model for CIFAR 10.
4647 *
4748 * See {@link SameDiffTFImportMobileNetExample} for the model import example.
4849 * See {@link SameDiffMNISTTrainingExample} for the SameDiff training example.
50+ * See {@link SameDiffCustomListenerExample} for an example of how to use custom listeners (we use one here to find the shapes of an activation).
4951 *
5052 */
5153public class SameDiffTransferLearningExample {
5254
53- // Used to figure out the shapes of variables, needed to figure out how many channels are going into our added Conv layer
54- static class ShapeListener extends BaseListener {
55-
56- @ Override
57- public boolean isActive (Operation operation ) {
58- return true ;
59- }
60-
61- @ Override
62- public void activationAvailable (SameDiff sd , At at ,
63- MultiDataSet batch , SameDiffOp op ,
64- String varName , INDArray activation ) {
65- System .out .println (varName + ": \t \t \t " + Arrays .toString (activation .shape ()));
66-
67- if (varName .endsWith ("Shape" )){
68- System .out .println ("Shape value: " + activation );
69- }
70-
71- }
72- }
73-
74- /**
75- * Does inception preprocessing on a batch of images. Takes an image with shape [batchSize, c, h, w]
76- * and returns an image with shape [batchSize, height, width, c].
77- *
78- * @param height the height to resize to
79- * @param width the width to resize to
80- */
81- public static INDArray batchInceptionPreprocessing (INDArray img , int height , int width ){
82- // change to channels-last
83- img = img .permute (0 , 2 , 3 , 1 );
84-
85- // normalize to 0-1
86- img = img .div (256 );
87-
88- // resize
89- INDArray preprocessedImage = Nd4j .createUninitialized (img .size (0 ), height , width , img .size (3 ));
90-
91- DynamicCustomOp op = DynamicCustomOp .builder ("resize_bilinear" )
92- .addInputs (img )
93- .addOutputs (preprocessedImage )
94- .addIntegerArguments (height , width ).build ();
95- Nd4j .exec (op );
96-
97- // finish preprocessing
98- preprocessedImage = preprocessedImage .sub (0.5 );
99- preprocessedImage = preprocessedImage .mul (2 );
100- return preprocessedImage ;
101- }
102-
10355 public static void main (String [] args ) throws Exception {
10456 File modelFile = SameDiffTFImportMobileNetExample .downloadModel ();
10557
@@ -112,7 +64,14 @@ public static void main(String[] args) throws Exception {
11264
11365 System .out .println ("\n \n " );
11466
115- // Print shapes for each activation
67+
68+ // We want to replace the last convolution layer and the output layer with our own ops, so we can fine tune the network
69+ // These are the MobilenetV2/Logits and MobilenetV2/Predictions sections, respectively. See the printed summary.
70+
71+
72+ // Print shapes for each activation.
73+ // We need to know the shape (especially the channels) of the convolution op's input, so we know what shape to make the weight.
74+ // We use a custom listener for this, see SameDiffCustomListenerExample
11675
11776// INDArray test = new Cifar10DataSetIterator(10).next().getFeatures();
11877// test = batchInceptionPreprocessing(test, 224, 224);
@@ -123,23 +82,55 @@ public static void main(String[] args) throws Exception {
12382// .listeners(new ShapeListener())
12483// .execSingle();
12584
126- // get info for the last convolution layer (MobilenetV2/Logits)
85+ // get info for the last convolution layer (MobilenetV2/Logits). We want to use an equivalent config.
12786 Conv2D convOp = (Conv2D ) sd .getOpById ("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D" );
12887 System .out .println ("Conv config: " + convOp .getConfig ());
12988
130- // replace last convolution layer (MobilenetV2/Logits)
131- sd = GraphTransformUtil .replaceSubgraphsMatching (sd ,
89+ /*
90+ The MobilenetV2/Logits section looks like:
91+ MobilenetV2/Logits/AvgPool
92+ MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D
93+ MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd
94+ MobilenetV2/Logits/Squeeze
95+
96+ We want to replace the convolution layer (Conv2D and BiasAdd) with our own, so we can fine tune it.
97+
98+
99+ The SubGraphPredicate will select a subset of the graph by starting at the root node,
100+ and then optionally applying SubGraphPredicate's for inputs.
101+ Those SubGraphPredicate's can also add their inputs, etc.
102+
103+ The predicate will only accept a subgraph if it passes all the filters.
104+ */
105+
106+ // Create a predicate for selecting the BiasAdd and Conv2D ops we want
107+ SubGraphPredicate pred1 =
108+ // Select the subgraph with root MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd
132109 SubGraphPredicate .withRoot (OpPredicate .nameMatches ("MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd" ))
133- .withInputSubgraph (0 , OpPredicate .nameMatches ("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D" )),
134- (sd1 , subGraph ) -> {
110+ // Select (and require) the BiasAdd's 0th input to be MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D
111+ .withInputSubgraph (0 , OpPredicate .nameMatches ("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D" ));
112+
113+
114+ /*
115+ Replace any subgraphs matching the predicate with our own subgraph
116+ There will only be one match, but you can use SubGraphPredicate and GraphTransformUtil to replace many occurrences of the same subgraph.
135117
118+ The number of outputs from the replacement subgraph must match the number of outputs of the subgraph it is replacing.
119+
120+ Note that the graph isn't actually modified, a copy is made, modified, and then returned.
121+ */
122+ sd = GraphTransformUtil .replaceSubgraphsMatching (sd ,
123+ pred1 ,
124+ (sd1 , subGraph ) -> {
136125 NameScope logits = sd1 .withNameScope ("Logits/Conv2D" );
137126
138127 // get the output of the AveragePooling op
139128 SDVariable input = subGraph .inputs ().get (1 );
140129
141130 // we know the sizes from using the ShapeListener earlier
142131
132+ // We know what shape the weight needs to be from the input's channels and the config's kernel height and width.
133+ // This is why we printed the shapes.
143134 SDVariable w = sd1 .var ("W" , new XavierInitScheme ('c' , 5 * 5 * 8 , 10 ), DataType .FLOAT ,
144135 1 , 1 , 1280 , 10 );
145136
@@ -155,17 +146,38 @@ public static void main(String[] args) throws Exception {
155146 return Collections .singletonList (output );
156147 });
157148
158- // create SubGraphPredicate for selecting the MobilenetV2/Predictions ops
159- SubGraphPredicate graphPred = SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Reshape_1" ))
160- .withInputSubgraph (0 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Softmax" ))
161- .withInputSubgraph (0 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Reshape" ))))
162- .withInputSubgraph (1 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Shape" )));
149+ /*
150+ The MobilenetV2/Predictions section looks like:
151+ MobilenetV2/Predictions/Reshape/shape
152+ MobilenetV2/Predictions/Reshape
153+ MobilenetV2/Predictions/Softmax
154+ MobilenetV2/Predictions/Shape
155+ MobilenetV2/Predictions/Reshape_1
163156
164- // replace the MobilenetV2/Predictions with our own softmax and loss
157+ We want to replace the reshapes (unneeded and the wrong shape) and the softmax (we need a loss function and an output function).
158+ You could keep the softmax, but there is no reason to.
159+
160+ We also need to add a labels input.
161+
162+ Note that this subgraph has no outputs, so neither should the replacement subgraph.
163+ */
164+
165+ // create SubGraphPredicate for selecting the MobilenetV2/Predictions ops
166+ SubGraphPredicate pred2 =
167+ // Select a subgraph starting with the Reshape_1 op
168+ SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Reshape_1" ))
169+ // Add the 0th input to the subgraph if it is the specified Softmax Op
170+ .withInputSubgraph (0 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Softmax" ))
171+ // Add the 0th input of the Softmax op to the subgraph, as long as it is the specified Reshape op
172+ .withInputSubgraph (0 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Reshape" ))))
173+ // Add the 1st input to the subgraph if it is the specified Shape Op
174+ .withInputSubgraph (1 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Shape" )));
175+
176+ // Replace any subgraphs matching the predicate with our own subgraph
177+ // There will only be one match, but you can use SubGraphPredicate and GraphTransformUtil to replace many occurrences of the same subgraph
165178 sd = GraphTransformUtil .replaceSubgraphsMatching (sd ,
166- graphPred ,
179+ pred2 ,
167180 (sd1 , subGraph ) -> {
168-
169181 // placeholder for labels (needed for training)
170182 SDVariable labels = sd1 .placeHolder ("label" , DataType .FLOAT , -1 , 10 );
171183
@@ -186,8 +198,8 @@ public static void main(String[] args) throws Exception {
186198 });
187199
188200
189- // replace the input with input and inception preprocessing (except for resizing, which is done as part of the record reader)
190- // can 't do this with GraphTransformUtil as it can't replace variables or re-use ops
201+ // Add inception preprocessing to the input (except for resizing, which is done as part of the record reader)
202+ // Can 't do this with GraphTransformUtil as it can't replace variables or re-use ops
191203
192204 SDVariable input = sd .getVariable ("input" );
193205
@@ -200,6 +212,7 @@ public static void main(String[] args) throws Exception {
200212 // change range to -1 - 1
201213 SDVariable processed = normalized .sub (0.5 ).mul (2 );
202214
215+ // The 0th arg was input, replace it with the preprocessed input
203216 sd .getOpById ("MobilenetV2/Conv/Conv2D" ).replaceArg (0 , processed );
204217
205218
@@ -254,4 +267,58 @@ public static void main(String[] args) throws Exception {
254267
255268 System .out .println ("Accuracy: " + acc );
256269 }
270+
271+ /**
272+ * Used to figure out the shapes of variables, needed to figure out how many channels are going into our added Conv layer
273+ *
274+ * See {@link SameDiffCustomListenerExample}
275+ */
276+ static class ShapeListener extends BaseListener {
277+
278+ @ Override
279+ public boolean isActive (Operation operation ) {
280+ return true ;
281+ }
282+
283+ @ Override
284+ public void activationAvailable (SameDiff sd , At at ,
285+ MultiDataSet batch , SameDiffOp op ,
286+ String varName , INDArray activation ) {
287+ System .out .println (varName + ": \t \t \t " + Arrays .toString (activation .shape ()));
288+
289+ if (varName .endsWith ("Shape" )){
290+ System .out .println ("Shape value: " + activation );
291+ }
292+
293+ }
294+ }
295+
296+ /**
297+ * Does inception preprocessing on a batch of images. Takes an image with shape [batchSize, c, h, w]
298+ * and returns an image with shape [batchSize, height, width, c].
299+ *
300+ * @param height the height to resize to
301+ * @param width the width to resize to
302+ */
303+ public static INDArray batchInceptionPreprocessing (INDArray img , int height , int width ){
304+ // change to channels-last
305+ img = img .permute (0 , 2 , 3 , 1 );
306+
307+ // normalize to 0-1
308+ img = img .div (256 );
309+
310+ // resize
311+ INDArray preprocessedImage = Nd4j .createUninitialized (img .size (0 ), height , width , img .size (3 ));
312+
313+ DynamicCustomOp op = DynamicCustomOp .builder ("resize_bilinear" )
314+ .addInputs (img )
315+ .addOutputs (preprocessedImage )
316+ .addIntegerArguments (height , width ).build ();
317+ Nd4j .exec (op );
318+
319+ // finish preprocessing
320+ preprocessedImage = preprocessedImage .sub (0.5 );
321+ preprocessedImage = preprocessedImage .mul (2 );
322+ return preprocessedImage ;
323+ }
257324}
0 commit comments