ANN_ModelDataService.java [src/java/crp/utils] Revision:   Date:
/*
 * $Id$
 *
 * This file is part of the Cloud Services Integration Platform (CSIP),
 * a Model-as-a-Service framework, API, and application suite.
 *
 * 2012-2019, OMSLab, Colorado State University.
 *
 * OMSLab licenses this file to you under the MIT license.
 * See the LICENSE file in the project root for more information.
 */
package crp.utils;

import com.google.common.math.Quantiles;
import csip.ModelDataService;
import csip.ServiceException;
import csip.utils.JSONUtils;
import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import static javax.xml.bind.DatatypeConverter.parseBase64Binary;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.WildcardFileFilter;
import org.apache.commons.lang3.SerializationUtils;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;
import org.encog.ml.data.MLData;
import org.encog.neural.neat.NEATNetwork;

/**
 *
 * @author <a href="mailto:shaun.case@colostate.edu">Shaun Case</a>
 */
public abstract class ANN_ModelDataService extends ModelDataService {

  public static final String ANN_REQUEST_NAME = "ann_name";
  protected String annName;
  protected List<double[]> outputData;
  protected ConcurrentHashMap<String, NetworkVariable> out;
  protected ANN_Network nn;

  /**
   * Unpacks the zip file identified by the "name" parameter and loads all of
   * its ann's into a list to be stored in the child class in a static
   * SimpleCache
   *
   * @param name
   * @return
   */
  protected ANN_Network getANNs(String name) {
    try {
      File ens = resources().getFile(name);
      if (ens.isDirectory()) {
        File[] files = ens.listFiles((FilenameFilter) new WildcardFileFilter("*.ann"));
        List<NEATNetwork> l = new ArrayList<>();
        for (File file : files) {
          byte[] ann = parseBase64Binary(FileUtils.readFileToString(file));  //.readFileToByteArray(file);          
          l.add((NEATNetwork) SerializationUtils.deserialize(ann));
        }
        if (l.isEmpty()) {
          throw new IllegalArgumentException("No anns found for zipped ensemble: " + name);
        }
        ANN_Network annNetwork = new ANN_Network();
        annNetwork.anns = l;

        String metaDataString;
        File[] mainMetaFiles = ens.listFiles((FilenameFilter) new WildcardFileFilter(name + ".json"));

        if (mainMetaFiles.length == 1) {
          File mainMetaFile = mainMetaFiles[0];
          metaDataString = FileUtils.readFileToString(mainMetaFile);

          readModelMetaData(annNetwork, metaDataString);

          return annNetwork;
        } else {
          throw new ServiceException("Cannot find that ANN Model's metadata JSON file.");
        }
      }
    } catch (ServiceException | IOException ex) {
      throw new IllegalArgumentException("Failed to load ensemble:: " + name, ex);
    }
    throw new IllegalArgumentException("Not an ann ensemble resource: " + name);
  }

  /**
   * Computes the output from all ANNs
   *
   * @param nn
   * @param input
   * @return
   */
  protected List<double[]> compute(List<NEATNetwork> nn, MLData input) {
    return nn.parallelStream()
        .map(n -> n.compute(input).getData())
        .collect(Collectors.toList());
  }

  protected abstract ANN_Network getAnnList() throws ServiceException;

  protected MLData getMappedInput(ANN_Network nn, Map<String, NetworkVariable> out) throws ServiceException {
    ArrayList<Double> inData = new ArrayList<>();

    for (NetworkVariable variable : nn.variables) {
      if (variable.type.equalsIgnoreCase("in")) {
        double value = parameter().getDouble(variable.name);
        if (variable.norm) {
          inData.add(value);
        } else {
          try {
            inData.add(Calc.normalize(value, variable.min, variable.max, variable.norm_min, variable.norm_max));
          } catch (IllegalArgumentException ex) {
            throw new ServiceException("For input variable, " + variable.name + ", there was a problem: " + ex.getMessage(), ex);
          }
        }
      } else {
        if (variable.type.equalsIgnoreCase("out")) {
          out.put(variable.name, new NetworkVariable(variable.name, variable.type,
              variable.norm, variable.norm_min, variable.norm_max, variable.min, variable.max));
        }
      }
    }

    return new CSIPMLData(inData);
  }

  @Override
  public void preProcess() throws Exception {
    annName = parameter().getString(ANN_REQUEST_NAME);
  }

  @Override
  public void doProcess() throws Exception {
    //  Get the correct ANN
    nn = getAnnList();

    out = new ConcurrentHashMap<>();

    MLData mappedInput = getMappedInput(nn, out);

    outputData = compute(nn.anns, mappedInput);
  }

  @Override
  public void postProcess() throws Exception {
    int index = 0;
    for (Map.Entry<String, NetworkVariable> entry : out.entrySet()) {
      String name = entry.getKey();
      NetworkVariable variable = entry.getValue();

      double[] dataset = new double[outputData.size()];
      int nnindex = 0;
      for (double[] out : outputData) {
        dataset[nnindex] = variable.norm ? out[index] : Calc.denormalize(out[index], variable.min, variable.max, variable.norm_min, variable.norm_max);
        nnindex++;
      }

      Map<Integer, Double> uncertResult = Quantiles.quartiles().indexes(0, 1, 2, 3, 4).compute(dataset);
      results().put(name, uncertResult.get(2));
      results().putMetaInfo(name, "min", uncertResult.get(0));
      results().putMetaInfo(name, "1q", uncertResult.get(1));
      results().putMetaInfo(name, "3q", uncertResult.get(3));
      results().putMetaInfo(name, "max", uncertResult.get(4));
      results().putMetaInfo(name, "vals", JSONUtils.toArray(dataset)); //NEEDS TO BE FIXED
      index++;
    }
  }

  private void readModelMetaData(ANN_Network annNetwork, String metaFileData) throws ServiceException {
    if ((null != annNetwork) && !metaFileData.isEmpty()) {
      try {
        annNetwork.metaData = new JSONObject(metaFileData);
        JSONObject metaObject = annNetwork.metaData.optJSONObject("metadata");
        if (null != metaObject) {
          JSONArray variables = metaObject.optJSONArray("variables");

          if (null != variables) {
            for (int i = 0; i < variables.length(); i++) {
              JSONObject variable = variables.getJSONObject(i);
              NetworkVariable newVar = new NetworkVariable();

              newVar.name = variable.getString("name");
              newVar.type = variable.getString("type");
              newVar.norm = variable.getBoolean("norm");
              newVar.norm_min = variable.getDouble("norm_min");
              newVar.norm_max = variable.getDouble("norm_max");
              newVar.min = variable.getDouble("min");
              newVar.max = variable.getDouble("max");
              annNetwork.variables.add(newVar);
            }

          } else {
            throw new ServiceException("Cannot read the ANN Model's metadata from the provided JSON file.  Missing the JSONArray 'variables'.");
          }
        } else {
          throw new ServiceException("Cannot read the ANN Model's metadata from the provided JSON file.  Missing the JSONObject 'metadata'.");
        }

      } catch (JSONException ex) {
        throw new ServiceException("Cannot read the ANN Model's metadata from the provided JSON file: " + ex.getMessage(), ex);
      }

    } else {
      throw new ServiceException("Cannot read the ANN Model's metadata");

    }
  }

  public class ANN_Network {

    public List<NEATNetwork> anns = null;
    public JSONObject metaData = null;
    public ArrayList<NetworkVariable> variables = new ArrayList<>();
  }

  public class NetworkVariable {

    public String name;
    public String type;
    public boolean norm;
    public double norm_min;
    public double norm_max;
    public double min;
    public double max;

    public NetworkVariable() {
      this.name = "";
      this.type = "";
      this.norm = false;
      this.norm_min = Double.NaN;
      this.norm_max = Double.NaN;
      this.min = Double.NaN;
      this.max = Double.NaN;
    }

    public NetworkVariable(String name, String type, boolean norm, double norm_min, double norm_max, double min, double max) {
      this.name = name;
      this.type = type;
      this.norm = norm;
      this.norm_min = norm_min;
      this.norm_max = norm_max;
      this.min = min;
      this.max = max;
    }
  }
}