@@ -42,7 +42,7 @@ public class ScatterView extends View {
4242
4343 private final int nPointsPerAxis = 100 ;
4444 private INDArray xyGrid ; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
45- private MultiLayerNetwork model ;
45+ INDArray modelOut = null ;
4646
4747 public ScatterView (Context context , @ Nullable AttributeSet attrs ) {
4848 super (context , attrs );
@@ -74,30 +74,25 @@ public void onDraw(Canvas canvas) {
7474 int h = this .getHeight ();
7575 int w = this .getWidth ();
7676
77- if (null == data ) {
78- canvas .drawColor (Color .rgb (32 , 32 , 32 ));
79- canvas .drawCircle (800 , 500 , 200 , redPaint );
80- canvas .drawCircle (325 , 900 , 300 , greenPaint );
81- } else {
82-
83- //draw the nn predictions:
77+ //draw the nn predictions:
78+ if ((modelOut != null ) && (null != xyGrid )){
8479 int halfRectHeight = h / nPointsPerAxis ;
8580 int halfRectWidth = w / nPointsPerAxis ;
86- INDArray modelOut = model .output (xyGrid );
87-
8881 int nRows = xyGrid .rows ();
8982
9083 for (int i = 0 ; i < nRows ; i ++){
91- int x = (int )(xyGrid .getFloat (i , 0 ) * w );
92- int y = (int ) (xyGrid .getFloat (i , 1 ) * h );
93- float z = modelOut .getFloat (i , 0 );
94- Paint p = (z >= 0.5f ) ? lightGreenPaint : lightRedPaint ;
95- canvas .drawRect (x -halfRectWidth , y -halfRectHeight , x +halfRectWidth , y +halfRectHeight , p );
96- // }
84+ int x = (int )(xyGrid .getFloat (i , 0 ) * w );
85+ int y = (int ) (xyGrid .getFloat (i , 1 ) * h );
86+ float z = modelOut .getFloat (i , 0 );
87+ Paint p = (z >= 0.5f ) ? lightGreenPaint : lightRedPaint ;
88+ canvas .drawRect (x -halfRectWidth , y -halfRectHeight , x +halfRectWidth , y +halfRectHeight , p );
89+ // }
9790 }
91+ }
9892
93+ //draw the data set if we have one.
94+ if (null != data ) {
9995
100- //draw the data set
10196 for (float [] datum : data ) {
10297 int x = (int ) (datum [1 ] * w );
10398 int y = (int ) (datum [2 ] * h );
@@ -173,11 +168,11 @@ private void normalizeColumn(int c, float[][] tmpData){
173168
174169 private void BuildNN (){
175170 int seed = 123 ;
176- double learningRate = 0.01 ;
171+ double learningRate = 0.005 ;
177172 int numInputs = 2 ;
178173 int numOutputs = 2 ;
179174 int numHiddenNodes = 20 ;
180- int nEpochs = 200 ;
175+ int nEpochs = 2000 ;
181176
182177 MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
183178 .seed (seed )
@@ -192,12 +187,18 @@ private void BuildNN(){
192187 .nIn (numHiddenNodes ).nOut (numOutputs ).build ())
193188 .build ();
194189
195- model = new MultiLayerNetwork (conf );
190+ MultiLayerNetwork model = new MultiLayerNetwork (conf );
196191 model .init ();
197192 model .setListeners (new ScoreIterationListener (10 ));
198193
199194 for (int i = 0 ; i <nEpochs ; i ++){
200195 model .fit (ds );
196+ INDArray tmp = model .output (xyGrid );
197+
198+ this .post (() -> {
199+ this .modelOut = tmp ; // update from within the UI thread.
200+ this .invalidate (); // have the UI thread redraw at its own convenience.
201+ });
201202 }
202203
203204 Evaluation eval = new Evaluation (numOutputs );
0 commit comments