CSIPMLDataSet.java [src/java/utils] Revision: default Date:
package utils;
import static csip.ModelDataService.IN;
import java.util.ArrayList;
import org.encog.EncogError;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.util.EngineArray;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.bson.Document;
public class CSIPMLDataSet implements MLDataSet {
public class CSIPMLDataSetIterator implements Iterator<MLDataPair> {
private int currentIndex = 0;
@Override
public boolean hasNext() {
return currentIndex < CSIPMLDataSet.this.indices.length;
}
@Override
public MLDataPair next() {
if (!hasNext()) {
return null;
}
return CSIPMLDataSet.this.get(currentIndex++);
}
@Override
public void remove() {
throw new EncogError("Called remove, unsupported operation.");
}
}
/**
* The number of inputs.equals
*/
private int inputCounter = 0;
private final List<List<Number>> data;
private final int[] indices;
public CSIPMLDataSet(List<List<Number>> data, int[] indices, int inputCounter) {
this.data = data;
this.indices = indices;
this.inputCounter = inputCounter;
}
public CSIPMLDataSet(Iterable<Document> d, int[] indices) {
LinkedList<List<Number>> data = new LinkedList<>();
for (Document document : d) {
List<Number> vals = document.get(MongoAccess.VALUES, List.class);
String type = document.get(MongoAccess.METADATA, Document.class).getString(MongoAccess.TYPE);
if (type.equals(IN)) {
data.add(inputCounter, vals);
inputCounter++;
} else {
// no downcasting to avoid execution overhead
data.addLast(vals);
}
}
// upcasting
this.data = data;
this.indices = indices;
}
public CSIPMLDataSet(Iterable<Document> d) {
LinkedList<List<Number>> data = new LinkedList<>();
int totalVals = 0;
for (Document document : d) {
List<Number> vals = document.get(MongoAccess.VALUES, List.class);
totalVals = vals.size();
String type = document.get(MongoAccess.METADATA, Document.class).getString(MongoAccess.TYPE);
if (type.equals(IN)) {
data.add(inputCounter, vals);
inputCounter++;
} else {
// no downcasting to avoid execution overhead
data.addLast(vals);
}
}
if (totalVals == 0) {
String msg = "No values available. Check the data base";
throw new IllegalArgumentException(msg);
}
// upcasting
this.data = data;
this.indices = new int[totalVals];
for (int i = 0; i < totalVals; i++) {
indices[i] = i;
}
}
@Override
public int getIdealSize() {
return data.size() - inputCounter;
}
@Override
public int getInputSize() {
return inputCounter;
}
@Override
public boolean isSupervised() {
return getIdealSize() == 0;
}
@Override
public long getRecordCount() {
if (data == null) {
throw new EncogError("You must normalize the dataset before using it.");
}
return indices.length;
}
@Override
public void getRecord(long index, MLDataPair pair) {
if (data == null) {
throw new EncogError("You must normalize the dataset before using it.");
}
double[] dataRow = lookupDataRow((int) index);
// Copy the input
EngineArray.arrayCopy(dataRow, 0, pair.getInput().getData(), 0, getInputSize());
// Copy the output
EngineArray.arrayCopy(dataRow, getInputSize(), pair.getIdeal().getData(), 0, getIdealSize());
}
private double[] lookupDataRow(int index) {
double[] dataRow = new double[data.size()];
int localCounter = 0;
for (List<Number> val : data) {
dataRow[localCounter] = val.get(index).doubleValue();
localCounter++;
}
return dataRow;
}
@Override
public MLDataSet openAdditional() {
return null;
}
@Override
public void add(MLData data1) {
}
@Override
public void add(MLData inputData, MLData idealData) {
}
@Override
public void add(MLDataPair inputData) {
}
@Override
public void close() {
}
@Override
public int size() {
return (int) getRecordCount();
}
@Override
public MLDataPair get(int index) {
if (index > size()) {
return null;
}
int actualIndex = indices[index];
BasicMLData input = new BasicMLData(getInputSize());
BasicMLData ideal = new BasicMLData(getIdealSize());
MLDataPair pair = new BasicMLDataPair(input, ideal);
getRecord(actualIndex, pair);
return pair;
}
@Override
public Iterator<MLDataPair> iterator() {
return new CSIPMLDataSetIterator();
}
}