V2_0.java [src/java/m/ann/run] Revision: default  Date:
/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package m.ann.run;

import utils.MongoAccess;
import com.google.common.math.Quantiles;
import csip.ModelDataService;
import csip.ServiceException;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.ws.rs.Path;
import m.ann.run.tree.Tree;
import oms3.annotations.Description;
import oms3.annotations.Name;
import org.apache.commons.io.FileUtils;
import org.bson.Document;
import utils.Calc;
import utils.CSIPMLData;
import utils.Metadata;
import utils.MongoUtils.Sorting;
import utils.SurrogateModel;

/**
 *
 * @author sidereus
 */
@Name("Run the rusler ANN prototype")
@Description("Run the network with provided input data, a rusler prototype version.")
@Path("m/run/2.0")
public class V2_0 extends ModelDataService {

    @Override
    public void doProcess() throws ServiceException, IOException {
        String annName = parameter().getString("annName");

        Tree tree = new Tree(parseTreeStructure());
        List<Integer> computationalOrder = tree.postOrder();

        Iterable<Document> d = MongoAccess.getSortedNormalizedData(annName, MongoAccess.NORMALIZED, MongoAccess.NAME, Sorting.ASCENDING);

        Map<String, Metadata> in = new LinkedHashMap<>();
        Map<String, Metadata> out = new LinkedHashMap<>();
        Map<String, int[]> inOutMap = preProcess(d, in, out);
        Map<Integer, Data> totalData = parseInputData(in);

        SurrogateModel sm = new SurrogateModel(annName);

        Long startTime = System.nanoTime();
        for (Integer vertex : computationalOrder) {
            Data data = totalData.get(vertex);
            if (tree.isLeaf(vertex)) {
                data.setInputZero();
            } else {
                List<List<Double>> totOutputs = new ArrayList<>();
                Set<Integer> children = tree.getChildren(vertex);
                for (Integer child : children) {
                    totOutputs.add(totalData.get(child).output);
                }

                List<Double> res = math(totOutputs);
                data.setInput(res, inOutMap, vertex);

            }
            List<double[]> result = sm.compute(new CSIPMLData(data.input));
            data.put(outputProcess(out, result));
            totalData.replace(vertex, data);
        }

        Long endTime = System.nanoTime() - startTime;
        System.out.println("Estimated time: " + endTime);
        postProcess(out, totalData, endTime);

    }

    private List<Double> math(List<List<Double>> totOutputs) {
        List<Double> mathRes = new ArrayList<>();
        int nOut = totOutputs.get(0).size();
        for (int i = 0; i < nOut; i++) {
            Double sum = 0.0;
            for (List<Double> outChild : totOutputs) {
                sum += outChild.get(i);
            }
            mathRes.add(sum);
        }
        return mathRes;
    }

    private Map<String, int[]> preProcess(Iterable<Document> d, Map<String, Metadata> in, Map<String, Metadata> out) throws ServiceException {
        int inCount = 0;
        Map<String, int[]> inOutMapping = new LinkedHashMap<>();
        for (Document doc : d) {
            Document meta = doc.get(MongoAccess.METADATA, Document.class);
            String type = meta.getString(MongoAccess.TYPE);
            String varName = meta.getString(MongoAccess.NAME);
            double max = meta.getDouble("max");
            double min = meta.getDouble("min");
            boolean normal = meta.getBoolean(MongoAccess.NORM);
            double norm_max = meta.getDouble(MongoAccess.NORM_MAX);
            double norm_min = meta.getDouble(MongoAccess.NORM_MIN);
            if (type.equals(IN)) {
                if (varName.startsWith("in_")) {
                    String tmpVar = varName.substring(3);
                    inOutMapping.put(tmpVar, new int[]{inCount, 0});
                }
                in.put(varName, new Metadata(normal, min, max, norm_min, norm_max));
                inCount++;
            } else {
                out.put(varName, new Metadata(normal, min, max, norm_min, norm_max));
            }
        }

        int outCount = 0;
        for (Document doc : d) {
            Document meta = doc.get(MongoAccess.METADATA, Document.class);
            String type = meta.getString(MongoAccess.TYPE);
            String varName = meta.getString(MongoAccess.NAME);
            if (!type.equals(IN)) {
                if (inOutMapping.containsKey(varName)) {
                    int[] map = inOutMapping.get(varName);
                    map[1] = outCount;
                    inOutMapping.replace(varName, map);
                }
                outCount++;
            }
        }
        return inOutMapping;
    }

    private void postProcess(Map<String, Metadata> outputMetadata, Map<Integer, Data> totalData, Long endTime) throws IOException {

        File f = new File(getWorkspaceDir(), "tmp.csv");
        String header = "id";
        for (Map.Entry<String, Metadata> entry : outputMetadata.entrySet()) {
            header += "," + entry.getKey();
        }
        header += "\n";
        //FileUtils.writeStringToFile(f,endTime.toString());
        //FileUtils.writeStringToFile(f, header);

        List<String> lines = new ArrayList<>();
        lines.add(endTime.toString());
        lines.add(header);
        for (Map.Entry<Integer, Data> dataentry : totalData.entrySet()) {
            int index = 0;
            String row = dataentry.getKey().toString();
            List<Double> output = dataentry.getValue().output;
            for (Map.Entry<String, Metadata> entry : outputMetadata.entrySet()) {
                String name = entry.getKey();
                Metadata metadata = entry.getValue();
                double min = metadata.getMin();
                double max = metadata.getMax();
                boolean norm = metadata.getNorm();
                double norm_min = metadata.getNormMin();
                double norm_max = metadata.getNormMax();

                Double val = norm ? output.get(index) : Calc.denormalize(output.get(index), min, max, norm_min, norm_max);
                row += "," + val.toString();
                index++;
            }
            row += "\n";
            lines.add(row);
        }
        FileUtils.writeLines(f, lines);
        results().put(f);
    }

    /**
     * @TODO: improve algorithm efficency
     * @param outputMetadata
     * @param output
     */
    private List<Double> outputProcess(Map<String, Metadata> outputMetadata, List<double[]> output) {
        int index = 0;
        List<Double> results = new ArrayList<>();
        for (Map.Entry<String, Metadata> entry : outputMetadata.entrySet()) {

            double[] dataset = new double[output.size()];
            int nnindex = 0;
            for (double[] out : output) {
                dataset[nnindex] = out[index];
                nnindex++;
            }

            Map<Integer, Double> uncertResult = Quantiles.quartiles().indexes(0, 1, 2, 3, 4).compute(dataset);
            results.add(uncertResult.get(2));
            index++;
        }
        return results;
    }

    private Map<Integer, Data> parseInputData(Map<String, Metadata> in) throws FileNotFoundException, IOException, ServiceException {
        Map<Integer, Data> inputData = new ConcurrentHashMap<>();
        // use the first file.
        File file = null;
        for (File tmpfile : attachments().getFiles()) {
                if (tmpfile.getName().endsWith("data.csv")) {
                    file = tmpfile;
                }
        }
        try (BufferedReader r = new BufferedReader(new FileReader(file))) {
            Pattern p = Pattern.compile("\\s*,\\s*");

            String[] names = p.split(r.readLine());
            String[] actualNames = new String[names.length - 1];
            int namecount = 0;
            for (int i = 0; i < names.length; i++) {
                String name = names[i];
                if (!name.equals("pixel")) {
                    actualNames[namecount] = name;
                    namecount++;
                }
            }
            int[] headerIndices = getHeaderIndices(in, actualNames);
            String line = null;
            while ((line = r.readLine()) != null) {
                String[] row = p.split(line);
                Integer id = Integer.parseInt(row[0]);
                List<Double> data = new ArrayList<>();
                for (int i = 0; i < headerIndices.length; i++) {
                    int index = headerIndices[i];
                    if (index == 0) {
                        throw new UnsupportedOperationException("ZERO index. That is not possible");
                    }
                    if (index == -99) {
                        data.add(null);
                    } else {
                        String name = names[index];
                        Metadata m = in.get(name);
                        double val = Calc.normalize(Double.parseDouble(row[index]), m.getMin(), m.getMax(), m.getNormMin(), m.getNormMax());
                        data.add(val);
                    }
                }
                inputData.put(id, new Data(data));
            }
        }

        return inputData;
    }

    // check algorithm
    private int[] getHeaderIndices(Map<String, Metadata> in, String[] names) {
        int colCount = 1; // need to not account for pixel column
        int indCount = 0;
        boolean found = false;
        int[] indices = new int[in.keySet().size()];
        for (String var : in.keySet()) {
            for (String name : names) {
                if (var.toLowerCase().equals(name.toLowerCase())) {
                    indices[indCount] = colCount;
                    indCount++;
                    found = true;
                    break;
                }
                colCount++;
            }
            if (!found) {
                indices[indCount] = -99;
                indCount++;
            } else {
                found = false;
            }
            colCount = 1;
        }
        return indices;
    }

    private class Data {

        List<Double> input;
        List<Double> output;

        public Data(List<Double> input) {
            this.input = input;
        }

        public void setInputZero() {
            input = input.stream().map(v -> v == null ? 0.0 : v).collect(Collectors.toList());
        }

        public void setInput(List<Double> out, Map<String, int[]> inOutMap, int vertex) {
            for (Map.Entry<String, int[]> entry : inOutMap.entrySet()) {
                int inIndex = entry.getValue()[0];
                int outIndex = entry.getValue()[1];
                Double val = input.get(inIndex);
                if (input.get(inIndex) != null) {
                    throw new IllegalArgumentException("Something wrong with indices. Vertex " + vertex);
                }
                input.set(inIndex, out.get(outIndex));
            }
        }

        public void put(List<Double> results) {
            output = results;
        }

    }

    /**
     * This method parses a csv file structured as child, parent. It returns a
     * List of couples child, parent.
     *
     * @return
     * @throws FileNotFoundException
     * @throws IOException 
     */
    private List<int[]> parseTreeStructure() throws FileNotFoundException, IOException {
        List<int[]> treeStructure = new ArrayList<>();
        if (attachments().getFilesCount() > 0) { // add a csv to the database
            // use the first file.
            for (File file : attachments().getFiles()) {
                if (file.getName().endsWith("tree.csv")) {
                    try (BufferedReader r = new BufferedReader(new FileReader(file))) {
                        Pattern p = Pattern.compile("\\s*,\\s*");
                        String line = null;
                        r.readLine(); // skip header
                        while ((line = r.readLine()) != null) {
                            String[] row = p.split(line);
                            if (row.length != 2) {
                                String msg = "Child, parent structure required";
                                throw new UnsupportedOperationException(msg);
                            }
                            int[] nodes = new int[2];
                            // file has to be structured child,parent
                            nodes[0] = Integer.parseInt(row[0]);
                            nodes[1] = Integer.parseInt(row[1]);
                            treeStructure.add(nodes);
                        }
                    }
                    break;
                }
            }
        } else { // add a row to the database
            String msg = "Tree structure required";
            throw new NullPointerException(msg);
        }
        return treeStructure;
    }

    public static void main(String[] args) {
        List<Double> t = new ArrayList<>();
        t.add(1.0);
        t.add(5.0);
        t.add(null);

        t = t.stream().map(v -> v == null ? 0.0 : v).collect(Collectors.toList());

        System.out.println(t);
    }
}