ExitStrategy.java [src/java/m/ann/training] Revision: default Date:
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package m.ann.training;
import csip.Config;
import org.encog.ml.ea.train.EvolutionaryAlgorithm;
/**
*
* @author sidereus
*/
public class ExitStrategy {
private double trainingError;
private int maxEpochs;
private ConstantError errorCheck = new ConstantError();
private ValidationError validErrorCheck = new ValidationError();
private String strategy;
ExitStrategy(double trainingError, int maxEpochs) {
this.trainingError = trainingError;
this.maxEpochs = maxEpochs;
}
boolean keepTraining(EvolutionaryAlgorithm trainingAlgorithm) {
double error = trainingAlgorithm.getError();
if (error < trainingError) {
strategy = "threshold_error reached: " + trainingError;
return false;
} else if (trainingAlgorithm.getIteration() > maxEpochs) {
strategy = "max_epochs reached: " + maxEpochs;
return false;
} else if (errorCheck.isErrorConstant(error)) {
strategy = "constant_error: " + errorCheck.getReason();
return false;
} else {
return true;
}
}
boolean keepTraining(EvolutionaryAlgorithm trainingAlgorithm, double validError, boolean reachedAccuracy) {
double error = trainingAlgorithm.getError();
if (error < trainingError) {
strategy = "threshold_error reached: " + trainingError;
return false;
} else if (trainingAlgorithm.getIteration() > maxEpochs) {
strategy = "max_epochs reached: " + maxEpochs;
return false;
} else if (errorCheck.isErrorConstant(error)) {
strategy = "constant_error: " + errorCheck.getReason();
return false;
} else if (validErrorCheck.isErrorIncreasing(validError)) {
strategy = "overtraining";
return false;
} else if (reachedAccuracy) {
strategy = "reachedAccuracy";
return false;
} else {
return true;
}
}
public String getStrategy() {
return strategy;
}
static class ValidationError {
double previous = Double.MAX_VALUE;
int count = 0;
boolean isErrorIncreasing(double actual) {
double val = previous - actual;
previous = actual;
if (val < 0) {
count++;
} else {
count = 0;
}
return count > 10;
}
}
static class ConstantError {
final double threshold = Config.getDouble("ann.seqthresh", 0.000001);
final int seqError = Config.getInt("ann.seqmax", 99);
double previous = Double.MAX_VALUE;
int count = 0;
boolean isErrorConstant(double actual) {
double error = previous - actual;
if (error < 0) {
throw new IllegalArgumentException("Soaring error: " + error);
}
count = (error < threshold) ? count++ : 0;
previous = actual;
return count > seqError;
}
String getReason() {
return "Error was " + seqError + " times below " + threshold + ".";
}
}
}