V2_0.java [src/java/m/ann/run] Revision: default Date:
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package m.ann.run;
import utils.MongoAccess;
import com.google.common.math.Quantiles;
import csip.ModelDataService;
import csip.ServiceException;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.ws.rs.Path;
import m.ann.run.tree.Tree;
import oms3.annotations.Description;
import oms3.annotations.Name;
import org.apache.commons.io.FileUtils;
import org.bson.Document;
import utils.Calc;
import utils.CSIPMLData;
import utils.Metadata;
import utils.MongoUtils.Sorting;
import utils.SurrogateModel;
/**
*
* @author sidereus
*/
@Name("Run the rusler ANN prototype")
@Description("Run the network with provided input data, a rusler prototype version.")
@Path("m/run/2.0")
public class V2_0 extends ModelDataService {
@Override
public void doProcess() throws ServiceException, IOException {
String annName = parameter().getString("annName");
Tree tree = new Tree(parseTreeStructure());
List<Integer> computationalOrder = tree.postOrder();
Iterable<Document> d = MongoAccess.getSortedNormalizedData(annName, MongoAccess.NORMALIZED, MongoAccess.NAME, Sorting.ASCENDING);
Map<String, Metadata> in = new LinkedHashMap<>();
Map<String, Metadata> out = new LinkedHashMap<>();
Map<String, int[]> inOutMap = preProcess(d, in, out);
Map<Integer, Data> totalData = parseInputData(in);
SurrogateModel sm = new SurrogateModel(annName);
Long startTime = System.nanoTime();
for (Integer vertex : computationalOrder) {
Data data = totalData.get(vertex);
if (tree.isLeaf(vertex)) {
data.setInputZero();
} else {
List<List<Double>> totOutputs = new ArrayList<>();
Set<Integer> children = tree.getChildren(vertex);
for (Integer child : children) {
totOutputs.add(totalData.get(child).output);
}
List<Double> res = math(totOutputs);
data.setInput(res, inOutMap, vertex);
}
List<double[]> result = sm.compute(new CSIPMLData(data.input));
data.put(outputProcess(out, result));
totalData.replace(vertex, data);
}
Long endTime = System.nanoTime() - startTime;
System.out.println("Estimated time: " + endTime);
postProcess(out, totalData, endTime);
}
private List<Double> math(List<List<Double>> totOutputs) {
List<Double> mathRes = new ArrayList<>();
int nOut = totOutputs.get(0).size();
for (int i = 0; i < nOut; i++) {
Double sum = 0.0;
for (List<Double> outChild : totOutputs) {
sum += outChild.get(i);
}
mathRes.add(sum);
}
return mathRes;
}
private Map<String, int[]> preProcess(Iterable<Document> d, Map<String, Metadata> in, Map<String, Metadata> out) throws ServiceException {
int inCount = 0;
Map<String, int[]> inOutMapping = new LinkedHashMap<>();
for (Document doc : d) {
Document meta = doc.get(MongoAccess.METADATA, Document.class);
String type = meta.getString(MongoAccess.TYPE);
String varName = meta.getString(MongoAccess.NAME);
double max = meta.getDouble("max");
double min = meta.getDouble("min");
boolean normal = meta.getBoolean(MongoAccess.NORM);
double norm_max = meta.getDouble(MongoAccess.NORM_MAX);
double norm_min = meta.getDouble(MongoAccess.NORM_MIN);
if (type.equals(IN)) {
if (varName.startsWith("in_")) {
String tmpVar = varName.substring(3);
inOutMapping.put(tmpVar, new int[]{inCount, 0});
}
in.put(varName, new Metadata(normal, min, max, norm_min, norm_max));
inCount++;
} else {
out.put(varName, new Metadata(normal, min, max, norm_min, norm_max));
}
}
int outCount = 0;
for (Document doc : d) {
Document meta = doc.get(MongoAccess.METADATA, Document.class);
String type = meta.getString(MongoAccess.TYPE);
String varName = meta.getString(MongoAccess.NAME);
if (!type.equals(IN)) {
if (inOutMapping.containsKey(varName)) {
int[] map = inOutMapping.get(varName);
map[1] = outCount;
inOutMapping.replace(varName, map);
}
outCount++;
}
}
return inOutMapping;
}
private void postProcess(Map<String, Metadata> outputMetadata, Map<Integer, Data> totalData, Long endTime) throws IOException {
File f = new File(getWorkspaceDir(), "tmp.csv");
String header = "id";
for (Map.Entry<String, Metadata> entry : outputMetadata.entrySet()) {
header += "," + entry.getKey();
}
header += "\n";
//FileUtils.writeStringToFile(f,endTime.toString());
//FileUtils.writeStringToFile(f, header);
List<String> lines = new ArrayList<>();
lines.add(endTime.toString());
lines.add(header);
for (Map.Entry<Integer, Data> dataentry : totalData.entrySet()) {
int index = 0;
String row = dataentry.getKey().toString();
List<Double> output = dataentry.getValue().output;
for (Map.Entry<String, Metadata> entry : outputMetadata.entrySet()) {
String name = entry.getKey();
Metadata metadata = entry.getValue();
double min = metadata.getMin();
double max = metadata.getMax();
boolean norm = metadata.getNorm();
double norm_min = metadata.getNormMin();
double norm_max = metadata.getNormMax();
Double val = norm ? output.get(index) : Calc.denormalize(output.get(index), min, max, norm_min, norm_max);
row += "," + val.toString();
index++;
}
row += "\n";
lines.add(row);
}
FileUtils.writeLines(f, lines);
results().put(f);
}
/**
* @TODO: improve algorithm efficency
* @param outputMetadata
* @param output
*/
private List<Double> outputProcess(Map<String, Metadata> outputMetadata, List<double[]> output) {
int index = 0;
List<Double> results = new ArrayList<>();
for (Map.Entry<String, Metadata> entry : outputMetadata.entrySet()) {
double[] dataset = new double[output.size()];
int nnindex = 0;
for (double[] out : output) {
dataset[nnindex] = out[index];
nnindex++;
}
Map<Integer, Double> uncertResult = Quantiles.quartiles().indexes(0, 1, 2, 3, 4).compute(dataset);
results.add(uncertResult.get(2));
index++;
}
return results;
}
private Map<Integer, Data> parseInputData(Map<String, Metadata> in) throws FileNotFoundException, IOException, ServiceException {
Map<Integer, Data> inputData = new ConcurrentHashMap<>();
// use the first file.
File file = null;
for (File tmpfile : attachments().getFiles()) {
if (tmpfile.getName().endsWith("data.csv")) {
file = tmpfile;
}
}
try (BufferedReader r = new BufferedReader(new FileReader(file))) {
Pattern p = Pattern.compile("\\s*,\\s*");
String[] names = p.split(r.readLine());
String[] actualNames = new String[names.length - 1];
int namecount = 0;
for (int i = 0; i < names.length; i++) {
String name = names[i];
if (!name.equals("pixel")) {
actualNames[namecount] = name;
namecount++;
}
}
int[] headerIndices = getHeaderIndices(in, actualNames);
String line = null;
while ((line = r.readLine()) != null) {
String[] row = p.split(line);
Integer id = Integer.parseInt(row[0]);
List<Double> data = new ArrayList<>();
for (int i = 0; i < headerIndices.length; i++) {
int index = headerIndices[i];
if (index == 0) {
throw new UnsupportedOperationException("ZERO index. That is not possible");
}
if (index == -99) {
data.add(null);
} else {
String name = names[index];
Metadata m = in.get(name);
double val = Calc.normalize(Double.parseDouble(row[index]), m.getMin(), m.getMax(), m.getNormMin(), m.getNormMax());
data.add(val);
}
}
inputData.put(id, new Data(data));
}
}
return inputData;
}
// check algorithm
private int[] getHeaderIndices(Map<String, Metadata> in, String[] names) {
int colCount = 1; // need to not account for pixel column
int indCount = 0;
boolean found = false;
int[] indices = new int[in.keySet().size()];
for (String var : in.keySet()) {
for (String name : names) {
if (var.toLowerCase().equals(name.toLowerCase())) {
indices[indCount] = colCount;
indCount++;
found = true;
break;
}
colCount++;
}
if (!found) {
indices[indCount] = -99;
indCount++;
} else {
found = false;
}
colCount = 1;
}
return indices;
}
private class Data {
List<Double> input;
List<Double> output;
public Data(List<Double> input) {
this.input = input;
}
public void setInputZero() {
input = input.stream().map(v -> v == null ? 0.0 : v).collect(Collectors.toList());
}
public void setInput(List<Double> out, Map<String, int[]> inOutMap, int vertex) {
for (Map.Entry<String, int[]> entry : inOutMap.entrySet()) {
int inIndex = entry.getValue()[0];
int outIndex = entry.getValue()[1];
Double val = input.get(inIndex);
if (input.get(inIndex) != null) {
throw new IllegalArgumentException("Something wrong with indices. Vertex " + vertex);
}
input.set(inIndex, out.get(outIndex));
}
}
public void put(List<Double> results) {
output = results;
}
}
/**
* This method parses a csv file structured as child, parent. It returns a
* List of couples child, parent.
*
* @return
* @throws FileNotFoundException
* @throws IOException
*/
private List<int[]> parseTreeStructure() throws FileNotFoundException, IOException {
List<int[]> treeStructure = new ArrayList<>();
if (attachments().getFilesCount() > 0) { // add a csv to the database
// use the first file.
for (File file : attachments().getFiles()) {
if (file.getName().endsWith("tree.csv")) {
try (BufferedReader r = new BufferedReader(new FileReader(file))) {
Pattern p = Pattern.compile("\\s*,\\s*");
String line = null;
r.readLine(); // skip header
while ((line = r.readLine()) != null) {
String[] row = p.split(line);
if (row.length != 2) {
String msg = "Child, parent structure required";
throw new UnsupportedOperationException(msg);
}
int[] nodes = new int[2];
// file has to be structured child,parent
nodes[0] = Integer.parseInt(row[0]);
nodes[1] = Integer.parseInt(row[1]);
treeStructure.add(nodes);
}
}
break;
}
}
} else { // add a row to the database
String msg = "Tree structure required";
throw new NullPointerException(msg);
}
return treeStructure;
}
public static void main(String[] args) {
List<Double> t = new ArrayList<>();
t.add(1.0);
t.add(5.0);
t.add(null);
t = t.stream().map(v -> v == null ? 0.0 : v).collect(Collectors.toList());
System.out.println(t);
}
}