Skip to content

Commit 16836ad

Browse files
committed
Issue #230 make TrainEA.getMethod() return the best member of population.
1 parent 0266b18 commit 16836ad

File tree

3 files changed

+66
-42
lines changed

3 files changed

+66
-42
lines changed

src/main/java/org/encog/Test.java

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,77 @@
1-
/*
2-
* Encog(tm) Core v3.4 - Java Version
3-
* http://www.heatonresearch.com/encog/
4-
* https://github.com/encog/encog-java-core
5-
6-
* Copyright 2008-2017 Heaton Research, Inc.
7-
*
8-
* Licensed under the Apache License, Version 2.0 (the "License");
9-
* you may not use this file except in compliance with the License.
10-
* You may obtain a copy of the License at
11-
*
12-
* http://www.apache.org/licenses/LICENSE-2.0
13-
*
14-
* Unless required by applicable law or agreed to in writing, software
15-
* distributed under the License is distributed on an "AS IS" BASIS,
16-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17-
* See the License for the specific language governing permissions and
18-
* limitations under the License.
19-
*
20-
* For more information on Heaton Research copyrights, licenses
21-
* and trademarks visit:
22-
* http://www.heatonresearch.com/copyright
23-
*/
241
package org.encog;
25-
26-
import org.encog.Encog;
2+
import org.encog.engine.network.activation.ActivationLinear;
3+
import org.encog.engine.network.activation.ActivationReLU;
274
import org.encog.engine.network.activation.ActivationSigmoid;
285
import org.encog.ml.data.MLData;
296
import org.encog.ml.data.MLDataPair;
307
import org.encog.ml.data.MLDataSet;
318
import org.encog.ml.data.basic.BasicMLDataSet;
32-
import org.encog.ml.importance.PerturbationFeatureImportanceCalc;
339
import org.encog.neural.networks.BasicNetwork;
3410
import org.encog.neural.networks.layers.BasicLayer;
3511
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
36-
import org.encog.neural.networks.training.propagation.sgd.StochasticGradientDescent;
37-
import org.encog.neural.networks.training.propagation.sgd.update.AdaGradUpdate;
38-
import org.encog.neural.networks.training.propagation.sgd.update.NesterovUpdate;
39-
import org.encog.neural.networks.training.propagation.sgd.update.RMSPropUpdate;
40-
import org.encog.neural.pattern.ElmanPattern;
41-
4212

13+
/**
14+
* XOR: This example is essentially the "Hello World" of neural network
15+
* programming. This example shows how to construct an Encog neural
16+
* network to predict the output from the XOR operator. This example
17+
* uses backpropagation to train the neural network.
18+
*
19+
* This example attempts to use a minimum of Encog features to create and
20+
* train the neural network. This allows you to see exactly what is going
21+
* on. For a more advanced example, that uses Encog factories, refer to
22+
* the XORFactory example.
23+
*
24+
*/
4325
public class Test {
4426

27+
/**
28+
* The input necessary for XOR.
29+
*/
30+
public static double XOR_INPUT[][] = { { 0.0, 0.0 }, { 1.0, 0.0 },
31+
{ 0.0, 1.0 }, { 1.0, 1.0 } };
32+
33+
/**
34+
* The ideal data necessary for XOR.
35+
*/
36+
public static double XOR_IDEAL[][] = { { 0.0 }, { 1.0 }, { 1.0 }, { 0.0 } };
37+
4538
/**
4639
* The main method.
4740
* @param args No arguments are used.
4841
*/
4942
public static void main(final String args[]) {
5043

51-
ElmanPattern elmanPat = new ElmanPattern();
52-
elmanPat.setInputNeurons(5);
53-
elmanPat.addHiddenLayer(5);
54-
elmanPat.setOutputNeurons(1);
55-
BasicNetwork network = (BasicNetwork) elmanPat.generate();
56-
System.out.println(network.toString());
44+
// create a neural network, without using a factory
45+
BasicNetwork network = new BasicNetwork();
46+
network.addLayer(new BasicLayer(new ActivationReLU(),true,2));
47+
network.addLayer(new BasicLayer(new ActivationSigmoid(),true,3));
48+
network.addLayer(new BasicLayer(new ActivationLinear(),false,1));
49+
network.getStructure().finalizeStructure();
50+
network.reset();
51+
52+
// create training data
53+
MLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL);
54+
55+
// train the neural network
56+
final ResilientPropagation train = new ResilientPropagation(network, trainingSet);
57+
58+
int epoch = 1;
59+
60+
do {
61+
train.iteration();
62+
System.out.println("Epoch #" + epoch + " Error:" + train.getError());
63+
epoch++;
64+
} while(train.getError() > 0.01);
65+
train.finishTraining();
66+
67+
// test the neural network
68+
System.out.println("Neural Network Results:");
69+
for(MLDataPair pair: trainingSet ) {
70+
final MLData output = network.compute(pair.getInput());
71+
System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1)
72+
+ ", actual=" + output.getData(0) + ",ideal=" + pair.getIdeal().getData(0));
73+
}
74+
75+
Encog.getInstance().shutdown();
5776
}
58-
}
77+
}

src/main/java/org/encog/ml/ea/train/basic/TrainEA.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.encog.ml.MLMethod;
3131
import org.encog.ml.TrainingImplementationType;
3232
import org.encog.ml.data.MLDataSet;
33+
import org.encog.ml.ea.genome.Genome;
3334
import org.encog.ml.ea.population.Population;
3435
import org.encog.ml.train.MLTrain;
3536
import org.encog.ml.train.strategy.Strategy;
@@ -170,7 +171,11 @@ public void finishTraining() {
170171
*/
171172
@Override
172173
public MLMethod getMethod() {
173-
return this.getPopulation();
174+
Genome g = this.getPopulation().getBestGenome();
175+
if(g==null || getCODEC()==null) {
176+
return null;
177+
}
178+
return getCODEC().decode(g);
174179
}
175180

176181
/**

src/main/java/org/encog/neural/networks/training/propagation/GradientWorker.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ public void process(final MLDataPair pair) {
216216
// Calculate error for the output layer.
217217
this.errorFunction.calculateError(
218218
this.network.getActivationFunctions()[0], this.layerSums,this.layerOutput,
219-
pair.getIdeal().getData(), this.actual, this.layerDelta, this.flatSpot[0],
219+
pair.getIdeal().getData(), this.actual, this.layerDelta, this.flatSpot[0],
220220
pair.getSignificance());
221221

222222
// Apply regularization, if requested.
@@ -255,7 +255,7 @@ private void processLevel(final int currentLevel) {
255255

256256
final int index = this.weightIndex[currentLevel];
257257
final ActivationFunction activation = this.network
258-
.getActivationFunctions()[currentLevel];
258+
.getActivationFunctions()[currentLevel + 1];
259259
final double currentFlatSpot = this.flatSpot[currentLevel + 1];
260260

261261
// handle weights

0 commit comments

Comments
 (0)