V1_0.java [src/java/m/ann/training] Revision: default  Date:
/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package m.ann.training;

import com.mongodb.BasicDBObject;
import csip.Config;
import utils.MongoAccess;
import csip.ModelDataService;
import csip.ServiceException;
import csip.annotations.Options;
import java.io.ByteArrayOutputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import javax.ws.rs.Path;
import m.ann.training.scale.DataSetIndices;
import m.ann.training.scale.ScalingMechanism;
import oms3.annotations.Description;
import oms3.annotations.Name;
import oms3.util.Statistics;
import org.apache.commons.lang.SerializationUtils;
import org.bson.Document;
import org.bson.types.ObjectId;
import org.codehaus.jettison.json.JSONException;
import org.encog.ml.CalculateScore;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.ea.train.EvolutionaryAlgorithm;
import org.encog.neural.neat.NEATLink;
import org.encog.neural.neat.NEATNetwork;
import org.encog.neural.neat.NEATPopulation;
import org.encog.neural.neat.NEATUtil;
import org.encog.neural.networks.training.TrainingSetScore;
import org.graphstream.graph.Graph;
import org.graphstream.stream.file.FileSink;
import org.graphstream.stream.file.FileSinkGEXF;
import org.graphstream.ui.graphicGraph.GraphicGraph;
import utils.CSIPMLDataSet;
import utils.Calc;
import utils.ErrorStatAnalysis;
import utils.Metadata;
import utils.MongoUtils.Sorting;

/**
 * Train ANN service.
 *
 * @author sidereus, od
 */
@Name("Train the network")
@Description("Train the network.")
@Path("m/train/1.0")
// Timeout after 365 days.
@Options(timeout = "P365D")
public class V1_0 extends ModelDataService {

  protected String ann_in;
  protected String ann_out;
  protected double trainingError;
  protected int maxEpochs;
  protected double trainingPerc;
  protected String splitMechanism;
  protected int populations;
  protected double connectionDensity;
  protected int recEpochs;
  protected double NSE_KGE;

  @Override
  public void preProcess() throws ServiceException, JSONException {

    ann_in = parameter().getString("annName");
    ann_out = parameter().getString("annName_out", ann_in);
    trainingError = parameter().getDouble(MongoAccess.TRAINING_ERROR, 0.1);
    checkDoubleParam(MongoAccess.TRAINING_ERROR, 0.0001, 0.1, trainingError);
    maxEpochs = parameter().getInt(MongoAccess.MAX_EPOCHS, 499);
    checkIntParam(MongoAccess.MAX_EPOCHS, 1, 150000, maxEpochs);
    trainingPerc = parameter().getDouble(MongoAccess.TRAINING_PERC, 0.8);
    checkDoubleParam(MongoAccess.TRAINING_PERC, 0.1, 0.99, trainingPerc);

    splitMechanism = parameter().getString(MongoAccess.SCALE_MECHANISM, "samedistribution");
    populations = parameter().getInt(MongoAccess.POPULATION, 1000);
    checkIntParam(MongoAccess.POPULATION, 50, 10000, populations);
    connectionDensity = parameter().getDouble(MongoAccess.CONNECTION_DENSITY, 1); // suggested: it speeds up the computation
    checkDoubleParam(MongoAccess.CONNECTION_DENSITY, 0.1, 1, connectionDensity);
    recEpochs = parameter().getInt(MongoAccess.RECOVERY_EPOCHS, 1000);
    checkIntParam(MongoAccess.RECOVERY_EPOCHS, 1, maxEpochs, recEpochs);
    NSE_KGE = parameter().getDouble("Min accuracy", 0.95);

  }

  /**
   * Call to the train method
   *
   * @throws ServiceException
   * @throws IOException
   */
  @Override
  public void doProcess() throws ServiceException, IOException {

    Document hyperParams = new Document(MongoAccess.TRAINING_ERROR, trainingError)
        .append(MongoAccess.MAX_EPOCHS, maxEpochs)
        .append(MongoAccess.RECOVERY_EPOCHS, recEpochs)
        .append(MongoAccess.TRAINING_PERC, trainingPerc)
        .append(MongoAccess.SCALE_MECHANISM, splitMechanism)
        .append(MongoAccess.POPULATION, populations)
        .append(MongoAccess.CONNECTION_DENSITY, connectionDensity);
    ExitStrategy exitStrategy = new ExitStrategy(trainingError, maxEpochs);

    train(ann_in, ann_out, exitStrategy, trainingPerc, splitMechanism, populations, connectionDensity, recEpochs, hyperParams);

    results().put("exit_reason", exitStrategy.getStrategy());
  }

  private void checkDoubleParam(String param, double minVal, double maxVal, double actualVal) {
    if (actualVal < minVal || actualVal > maxVal) {
      String msg = param + " is: " + actualVal;
      msg += ". Must be [" + minVal + ", " + maxVal + "]";
      throw new IllegalArgumentException(msg);
    }
  }

  private void checkIntParam(String param, int minVal, int maxVal, int actualVal) {
    if (actualVal < minVal || actualVal > maxVal) {
      String msg = param + " is: " + actualVal;
      msg += ". Must be [" + minVal + ", " + maxVal + "]";
      throw new IllegalArgumentException(msg);
    }
  }

  /*
    // 1
{
    "_id" : ObjectId("5a1309aa3e90b86a0804a5ee"),
    "name" : "slope",
    "type" : "in",
    "count" : 13,
    "min" : 0.4,
    "max" : 0.9,
    "min_index": 0,
    "max_index": 12,
    "date" : ISODate("2017-11-20T09:58:18.864-07:00"),
    "values" : [ 
        0.0, 
        0.0, 
        0.0, 
        0.0, 
        0.0, 
        0.0, 
        0.0, 
        0.6, 
        0.6, 
        0.0, 
        0.0, 
        0.108, 
        1.0
    ]
}
   */
  /**
   *
   * @TODO: compute several statistics on neural network accuracy
   * @TODO: manage multiple outputs
   *
   * @param ann
   * @param trainingError
   * @param trainingPerc
   * @param populations
   * @param connectionDensity
   * @throws ServiceException
   * @throws IOException
   */
  private void train(String ann_in, String ann_out, ExitStrategy exitStrategy,
      double trainingPerc, String splitMechanism, int populations, double connectionDensity,
      int recEpochs, Document hyperParams)
      throws ServiceException, IOException {

    // get ascending name sorted normalized data
    Iterable<Document> d = MongoAccess.getSortedNormalizedData(ann_in, MongoAccess.NORMALIZED, MongoAccess.NAME, Sorting.ASCENDING);

    // split data into training and validation data set 90-10
    ScalingMechanism scaleMechanism = ScalingMechanism.split(splitMechanism);
    DataSetIndices res = scaleMechanism.compute(d, trainingPerc);

    String nn_id = ObjectId.get().toString();

    // @TODO: convert to float
    List<Double> errorList = new ArrayList<>();
    Document netDocList = new Document();

    // create the ANN
    EvolutionaryAlgorithm trainingAlgorithm = trainNeuralNetwork(exitStrategy, populations, connectionDensity, d,
        res.getTraining(), res.getValidation(), ann_in, ann_out, nn_id, recEpochs, hyperParams, errorList, netDocList);
    int epochs = trainingAlgorithm.getIteration();
    double error = trainingAlgorithm.getError();

    LOG.info("Finished training:");
    LOG.info("   epochs: " + epochs);
    LOG.info("   error : " + error);

    // get the network.
    NEATNetwork nn = (NEATNetwork) trainingAlgorithm.getCODEC().decode(trainingAlgorithm.getBestGenome());

    // compute statistics
    List<Double> errorCheck = new ArrayList<>();
    Document errorStat = validate(nn, d, res.getValidation(), errorCheck);
    Document history = createHistory(epochs, errorList, exitStrategy.getStrategy());
    Document netDocStructure = createNetworkDocument(nn, nn_id, netDocList);
    Document metadata = createDataDocument(d, res);
    save(populations, epochs, error, errorStat, scaleMechanism.getStrategy(), exitStrategy.getStrategy(),
        connectionDensity, nn, metadata, ann_in, ann_out, nn_id, hyperParams, history, netDocStructure);

  }

  private Document createDataDocument(Iterable<Document> d, DataSetIndices res) {
    Document metadata = createDataDocument(d);
    return appendTrainValidIndices(metadata, res);
  }

  private Document appendTrainValidIndices(Document metadata, DataSetIndices res) {
    BasicDBObject bdb = new BasicDBObject();

    List<Integer> training = Arrays.stream(res.getTraining()).boxed().collect(Collectors.toList());
    bdb.append("traning", training);

    List<Integer> validation = Arrays.stream(res.getValidation()).boxed().collect(Collectors.toList());
    bdb.append("validation", validation);

    return metadata.append("indices", bdb);
  }

  private Document createDataDocument(Iterable<Document> d) {
    return new Document("metadata", MongoAccess.extractMetadata(d));
  }

  private void save(int populations, int epochs, double error, Document errorStat, String scaleStrategy,
      String exitStrategy, double connectionDensity, NEATNetwork nn, Document metadata,
      String ann_in, String ann_out, String nn_id, Document hyperParams,
      Document history, Document netDocStructure) throws IOException {

    String valid_id = (String) ((Document) (metadata.get("metadata", ArrayList.class).get(0))).get(MongoAccess.VAL_ID);
    // save the network, change the metadata
    byte[] network = SerializationUtils.serialize(nn);
    Document dd = new Document("nn_id", nn_id)
        .append(MongoAccess.POPULATION, populations)
        .append(MongoAccess.EPOCHS, epochs)
        .append(MongoAccess.SCORE, error)
        .append(MongoAccess.SCALE_MECHANISM, scaleStrategy)
        .append(MongoAccess.EXIT_STRATEGY, exitStrategy)
        .append(MongoAccess.CONNECTION_DENSITY, connectionDensity)
        .append(MongoAccess.VAL_ID, valid_id)
        .append(MongoAccess.VARIABLES, metadata)
        .append(MongoAccess.SUID, getSUID())
        .append(MongoAccess.HYPERPARAMS, hyperParams)
        .append(MongoAccess.PERFORMANCE, errorStat)
        .append(MongoAccess.HISTORY, history)
        .append(MongoAccess.BEST_NET_STRUCTURE, netDocStructure);
    MongoAccess.storeANN(ann_in, network, ann_out, dd, nn_id);
//    MongoAccess.validatePipeline(ann_in, MongoAccess.TRAINED);
  }

  private Document validate(NEATNetwork nn, Iterable<Document> d, int[] validIndices, List<Double> errorCheck) {

    MLDataSet validation = new CSIPMLDataSet(d, validIndices);
    int outputNodes = validation.getIdealSize();

    // @TODO: improve performance using double[] instead of List
    //        so to have fixed structure (Double -> double)
    ValidationResults vr = new ValidationResults(d);
    validation.forEach(data -> {
      MLData result = nn.compute(data.getInput());
      MLData observed = data.getIdeal();
      for (int i = 0; i < outputNodes; i++) {
        vr.add(i, observed.getData(i), result.getData(i));
      }
    });

    String[] outputVars = new String[outputNodes];
    int count = 0;
    for (Document document : d) {
      Document metadata = document.get(MongoAccess.METADATA, Document.class);
      String type = metadata.getString(MongoAccess.TYPE);
      if (type.equals(OUT)) {
        outputVars[count] = metadata.getString(MongoAccess.NAME);
        count++;
      }
    }

    Document listError = new Document();
    for (int i = 0; i < outputNodes; i++) {
      double[] obs = vr.getObserved(i);
      double[] sim = vr.getSimulated(i);
      listError.append(outputVars[i], ErrorStatAnalysis.compute(obs, sim, 2.0, -9999.0, errorCheck));
    }
    return listError;
  }
  
  private double validation(NEATNetwork nn, Iterable<Document> d, int[] validIndices) {
    MLDataSet validation = new CSIPMLDataSet(d, validIndices);

    List<Double> obs = new LinkedList();
    List<Double> sim = new LinkedList();
    validation.forEach(data -> {
      MLData result = nn.compute(data.getInput());
      MLData observed = data.getIdeal();
      double tmp = observed.getData(0);
      obs.add(tmp);
      tmp = result.getData(0);
      sim.add(tmp);
    });

    double[] ob = obs.stream().mapToDouble(Double::doubleValue).toArray();
    double[] si = sim.stream().mapToDouble(Double::doubleValue).toArray();
    return Statistics.mse(ob, si, -999.0);    
  }

  private EvolutionaryAlgorithm trainNeuralNetwork(ExitStrategy exitStrategy, int populations, double connectionDensity, Iterable<Document> d,
      int[] trainingIndeces, int[] validIndices, String ann_in, String ann_out, String nn_id, int recEpochs, Document hyperParams,
      List<Double> errorList, Document netDocList) throws ServiceException, IOException {
    MLDataSet trainingSet = new CSIPMLDataSet(d, trainingIndeces);
    if (trainingSet.getIdealSize() == 0) {
      String msg = "Training dataset has no output nodes. Check the raw db.";
      throw new UnsupportedOperationException(msg);
    }
    NEATPopulation population = new NEATPopulation(trainingSet.getInputSize(), trainingSet.getIdealSize(), populations);
    population.setInitialConnectionDensity(connectionDensity);
    population.reset();

    CalculateScore score = new TrainingSetScore(trainingSet);
    EvolutionaryAlgorithm trainingAlgorithm = NEATUtil.constructNEATTrainer(population, score);

    int epoch_progress = Config.getInt("ann.train.epoch_progress", 1);

    // train the network
    double previous = 0;
    boolean reachedAccuracy = Boolean.FALSE;
    List<Double> errorCheck;
    do {
      trainingAlgorithm.iteration();
      if ((trainingAlgorithm.getIteration() % epoch_progress) == 0) {
        double te = trainingAlgorithm.getError();
        String status = "DB: " + ann_in + " ANN Id: " + nn_id.substring(nn_id.length() - 5);
        status += " CPUs: " + Runtime.getRuntime().availableProcessors();
        status += " Epoch #" + trainingAlgorithm.getIteration();
        status += " Error: " + te + "  delta: " + (te - previous);
        previous = te;
        setProgress(status);
        LOG.info(status);
      }
      if ((trainingAlgorithm.getIteration() % recEpochs) == 0) {
        errorCheck = new LinkedList<>();
        NEATNetwork nn = (NEATNetwork) trainingAlgorithm.getCODEC().decode(trainingAlgorithm.getBestGenome());
        int epochs = trainingAlgorithm.getIteration();
        double error = trainingAlgorithm.getError();
        errorList.add(error);
        Document history = createHistory(epochs, errorList);
        Document netDocStructure = createNetworkDocument(nn, nn_id + "_" + epochs, netDocList);
        Document metadata = createDataDocument(d);
        save(populations, epochs, error, validate(nn, d, validIndices, errorCheck), "rec", "rec", connectionDensity,
            nn, metadata, ann_in, ann_out, nn_id, hyperParams, history, netDocStructure);
        reachedAccuracy = (errorCheck.get(0) > NSE_KGE && errorCheck.get(1) > NSE_KGE) ? Boolean.TRUE : Boolean.FALSE;
      }
    } while (exitStrategy.keepTraining(
        trainingAlgorithm,
        validation((NEATNetwork) trainingAlgorithm.getCODEC().decode(trainingAlgorithm.getBestGenome()), d, validIndices),
        reachedAccuracy));
    trainingAlgorithm.finishTraining();
    return trainingAlgorithm;
  }

  private Document createNetworkDocument(NEATNetwork nn, String nn_id, Document netDocList) {
    Document actual = createNetStructure(nn, nn_id);
    netDocList.append(nn_id, actual);
    return netDocList;
  }

  private Document createNetStructure(NEATNetwork nn, String nn_id) {
    NEATLink[] nl = nn.getLinks();

    Graph g = new GraphicGraph(nn_id);
    ArrayList<Integer> check = new ArrayList<>();
    for (NEATLink tmpLink : nl) {
      Integer from = tmpLink.getFromNeuron();
      Integer to = tmpLink.getToNeuron();
      String edge = Integer.toString(from) + Integer.toString(to);

      if (!check.contains(from)) {
        check.add(from);
        g.addNode(Integer.toString(from));
      }
      if (!check.contains(to)) {
        check.add(to);
        g.addNode(Integer.toString(to));
      }
      g.addEdge(edge, Integer.toString(from), Integer.toString(to), true);
    }

    FileSink fs = new FileSinkGEXF();
    ByteArrayOutputStream f;
    OutputStream out;
    Document d = null;
    try {
      f = new ByteArrayOutputStream();
      out = new ObjectOutputStream(f);
      fs.writeAll(g, out);
      int inNodes = nn.getInputCount();
      int outNodes = nn.getOutputCount();
      int hiddenNodes = Math.max(0, check.size() - inNodes - outNodes);

      d = new Document(MongoAccess.LINK_NUM, nl.length)
          .append(MongoAccess.IN_NODE_NUM, inNodes)
          .append(MongoAccess.OUT_NODE_NUM, outNodes)
          .append(MongoAccess.HIDE_NODE_NUM, hiddenNodes)
          .append(MongoAccess.STRUCTURE, f.toString());
    } catch (FileNotFoundException ex) {
      Logger.getLogger(V1_0.class.getName()).log(Level.SEVERE, null, ex);
    } catch (IOException ex) {
      Logger.getLogger(V1_0.class.getName()).log(Level.SEVERE, null, ex);
    }

    return d;
  }

  private Document createHistory(int epochs, List<Double> errorList) {
    return createHistory(epochs, errorList, "recovery");
  }

  private Document createHistory(int epochs, List<Double> errorList, String exitStrategy) {
    return new Document(MongoAccess.EPOCHS, epochs)
        .append(MongoAccess.EXIT_STRATEGY, exitStrategy)
        .append(MongoAccess.SCORES, errorList);
  }

  static class ValidationResults {

    private final Map<Integer, List<Double>> observed = new HashMap<>();
    private final Map<Integer, List<Double>> simulated = new HashMap<>();
    private final List<Metadata> metadata = new ArrayList<>();

    ValidationResults(Iterable<Document> d) {
      for (Document doc : d) {
        Document meta = doc.get(MongoAccess.METADATA, Document.class);
        String type = meta.getString(MongoAccess.TYPE);

        if (type.equals(OUT)) {
          double max = meta.getDouble(MongoAccess.MAX);
          double min = meta.getDouble(MongoAccess.MIN);
          boolean normal = meta.getBoolean(MongoAccess.NORM);
          double norm_max = meta.getDouble(MongoAccess.NORM_MAX);
          double norm_min = meta.getDouble(MongoAccess.NORM_MIN);
          metadata.add(new Metadata(normal, min, max, norm_min, norm_max));
        }
      }
    }

    public void add(int index, double obsVal, double simVal) {
      Metadata meta = metadata.get(index);
      if (!meta.getNorm()) {
        obsVal = Calc.denormalize(obsVal, meta.getMin(), meta.getMax(), meta.getNormMin(), meta.getNormMax());
        simVal = Calc.denormalize(simVal, meta.getMin(), meta.getMax(), meta.getNormMin(), meta.getNormMax());
      }
      if (observed.containsKey(index)) {
        observed.get(index).add(obsVal);
        simulated.get(index).add(simVal);
      } else {
        List<Double> obsList = new ArrayList<>();
        obsList.add(obsVal);
        observed.put(index, obsList);
        List<Double> simList = new ArrayList<>();
        simList.add(simVal);
        simulated.put(index, simList);
      }
    }

    double[] getObserved(int index) {
      return Calc.toDoubleArray(observed.get(index));
    }

    double[] getSimulated(int index) {
      return Calc.toDoubleArray(simulated.get(index));
    }
  }
}