CSIPMLDataSet.java [src/java/utils] Revision: default  Date:
package utils;

import static csip.ModelDataService.IN;
import java.util.ArrayList;
import org.encog.EncogError;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.util.EngineArray;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.bson.Document;

public class CSIPMLDataSet implements MLDataSet {

    public class CSIPMLDataSetIterator implements Iterator<MLDataPair> {

        private int currentIndex = 0;


        @Override
        public boolean hasNext() {
            return currentIndex < CSIPMLDataSet.this.indices.length;
        }


        @Override
        public MLDataPair next() {
            if (!hasNext()) {
                return null;
            }
            return CSIPMLDataSet.this.get(currentIndex++);
        }


        @Override
        public void remove() {
            throw new EncogError("Called remove, unsupported operation.");
        }
    }

    /**
     * The number of inputs.equals
     */
    private int inputCounter = 0;

    private final List<List<Number>> data;
    private final int[] indices;


    public CSIPMLDataSet(List<List<Number>> data, int[] indices, int inputCounter) {
        this.data = data;
        this.indices = indices;
        this.inputCounter = inputCounter;
    }


    public CSIPMLDataSet(Iterable<Document> d, int[] indices) {
        LinkedList<List<Number>> data = new LinkedList<>();
        for (Document document : d) {
            List<Number> vals = document.get(MongoAccess.VALUES, List.class);
            String type = document.get(MongoAccess.METADATA, Document.class).getString(MongoAccess.TYPE);
            if (type.equals(IN)) {
                data.add(inputCounter, vals);
                inputCounter++;
            } else {
                // no downcasting to avoid execution overhead
                data.addLast(vals);
            }
        }
        // upcasting
        this.data = data;
        this.indices = indices;
    }


    public CSIPMLDataSet(Iterable<Document> d) {
        LinkedList<List<Number>> data = new LinkedList<>();
        int totalVals = 0;
        for (Document document : d) {
            List<Number> vals = document.get(MongoAccess.VALUES, List.class);
            totalVals = vals.size();
            String type = document.get(MongoAccess.METADATA, Document.class).getString(MongoAccess.TYPE);
            if (type.equals(IN)) {
                data.add(inputCounter, vals);
                inputCounter++;
            } else {
                // no downcasting to avoid execution overhead
                data.addLast(vals);
            }
        }
        if (totalVals == 0) {
            String msg = "No values available. Check the data base";
            throw new IllegalArgumentException(msg);
        }
        // upcasting
        this.data = data;
        this.indices = new int[totalVals];
        for (int i = 0; i < totalVals; i++) {
            indices[i] = i;
        }
    }


    @Override
    public int getIdealSize() {
        return data.size() - inputCounter;
    }


    @Override
    public int getInputSize() {
        return inputCounter;
    }


    @Override
    public boolean isSupervised() {
        return getIdealSize() == 0;
    }


    @Override
    public long getRecordCount() {
        if (data == null) {
            throw new EncogError("You must normalize the dataset before using it.");
        }
        return indices.length;
    }


    @Override
    public void getRecord(long index, MLDataPair pair) {
        if (data == null) {
            throw new EncogError("You must normalize the dataset before using it.");
        }

        double[] dataRow = lookupDataRow((int) index);

        // Copy the input
        EngineArray.arrayCopy(dataRow, 0, pair.getInput().getData(), 0, getInputSize());

        // Copy the output
        EngineArray.arrayCopy(dataRow, getInputSize(), pair.getIdeal().getData(), 0, getIdealSize());

    }


    private double[] lookupDataRow(int index) {
        double[] dataRow = new double[data.size()];
        int localCounter = 0;
        for (List<Number> val : data) {
            dataRow[localCounter] = val.get(index).doubleValue();
            localCounter++;
        }
        return dataRow;
    }


    @Override
    public MLDataSet openAdditional() {
        return null;
    }


    @Override
    public void add(MLData data1) {

    }


    @Override
    public void add(MLData inputData, MLData idealData) {

    }


    @Override
    public void add(MLDataPair inputData) {

    }


    @Override
    public void close() {

    }


    @Override
    public int size() {
        return (int) getRecordCount();
    }


    @Override
    public MLDataPair get(int index) {
        if (index > size()) {
            return null;
        }
        int actualIndex = indices[index];
        BasicMLData input = new BasicMLData(getInputSize());
        BasicMLData ideal = new BasicMLData(getIdealSize());
        MLDataPair pair = new BasicMLDataPair(input, ideal);

        getRecord(actualIndex, pair);

        return pair;
    }


    @Override
    public Iterator<MLDataPair> iterator() {
        return new CSIPMLDataSetIterator();
    }
}