V1_0.java [src/java/m/ann/training] Revision: default Date:
/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package m.ann.training;
import com.mongodb.BasicDBObject;
import csip.Config;
import utils.MongoAccess;
import csip.ModelDataService;
import csip.ServiceException;
import csip.annotations.Options;
import java.io.ByteArrayOutputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import javax.ws.rs.Path;
import m.ann.training.scale.DataSetIndices;
import m.ann.training.scale.ScalingMechanism;
import oms3.annotations.Description;
import oms3.annotations.Name;
import oms3.util.Statistics;
import org.apache.commons.lang.SerializationUtils;
import org.bson.Document;
import org.bson.types.ObjectId;
import org.codehaus.jettison.json.JSONException;
import org.encog.ml.CalculateScore;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.ea.train.EvolutionaryAlgorithm;
import org.encog.neural.neat.NEATLink;
import org.encog.neural.neat.NEATNetwork;
import org.encog.neural.neat.NEATPopulation;
import org.encog.neural.neat.NEATUtil;
import org.encog.neural.networks.training.TrainingSetScore;
import org.graphstream.graph.Graph;
import org.graphstream.stream.file.FileSink;
import org.graphstream.stream.file.FileSinkGEXF;
import org.graphstream.ui.graphicGraph.GraphicGraph;
import utils.CSIPMLDataSet;
import utils.Calc;
import utils.ErrorStatAnalysis;
import utils.Metadata;
import utils.MongoUtils.Sorting;
/**
* Train ANN service.
*
* @author sidereus, od
*/
@Name("Train the network")
@Description("Train the network.")
@Path("m/train/1.0")
// Timeout after 365 days.
@Options(timeout = "P365D")
public class V1_0 extends ModelDataService {
protected String ann_in;
protected String ann_out;
protected double trainingError;
protected int maxEpochs;
protected double trainingPerc;
protected String splitMechanism;
protected int populations;
protected double connectionDensity;
protected int recEpochs;
protected double NSE_KGE;
@Override
public void preProcess() throws ServiceException, JSONException {
ann_in = parameter().getString("annName");
ann_out = parameter().getString("annName_out", ann_in);
trainingError = parameter().getDouble(MongoAccess.TRAINING_ERROR, 0.1);
checkDoubleParam(MongoAccess.TRAINING_ERROR, 0.0001, 0.1, trainingError);
maxEpochs = parameter().getInt(MongoAccess.MAX_EPOCHS, 499);
checkIntParam(MongoAccess.MAX_EPOCHS, 1, 150000, maxEpochs);
trainingPerc = parameter().getDouble(MongoAccess.TRAINING_PERC, 0.8);
checkDoubleParam(MongoAccess.TRAINING_PERC, 0.1, 0.99, trainingPerc);
splitMechanism = parameter().getString(MongoAccess.SCALE_MECHANISM, "samedistribution");
populations = parameter().getInt(MongoAccess.POPULATION, 1000);
checkIntParam(MongoAccess.POPULATION, 50, 10000, populations);
connectionDensity = parameter().getDouble(MongoAccess.CONNECTION_DENSITY, 1); // suggested: it speeds up the computation
checkDoubleParam(MongoAccess.CONNECTION_DENSITY, 0.1, 1, connectionDensity);
recEpochs = parameter().getInt(MongoAccess.RECOVERY_EPOCHS, 1000);
checkIntParam(MongoAccess.RECOVERY_EPOCHS, 1, maxEpochs, recEpochs);
NSE_KGE = parameter().getDouble("Min accuracy", 0.95);
}
/**
* Call to the train method
*
* @throws ServiceException
* @throws IOException
*/
@Override
public void doProcess() throws ServiceException, IOException {
Document hyperParams = new Document(MongoAccess.TRAINING_ERROR, trainingError)
.append(MongoAccess.MAX_EPOCHS, maxEpochs)
.append(MongoAccess.RECOVERY_EPOCHS, recEpochs)
.append(MongoAccess.TRAINING_PERC, trainingPerc)
.append(MongoAccess.SCALE_MECHANISM, splitMechanism)
.append(MongoAccess.POPULATION, populations)
.append(MongoAccess.CONNECTION_DENSITY, connectionDensity);
ExitStrategy exitStrategy = new ExitStrategy(trainingError, maxEpochs);
train(ann_in, ann_out, exitStrategy, trainingPerc, splitMechanism, populations, connectionDensity, recEpochs, hyperParams);
results().put("exit_reason", exitStrategy.getStrategy());
}
private void checkDoubleParam(String param, double minVal, double maxVal, double actualVal) {
if (actualVal < minVal || actualVal > maxVal) {
String msg = param + " is: " + actualVal;
msg += ". Must be [" + minVal + ", " + maxVal + "]";
throw new IllegalArgumentException(msg);
}
}
private void checkIntParam(String param, int minVal, int maxVal, int actualVal) {
if (actualVal < minVal || actualVal > maxVal) {
String msg = param + " is: " + actualVal;
msg += ". Must be [" + minVal + ", " + maxVal + "]";
throw new IllegalArgumentException(msg);
}
}
/*
// 1
{
"_id" : ObjectId("5a1309aa3e90b86a0804a5ee"),
"name" : "slope",
"type" : "in",
"count" : 13,
"min" : 0.4,
"max" : 0.9,
"min_index": 0,
"max_index": 12,
"date" : ISODate("2017-11-20T09:58:18.864-07:00"),
"values" : [
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.6,
0.6,
0.0,
0.0,
0.108,
1.0
]
}
*/
/**
*
* @TODO: compute several statistics on neural network accuracy
* @TODO: manage multiple outputs
*
* @param ann
* @param trainingError
* @param trainingPerc
* @param populations
* @param connectionDensity
* @throws ServiceException
* @throws IOException
*/
private void train(String ann_in, String ann_out, ExitStrategy exitStrategy,
double trainingPerc, String splitMechanism, int populations, double connectionDensity,
int recEpochs, Document hyperParams)
throws ServiceException, IOException {
// get ascending name sorted normalized data
Iterable<Document> d = MongoAccess.getSortedNormalizedData(ann_in, MongoAccess.NORMALIZED, MongoAccess.NAME, Sorting.ASCENDING);
// split data into training and validation data set 90-10
ScalingMechanism scaleMechanism = ScalingMechanism.split(splitMechanism);
DataSetIndices res = scaleMechanism.compute(d, trainingPerc);
String nn_id = ObjectId.get().toString();
// @TODO: convert to float
List<Double> errorList = new ArrayList<>();
Document netDocList = new Document();
// create the ANN
EvolutionaryAlgorithm trainingAlgorithm = trainNeuralNetwork(exitStrategy, populations, connectionDensity, d,
res.getTraining(), res.getValidation(), ann_in, ann_out, nn_id, recEpochs, hyperParams, errorList, netDocList);
int epochs = trainingAlgorithm.getIteration();
double error = trainingAlgorithm.getError();
LOG.info("Finished training:");
LOG.info(" epochs: " + epochs);
LOG.info(" error : " + error);
// get the network.
NEATNetwork nn = (NEATNetwork) trainingAlgorithm.getCODEC().decode(trainingAlgorithm.getBestGenome());
// compute statistics
List<Double> errorCheck = new ArrayList<>();
Document errorStat = validate(nn, d, res.getValidation(), errorCheck);
Document history = createHistory(epochs, errorList, exitStrategy.getStrategy());
Document netDocStructure = createNetworkDocument(nn, nn_id, netDocList);
Document metadata = createDataDocument(d, res);
save(populations, epochs, error, errorStat, scaleMechanism.getStrategy(), exitStrategy.getStrategy(),
connectionDensity, nn, metadata, ann_in, ann_out, nn_id, hyperParams, history, netDocStructure);
}
private Document createDataDocument(Iterable<Document> d, DataSetIndices res) {
Document metadata = createDataDocument(d);
return appendTrainValidIndices(metadata, res);
}
private Document appendTrainValidIndices(Document metadata, DataSetIndices res) {
BasicDBObject bdb = new BasicDBObject();
List<Integer> training = Arrays.stream(res.getTraining()).boxed().collect(Collectors.toList());
bdb.append("traning", training);
List<Integer> validation = Arrays.stream(res.getValidation()).boxed().collect(Collectors.toList());
bdb.append("validation", validation);
return metadata.append("indices", bdb);
}
private Document createDataDocument(Iterable<Document> d) {
return new Document("metadata", MongoAccess.extractMetadata(d));
}
private void save(int populations, int epochs, double error, Document errorStat, String scaleStrategy,
String exitStrategy, double connectionDensity, NEATNetwork nn, Document metadata,
String ann_in, String ann_out, String nn_id, Document hyperParams,
Document history, Document netDocStructure) throws IOException {
String valid_id = (String) ((Document) (metadata.get("metadata", ArrayList.class).get(0))).get(MongoAccess.VAL_ID);
// save the network, change the metadata
byte[] network = SerializationUtils.serialize(nn);
Document dd = new Document("nn_id", nn_id)
.append(MongoAccess.POPULATION, populations)
.append(MongoAccess.EPOCHS, epochs)
.append(MongoAccess.SCORE, error)
.append(MongoAccess.SCALE_MECHANISM, scaleStrategy)
.append(MongoAccess.EXIT_STRATEGY, exitStrategy)
.append(MongoAccess.CONNECTION_DENSITY, connectionDensity)
.append(MongoAccess.VAL_ID, valid_id)
.append(MongoAccess.VARIABLES, metadata)
.append(MongoAccess.SUID, getSUID())
.append(MongoAccess.HYPERPARAMS, hyperParams)
.append(MongoAccess.PERFORMANCE, errorStat)
.append(MongoAccess.HISTORY, history)
.append(MongoAccess.BEST_NET_STRUCTURE, netDocStructure);
MongoAccess.storeANN(ann_in, network, ann_out, dd, nn_id);
// MongoAccess.validatePipeline(ann_in, MongoAccess.TRAINED);
}
private Document validate(NEATNetwork nn, Iterable<Document> d, int[] validIndices, List<Double> errorCheck) {
MLDataSet validation = new CSIPMLDataSet(d, validIndices);
int outputNodes = validation.getIdealSize();
// @TODO: improve performance using double[] instead of List
// so to have fixed structure (Double -> double)
ValidationResults vr = new ValidationResults(d);
validation.forEach(data -> {
MLData result = nn.compute(data.getInput());
MLData observed = data.getIdeal();
for (int i = 0; i < outputNodes; i++) {
vr.add(i, observed.getData(i), result.getData(i));
}
});
String[] outputVars = new String[outputNodes];
int count = 0;
for (Document document : d) {
Document metadata = document.get(MongoAccess.METADATA, Document.class);
String type = metadata.getString(MongoAccess.TYPE);
if (type.equals(OUT)) {
outputVars[count] = metadata.getString(MongoAccess.NAME);
count++;
}
}
Document listError = new Document();
for (int i = 0; i < outputNodes; i++) {
double[] obs = vr.getObserved(i);
double[] sim = vr.getSimulated(i);
listError.append(outputVars[i], ErrorStatAnalysis.compute(obs, sim, 2.0, -9999.0, errorCheck));
}
return listError;
}
private double validation(NEATNetwork nn, Iterable<Document> d, int[] validIndices) {
MLDataSet validation = new CSIPMLDataSet(d, validIndices);
List<Double> obs = new LinkedList();
List<Double> sim = new LinkedList();
validation.forEach(data -> {
MLData result = nn.compute(data.getInput());
MLData observed = data.getIdeal();
double tmp = observed.getData(0);
obs.add(tmp);
tmp = result.getData(0);
sim.add(tmp);
});
double[] ob = obs.stream().mapToDouble(Double::doubleValue).toArray();
double[] si = sim.stream().mapToDouble(Double::doubleValue).toArray();
return Statistics.mse(ob, si, -999.0);
}
private EvolutionaryAlgorithm trainNeuralNetwork(ExitStrategy exitStrategy, int populations, double connectionDensity, Iterable<Document> d,
int[] trainingIndeces, int[] validIndices, String ann_in, String ann_out, String nn_id, int recEpochs, Document hyperParams,
List<Double> errorList, Document netDocList) throws ServiceException, IOException {
MLDataSet trainingSet = new CSIPMLDataSet(d, trainingIndeces);
if (trainingSet.getIdealSize() == 0) {
String msg = "Training dataset has no output nodes. Check the raw db.";
throw new UnsupportedOperationException(msg);
}
NEATPopulation population = new NEATPopulation(trainingSet.getInputSize(), trainingSet.getIdealSize(), populations);
population.setInitialConnectionDensity(connectionDensity);
population.reset();
CalculateScore score = new TrainingSetScore(trainingSet);
EvolutionaryAlgorithm trainingAlgorithm = NEATUtil.constructNEATTrainer(population, score);
int epoch_progress = Config.getInt("ann.train.epoch_progress", 1);
// train the network
double previous = 0;
boolean reachedAccuracy = Boolean.FALSE;
List<Double> errorCheck;
do {
trainingAlgorithm.iteration();
if ((trainingAlgorithm.getIteration() % epoch_progress) == 0) {
double te = trainingAlgorithm.getError();
String status = "DB: " + ann_in + " ANN Id: " + nn_id.substring(nn_id.length() - 5);
status += " CPUs: " + Runtime.getRuntime().availableProcessors();
status += " Epoch #" + trainingAlgorithm.getIteration();
status += " Error: " + te + " delta: " + (te - previous);
previous = te;
setProgress(status);
LOG.info(status);
}
if ((trainingAlgorithm.getIteration() % recEpochs) == 0) {
errorCheck = new LinkedList<>();
NEATNetwork nn = (NEATNetwork) trainingAlgorithm.getCODEC().decode(trainingAlgorithm.getBestGenome());
int epochs = trainingAlgorithm.getIteration();
double error = trainingAlgorithm.getError();
errorList.add(error);
Document history = createHistory(epochs, errorList);
Document netDocStructure = createNetworkDocument(nn, nn_id + "_" + epochs, netDocList);
Document metadata = createDataDocument(d);
save(populations, epochs, error, validate(nn, d, validIndices, errorCheck), "rec", "rec", connectionDensity,
nn, metadata, ann_in, ann_out, nn_id, hyperParams, history, netDocStructure);
reachedAccuracy = (errorCheck.get(0) > NSE_KGE && errorCheck.get(1) > NSE_KGE) ? Boolean.TRUE : Boolean.FALSE;
}
} while (exitStrategy.keepTraining(
trainingAlgorithm,
validation((NEATNetwork) trainingAlgorithm.getCODEC().decode(trainingAlgorithm.getBestGenome()), d, validIndices),
reachedAccuracy));
trainingAlgorithm.finishTraining();
return trainingAlgorithm;
}
private Document createNetworkDocument(NEATNetwork nn, String nn_id, Document netDocList) {
Document actual = createNetStructure(nn, nn_id);
netDocList.append(nn_id, actual);
return netDocList;
}
private Document createNetStructure(NEATNetwork nn, String nn_id) {
NEATLink[] nl = nn.getLinks();
Graph g = new GraphicGraph(nn_id);
ArrayList<Integer> check = new ArrayList<>();
for (NEATLink tmpLink : nl) {
Integer from = tmpLink.getFromNeuron();
Integer to = tmpLink.getToNeuron();
String edge = Integer.toString(from) + Integer.toString(to);
if (!check.contains(from)) {
check.add(from);
g.addNode(Integer.toString(from));
}
if (!check.contains(to)) {
check.add(to);
g.addNode(Integer.toString(to));
}
g.addEdge(edge, Integer.toString(from), Integer.toString(to), true);
}
FileSink fs = new FileSinkGEXF();
ByteArrayOutputStream f;
OutputStream out;
Document d = null;
try {
f = new ByteArrayOutputStream();
out = new ObjectOutputStream(f);
fs.writeAll(g, out);
int inNodes = nn.getInputCount();
int outNodes = nn.getOutputCount();
int hiddenNodes = Math.max(0, check.size() - inNodes - outNodes);
d = new Document(MongoAccess.LINK_NUM, nl.length)
.append(MongoAccess.IN_NODE_NUM, inNodes)
.append(MongoAccess.OUT_NODE_NUM, outNodes)
.append(MongoAccess.HIDE_NODE_NUM, hiddenNodes)
.append(MongoAccess.STRUCTURE, f.toString());
} catch (FileNotFoundException ex) {
Logger.getLogger(V1_0.class.getName()).log(Level.SEVERE, null, ex);
} catch (IOException ex) {
Logger.getLogger(V1_0.class.getName()).log(Level.SEVERE, null, ex);
}
return d;
}
private Document createHistory(int epochs, List<Double> errorList) {
return createHistory(epochs, errorList, "recovery");
}
private Document createHistory(int epochs, List<Double> errorList, String exitStrategy) {
return new Document(MongoAccess.EPOCHS, epochs)
.append(MongoAccess.EXIT_STRATEGY, exitStrategy)
.append(MongoAccess.SCORES, errorList);
}
static class ValidationResults {
private final Map<Integer, List<Double>> observed = new HashMap<>();
private final Map<Integer, List<Double>> simulated = new HashMap<>();
private final List<Metadata> metadata = new ArrayList<>();
ValidationResults(Iterable<Document> d) {
for (Document doc : d) {
Document meta = doc.get(MongoAccess.METADATA, Document.class);
String type = meta.getString(MongoAccess.TYPE);
if (type.equals(OUT)) {
double max = meta.getDouble(MongoAccess.MAX);
double min = meta.getDouble(MongoAccess.MIN);
boolean normal = meta.getBoolean(MongoAccess.NORM);
double norm_max = meta.getDouble(MongoAccess.NORM_MAX);
double norm_min = meta.getDouble(MongoAccess.NORM_MIN);
metadata.add(new Metadata(normal, min, max, norm_min, norm_max));
}
}
}
public void add(int index, double obsVal, double simVal) {
Metadata meta = metadata.get(index);
if (!meta.getNorm()) {
obsVal = Calc.denormalize(obsVal, meta.getMin(), meta.getMax(), meta.getNormMin(), meta.getNormMax());
simVal = Calc.denormalize(simVal, meta.getMin(), meta.getMax(), meta.getNormMin(), meta.getNormMax());
}
if (observed.containsKey(index)) {
observed.get(index).add(obsVal);
simulated.get(index).add(simVal);
} else {
List<Double> obsList = new ArrayList<>();
obsList.add(obsVal);
observed.put(index, obsList);
List<Double> simList = new ArrayList<>();
simList.add(simVal);
simulated.put(index, simList);
}
}
double[] getObserved(int index) {
return Calc.toDoubleArray(observed.get(index));
}
double[] getSimulated(int index) {
return Calc.toDoubleArray(simulated.get(index));
}
}
}