ExitStrategy.java [src/java/m/ann/training] Revision: default  Date:
/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package m.ann.training;

import csip.Config;
import org.encog.ml.ea.train.EvolutionaryAlgorithm;

/**
 *
 * @author sidereus
 */
public class ExitStrategy {

  private double trainingError;
  private int maxEpochs;
  private ConstantError errorCheck = new ConstantError();
  private ValidationError validErrorCheck = new ValidationError();
  private String strategy;


  ExitStrategy(double trainingError, int maxEpochs) {
    this.trainingError = trainingError;
    this.maxEpochs = maxEpochs;
  }


  boolean keepTraining(EvolutionaryAlgorithm trainingAlgorithm) {

    double error = trainingAlgorithm.getError();

    if (error < trainingError) {
      strategy = "threshold_error reached: " + trainingError;
      return false;
    } else if (trainingAlgorithm.getIteration() > maxEpochs) {
      strategy = "max_epochs reached: " + maxEpochs;
      return false;
    } else if (errorCheck.isErrorConstant(error)) {
      strategy = "constant_error: " + errorCheck.getReason();
      return false;
    } else {
      return true;
    }
  }


  boolean keepTraining(EvolutionaryAlgorithm trainingAlgorithm, double validError, boolean reachedAccuracy) {

    double error = trainingAlgorithm.getError();

    if (error < trainingError) {
      strategy = "threshold_error reached: " + trainingError;
      return false;
    } else if (trainingAlgorithm.getIteration() > maxEpochs) {
      strategy = "max_epochs reached: " + maxEpochs;
      return false;
    } else if (errorCheck.isErrorConstant(error)) {
      strategy = "constant_error: " + errorCheck.getReason();
      return false;
    } else if (validErrorCheck.isErrorIncreasing(validError)) {
      strategy = "overtraining";
      return false;
    } else if (reachedAccuracy) {
      strategy = "reachedAccuracy";
      return false;
    } else {
      return true;
    }
  }


  public String getStrategy() {
    return strategy;
  }

  static class ValidationError {

    double previous = Double.MAX_VALUE;
    int count = 0;


    boolean isErrorIncreasing(double actual) {
      double val = previous - actual;
      previous = actual;
      if (val < 0) {
        count++;
      } else {
        count = 0;
      }
      return count > 10;
    }

  }

  static class ConstantError {

    final double threshold = Config.getDouble("ann.seqthresh", 0.000001);
    final int seqError = Config.getInt("ann.seqmax", 99);
    double previous = Double.MAX_VALUE;
    int count = 0;


    boolean isErrorConstant(double actual) {
      double error = previous - actual;
      if (error < 0) {
        throw new IllegalArgumentException("Soaring error: " + error);
      }
      count = (error < threshold) ? count++ : 0;
      previous = actual;
      return count > seqError;
    }


    String getReason() {
      return "Error was " + seqError + " times below " + threshold + ".";
    }

  }
}