SameDistribution.java [src/java/m/ann/training/scale] Revision: default  Date:
/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package m.ann.training.scale;

import static csip.ModelDataService.OUT;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.stat.inference.TestUtils;
import org.bson.Document;
import utils.Data;
import utils.MongoAccess;

/**
 *
 * @author sidereus
 */
public class SameDistribution extends ScalingMechanism {

  @Override
  public String getStrategy() {
    return "SameDistribution";
  }


  @Override
  public DataSetIndices compute(Iterable<Document> d, double trainingPerc) {
    boolean resplit;
    List<Integer> trainingIndices = null;
    List<Integer> validIndices = null;

    int dataLength = getDataLength(d);
    int trainingLength = getTrainingLength(trainingPerc, dataLength);
    List<Integer> range = getRange(dataLength);

    Set<Integer> indicesToHave = getIndicesToHave(d, dataLength);

    do {
      Collections.shuffle(range);
      trainingIndices = range.subList(0, trainingLength);
      validIndices = range.subList(trainingLength, dataLength);

      // check indices for training dataset
      moveIndicesToHaveToTraining(indicesToHave, trainingIndices, validIndices);

      // check distribution between training and validation dataset
      resplit = checkDistribution(d, trainingIndices, validIndices);
    } while (resplit);

    return new DataSetIndices(trainingIndices, validIndices);

  }


  public DataSetIndices compute(Data d, double trainingPerc) {
    boolean resplit;
    List<Integer> trainingIndices = null;
    List<Integer> validIndices = null;

    int dataLength = getDataLength(d);
    int trainingLength = getTrainingLength(trainingPerc, dataLength);
    List<Integer> range = getRange(dataLength);

    Set<Integer> indicesToHave = getIndicesToHave(d, dataLength);

    do {
      Collections.shuffle(range);
      trainingIndices = range.subList(0, trainingLength);
      validIndices = range.subList(trainingLength, dataLength);

      // check indices for training dataset
      moveIndicesToHaveToTraining(indicesToHave, trainingIndices, validIndices);

      // check distribution between training and validation dataset
      resplit = checkDistribution(d, trainingIndices, validIndices);
    } while (resplit);

    return new DataSetIndices(trainingIndices, validIndices);
  }


  private boolean checkDistribution(Iterable<Document> d, List<Integer> trainingIndices, List<Integer> validIndices) {
    boolean resplit = false;
    for (Document dd : d) {
      double[] trainingDataSet = new double[trainingIndices.size()];
      double[] validDataSet = new double[validIndices.size()];
      List<Number> vals = dd.get(MongoAccess.VALUES, List.class);
      for (int i = 0; i < trainingIndices.size(); i++) {
        trainingDataSet[i] = vals.get(trainingIndices.get(i)).doubleValue();
      }
      for (int i = 0; i < validIndices.size(); i++) {
        validDataSet[i] = vals.get(validIndices.get(i)).doubleValue();
      }

      double testError = TestUtils.kolmogorovSmirnovTest(trainingDataSet, validDataSet);

      if (testError < 0.05) {
//                String resplitMessage = "Resplitting data, error: " + testError;
        resplit = true;
        break;
      }
    }
    return resplit;
  }


  private boolean checkDistribution(Data d, List<Integer> trainingIndices, List<Integer> validIndices) {
    boolean resplit = false;
    for (List<Double> colData : d.getDataPerCol()) {
      double[] trainingDataSet = new double[trainingIndices.size()];
      double[] validDataSet = new double[validIndices.size()];

      for (int i = 0; i < trainingIndices.size(); i++) {
        trainingDataSet[i] = colData.get(trainingIndices.get(i));
      }
      for (int i = 0; i < validIndices.size(); i++) {
        validDataSet[i] = colData.get(validIndices.get(i));
      }

      double testError = TestUtils.kolmogorovSmirnovTest(trainingDataSet, validDataSet);

      if (testError < 0.05) {
//                String resplitMessage = "Resplitting data, error: " + testError;
        resplit = true;
        break;
      }
    }
    return resplit;
  }


  private void moveIndicesToHaveToTraining(Set<Integer> indicesToHave, List<Integer> trainingIndices, List<Integer> validIndices) {

    int indexToReplace = 0;
    List<Integer> posIndicesToHave = new ArrayList<>();
    for (Integer index : indicesToHave) {
      if (trainingIndices.contains(index)) {
        posIndicesToHave.add(trainingIndices.indexOf(index));
      } else {
        int indexToReplaceValid = validIndices.indexOf(index);
        while (posIndicesToHave.contains(indexToReplace)) {
          indexToReplace++;
        }
        int valTraining = trainingIndices.get(indexToReplace);

        trainingIndices.set(indexToReplace, index);
        validIndices.set(indexToReplaceValid, valTraining);
        posIndicesToHave.add(indexToReplace);
        indexToReplace++;

      }
    }

  }


  private Set<Integer> getIndicesToHave(Iterable<Document> d, int dataLength) {
    Set<Integer> indicesToHave = new HashSet<>();
    // get max min indices for output vars
    for (Document dd : d) {
      Document metadata = dd.get(MongoAccess.METADATA, Document.class);
      if (metadata.getString(MongoAccess.TYPE).equals(OUT)) {
        int min_index = metadata.getInteger(MongoAccess.MIN_INDEX);
        checkIndex(min_index, metadata, dataLength);

        int max_index = metadata.getInteger(MongoAccess.MAX_INDEX);
        checkIndex(max_index, metadata, dataLength);

        indicesToHave.add(min_index);
        indicesToHave.add(max_index);
      };
    }

    return indicesToHave;

  }


  private Set<Integer> getIndicesToHave(Data d, int dataLength) {
    Set<Integer> indicesToHave = new HashSet<>();
    // get max min indices for output vars

    for (int index : d.getMinIndices()) {
      indicesToHave.add(index);
    }
    for (int index : d.getMaxIndices()) {
      indicesToHave.add(index);
    }

    return indicesToHave;

  }

}