SameDistribution.java [src/java/m/ann/training/scale] Revision: default Date:
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package m.ann.training.scale;
import static csip.ModelDataService.OUT;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.stat.inference.TestUtils;
import org.bson.Document;
import utils.Data;
import utils.MongoAccess;
/**
*
* @author sidereus
*/
public class SameDistribution extends ScalingMechanism {
@Override
public String getStrategy() {
return "SameDistribution";
}
@Override
public DataSetIndices compute(Iterable<Document> d, double trainingPerc) {
boolean resplit;
List<Integer> trainingIndices = null;
List<Integer> validIndices = null;
int dataLength = getDataLength(d);
int trainingLength = getTrainingLength(trainingPerc, dataLength);
List<Integer> range = getRange(dataLength);
Set<Integer> indicesToHave = getIndicesToHave(d, dataLength);
do {
Collections.shuffle(range);
trainingIndices = range.subList(0, trainingLength);
validIndices = range.subList(trainingLength, dataLength);
// check indices for training dataset
moveIndicesToHaveToTraining(indicesToHave, trainingIndices, validIndices);
// check distribution between training and validation dataset
resplit = checkDistribution(d, trainingIndices, validIndices);
} while (resplit);
return new DataSetIndices(trainingIndices, validIndices);
}
public DataSetIndices compute(Data d, double trainingPerc) {
boolean resplit;
List<Integer> trainingIndices = null;
List<Integer> validIndices = null;
int dataLength = getDataLength(d);
int trainingLength = getTrainingLength(trainingPerc, dataLength);
List<Integer> range = getRange(dataLength);
Set<Integer> indicesToHave = getIndicesToHave(d, dataLength);
do {
Collections.shuffle(range);
trainingIndices = range.subList(0, trainingLength);
validIndices = range.subList(trainingLength, dataLength);
// check indices for training dataset
moveIndicesToHaveToTraining(indicesToHave, trainingIndices, validIndices);
// check distribution between training and validation dataset
resplit = checkDistribution(d, trainingIndices, validIndices);
} while (resplit);
return new DataSetIndices(trainingIndices, validIndices);
}
private boolean checkDistribution(Iterable<Document> d, List<Integer> trainingIndices, List<Integer> validIndices) {
boolean resplit = false;
for (Document dd : d) {
double[] trainingDataSet = new double[trainingIndices.size()];
double[] validDataSet = new double[validIndices.size()];
List<Number> vals = dd.get(MongoAccess.VALUES, List.class);
for (int i = 0; i < trainingIndices.size(); i++) {
trainingDataSet[i] = vals.get(trainingIndices.get(i)).doubleValue();
}
for (int i = 0; i < validIndices.size(); i++) {
validDataSet[i] = vals.get(validIndices.get(i)).doubleValue();
}
double testError = TestUtils.kolmogorovSmirnovTest(trainingDataSet, validDataSet);
if (testError < 0.05) {
// String resplitMessage = "Resplitting data, error: " + testError;
resplit = true;
break;
}
}
return resplit;
}
private boolean checkDistribution(Data d, List<Integer> trainingIndices, List<Integer> validIndices) {
boolean resplit = false;
for (List<Double> colData : d.getDataPerCol()) {
double[] trainingDataSet = new double[trainingIndices.size()];
double[] validDataSet = new double[validIndices.size()];
for (int i = 0; i < trainingIndices.size(); i++) {
trainingDataSet[i] = colData.get(trainingIndices.get(i));
}
for (int i = 0; i < validIndices.size(); i++) {
validDataSet[i] = colData.get(validIndices.get(i));
}
double testError = TestUtils.kolmogorovSmirnovTest(trainingDataSet, validDataSet);
if (testError < 0.05) {
// String resplitMessage = "Resplitting data, error: " + testError;
resplit = true;
break;
}
}
return resplit;
}
private void moveIndicesToHaveToTraining(Set<Integer> indicesToHave, List<Integer> trainingIndices, List<Integer> validIndices) {
int indexToReplace = 0;
List<Integer> posIndicesToHave = new ArrayList<>();
for (Integer index : indicesToHave) {
if (trainingIndices.contains(index)) {
posIndicesToHave.add(trainingIndices.indexOf(index));
} else {
int indexToReplaceValid = validIndices.indexOf(index);
while (posIndicesToHave.contains(indexToReplace)) {
indexToReplace++;
}
int valTraining = trainingIndices.get(indexToReplace);
trainingIndices.set(indexToReplace, index);
validIndices.set(indexToReplaceValid, valTraining);
posIndicesToHave.add(indexToReplace);
indexToReplace++;
}
}
}
private Set<Integer> getIndicesToHave(Iterable<Document> d, int dataLength) {
Set<Integer> indicesToHave = new HashSet<>();
// get max min indices for output vars
for (Document dd : d) {
Document metadata = dd.get(MongoAccess.METADATA, Document.class);
if (metadata.getString(MongoAccess.TYPE).equals(OUT)) {
int min_index = metadata.getInteger(MongoAccess.MIN_INDEX);
checkIndex(min_index, metadata, dataLength);
int max_index = metadata.getInteger(MongoAccess.MAX_INDEX);
checkIndex(max_index, metadata, dataLength);
indicesToHave.add(min_index);
indicesToHave.add(max_index);
};
}
return indicesToHave;
}
private Set<Integer> getIndicesToHave(Data d, int dataLength) {
Set<Integer> indicesToHave = new HashSet<>();
// get max min indices for output vars
for (int index : d.getMinIndices()) {
indicesToHave.add(index);
}
for (int index : d.getMaxIndices()) {
indicesToHave.add(index);
}
return indicesToHave;
}
}