ANN_ModelDataService.java [src/java/crp/utils] Revision: default Date:
/*
* $Id$
*
* This file is part of the Cloud Services Integration Platform (CSIP),
* a Model-as-a-Service framework, API, and application suite.
*
* 2012-2019, OMSLab, Colorado State University.
*
* OMSLab licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package crp.utils;
import com.google.common.math.Quantiles;
import csip.ModelDataService;
import csip.ServiceException;
import csip.utils.JSONUtils;
import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import static javax.xml.bind.DatatypeConverter.parseBase64Binary;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.WildcardFileFilter;
import org.apache.commons.lang3.SerializationUtils;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;
import org.encog.ml.data.MLData;
import org.encog.neural.neat.NEATNetwork;
/**
*
* @author <a href="mailto:shaun.case@colostate.edu">Shaun Case</a>
*/
public abstract class ANN_ModelDataService extends ModelDataService {
public static final String ANN_REQUEST_NAME = "ann_name";
protected String annName;
protected List<double[]> outputData;
protected ConcurrentHashMap<String, NetworkVariable> out;
protected ANN_Network nn;
/**
* Unpacks the zip file identified by the "name" parameter and loads all of
* its ann's into a list to be stored in the child class in a static
* SimpleCache
*
* @param name
* @return
*/
protected ANN_Network getANNs(String name) {
try {
File ens = resources().getFile(name);
if (ens.isDirectory()) {
File[] files = ens.listFiles((FilenameFilter) new WildcardFileFilter("*.ann"));
List<NEATNetwork> l = new ArrayList<>();
for (File file : files) {
byte[] ann = parseBase64Binary(FileUtils.readFileToString(file)); //.readFileToByteArray(file);
l.add((NEATNetwork) SerializationUtils.deserialize(ann));
}
if (l.isEmpty()) {
throw new IllegalArgumentException("No anns found for zipped ensemble: " + name);
}
ANN_Network annNetwork = new ANN_Network();
annNetwork.anns = l;
String metaDataString;
File[] mainMetaFiles = ens.listFiles((FilenameFilter) new WildcardFileFilter(name + ".json"));
if (mainMetaFiles.length == 1) {
File mainMetaFile = mainMetaFiles[0];
metaDataString = FileUtils.readFileToString(mainMetaFile);
readModelMetaData(annNetwork, metaDataString);
return annNetwork;
} else {
throw new ServiceException("Cannot find that ANN Model's metadata JSON file.");
}
}
} catch (ServiceException | IOException ex) {
throw new IllegalArgumentException("Failed to load ensemble:: " + name, ex);
}
throw new IllegalArgumentException("Not an ann ensemble resource: " + name);
}
/**
* Computes the output from all ANNs
*
* @param nn
* @param input
* @return
*/
protected List<double[]> compute(List<NEATNetwork> nn, MLData input) {
return nn.parallelStream()
.map(n -> n.compute(input).getData())
.collect(Collectors.toList());
}
protected abstract ANN_Network getAnnList() throws ServiceException;
protected MLData getMappedInput(ANN_Network nn, Map<String, NetworkVariable> out) throws ServiceException {
ArrayList<Double> inData = new ArrayList<>();
for (NetworkVariable variable : nn.variables) {
if (variable.type.equalsIgnoreCase("in")) {
double value = parameter().getDouble(variable.name);
if (variable.norm) {
inData.add(value);
} else {
try {
inData.add(Calc.normalize(value, variable.min, variable.max, variable.norm_min, variable.norm_max));
} catch (IllegalArgumentException ex) {
throw new ServiceException("For input variable, " + variable.name + ", there was a problem: " + ex.getMessage(), ex);
}
}
} else {
if (variable.type.equalsIgnoreCase("out")) {
out.put(variable.name, new NetworkVariable(variable.name, variable.type,
variable.norm, variable.norm_min, variable.norm_max, variable.min, variable.max));
}
}
}
return new CSIPMLData(inData);
}
@Override
public void preProcess() throws Exception {
annName = parameter().getString(ANN_REQUEST_NAME);
}
@Override
public void doProcess() throws Exception {
// Get the correct ANN
nn = getAnnList();
out = new ConcurrentHashMap<>();
MLData mappedInput = getMappedInput(nn, out);
outputData = compute(nn.anns, mappedInput);
}
@Override
public void postProcess() throws Exception {
int index = 0;
for (Map.Entry<String, NetworkVariable> entry : out.entrySet()) {
String name = entry.getKey();
NetworkVariable variable = entry.getValue();
double[] dataset = new double[outputData.size()];
int nnindex = 0;
for (double[] out : outputData) {
dataset[nnindex] = variable.norm ? out[index] : Calc.denormalize(out[index], variable.min, variable.max, variable.norm_min, variable.norm_max);
nnindex++;
}
Map<Integer, Double> uncertResult = Quantiles.quartiles().indexes(0, 1, 2, 3, 4).compute(dataset);
results().put(name, uncertResult.get(2));
results().putMetaInfo(name, "min", uncertResult.get(0));
results().putMetaInfo(name, "1q", uncertResult.get(1));
results().putMetaInfo(name, "3q", uncertResult.get(3));
results().putMetaInfo(name, "max", uncertResult.get(4));
results().putMetaInfo(name, "vals", JSONUtils.toArray(dataset)); //NEEDS TO BE FIXED
index++;
}
}
private void readModelMetaData(ANN_Network annNetwork, String metaFileData) throws ServiceException {
if ((null != annNetwork) && !metaFileData.isEmpty()) {
try {
annNetwork.metaData = new JSONObject(metaFileData);
JSONObject metaObject = annNetwork.metaData.optJSONObject("metadata");
if (null != metaObject) {
JSONArray variables = metaObject.optJSONArray("variables");
if (null != variables) {
for (int i = 0; i < variables.length(); i++) {
JSONObject variable = variables.getJSONObject(i);
NetworkVariable newVar = new NetworkVariable();
newVar.name = variable.getString("name");
newVar.type = variable.getString("type");
newVar.norm = variable.getBoolean("norm");
newVar.norm_min = variable.getDouble("norm_min");
newVar.norm_max = variable.getDouble("norm_max");
newVar.min = variable.getDouble("min");
newVar.max = variable.getDouble("max");
annNetwork.variables.add(newVar);
}
} else {
throw new ServiceException("Cannot read the ANN Model's metadata from the provided JSON file. Missing the JSONArray 'variables'.");
}
} else {
throw new ServiceException("Cannot read the ANN Model's metadata from the provided JSON file. Missing the JSONObject 'metadata'.");
}
} catch (JSONException ex) {
throw new ServiceException("Cannot read the ANN Model's metadata from the provided JSON file: " + ex.getMessage(), ex);
}
} else {
throw new ServiceException("Cannot read the ANN Model's metadata");
}
}
public class ANN_Network {
public List<NEATNetwork> anns = null;
public JSONObject metaData = null;
public ArrayList<NetworkVariable> variables = new ArrayList<>();
}
public class NetworkVariable {
public String name;
public String type;
public boolean norm;
public double norm_min;
public double norm_max;
public double min;
public double max;
public NetworkVariable() {
this.name = "";
this.type = "";
this.norm = false;
this.norm_min = Double.NaN;
this.norm_max = Double.NaN;
this.min = Double.NaN;
this.max = Double.NaN;
}
public NetworkVariable(String name, String type, boolean norm, double norm_min, double norm_max, double min, double max) {
this.name = name;
this.type = type;
this.norm = norm;
this.norm_min = norm_min;
this.norm_max = norm_max;
this.min = min;
this.max = max;
}
}
}