Example usage for weka.core Instances Instances

List of usage examples for weka.core Instances Instances

Introduction

In this page you can find the example usage for weka.core Instances Instances.

Prototype

public Instances(Instances dataset) 

Source Link

Document

Constructor copying all instances and references to the header information from the given set of instances.

Usage

From source file:ann.MyANN.java

/**
 * Melakukan training dengan data yang diberikan
 * @param instances training data//from  ww  w . ja  va  2  s  .  co  m
 * @throws Exception Exception apapun yang menyebabkan training gagal
 */
@Override
public void buildClassifier(Instances instances) throws Exception {

    // cek apakah sesuai dengan data input
    getCapabilities().testWithFail(instances);
    // copy data dan buang semua missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    // filter
    NumericToBinary ntb = new NumericToBinary();
    ntb.setInputFormat(instances);
    instances = Filter.useFilter(instances, ntb);

    // ubah instances ke data
    instancesToDatas(instances);

    // membangun ANN berdasarkan nbLayers
    // membuat layer
    ArrayList<ArrayList<Node>> layers = new ArrayList<>();
    for (int i = 0; i < nbLayers.length; i++) {
        layers.add(new ArrayList<>());
    }

    // inisialisasi bagian input layer
    for (int i = 0; i < nbLayers[0]; i++) {
        // set id, prevLayer = null, nextLayer = layers[1]
        layers.get(0).add(new Node("node-0" + "-" + i, null, layers.get(1)));
    }

    // inisialisasi bagian hidden layer
    for (int i = 1; i < nbLayers.length - 1; i++) {
        for (int j = 0; j < nbLayers[i]; j++) {
            // set id, prevLayer = layers[i-1], nextLayer = layers[i+1]
            layers.get(i).add(new Node("node-" + i + "-" + j, layers.get(i - 1), layers.get(i + 1)));
        }
    }

    // inisialisasi bagian output layer
    for (int i = 0; i < nbLayers[nbLayers.length - 1]; i++) {
        // set id, prevLayer = layers[n-1], nextLayer = null
        layers.get(nbLayers.length - 1).add(
                new Node("node-" + (nbLayers.length - 1) + "-" + i, layers.get(nbLayers.length - 2), null));
    }

    // tambah weight tiap neuron

    // siapin bobot bias, jumlah layer bias adalah nbLayers - 1
    ArrayList<Double> bias = new ArrayList<>();
    for (int i = 0; i < nbLayers.length - 1; i++) {
        bias.add(1.0);
    }

    // jumlah bobot setiap layer sama dengan jumlah node setiap layer                
    double[][] biasWeight = new double[nbLayers.length - 1][];
    for (int i = 1; i < biasWeight.length; i++) {
        biasWeight[i] = new double[nbLayers[i]];
    }

    // masukin setiap bobot dengan angka random
    //Random rand = new Random(System.currentTimeMillis());
    Random rand = new Random(1);
    // masukin bobot bias
    int j = 0;
    Map<Integer, Map<Node, Double>> biasesWeight = new HashMap<>();
    for (int i = 0; i < nbLayers.length - 1; i++) {
        ArrayList<Node> arrNode = layers.get(i + 1);
        Map<Node, Double> map = new HashMap<>();
        for (Node node : arrNode) {
            if (isInitialWeightSet) {
                map.put(node, weights[1][j]);
            } else {
                map.put(node, rand.nextDouble());
            }
            j++;
        }
        biasesWeight.put(i, map);
    }

    j = 0;
    // masukin bobot tiap neuron
    Map<Node, Map<Node, Double>> mapWeight = new HashMap<>();
    for (int i = 0; i < nbLayers.length - 1; i++) {
        ArrayList<Node> arrNode = layers.get(i);
        for (Node node : arrNode) {
            Map<Node, Double> map = new HashMap<>();
            for (Node nextNode : node.getNextNodes()) {
                if (isInitialWeightSet) {
                    map.put(nextNode, weights[0][j]);
                } else {
                    map.put(nextNode, rand.nextDouble());
                }
                j++;
            }
            mapWeight.put(node, map);
        }
    }

    // buat model ANN berdasarkan nilai di atas
    annModel = new ANNModel(layers, mapWeight, bias, biasesWeight);
    // set konfigurasi awal model
    // debug
    //        System.out.println("debug");
    //        for (Data d : datas) {
    //            for (Double dd : d.input) {
    //                System.out.print(dd+" ");
    //            }
    //            System.out.print(" | ");
    //            for (Double dd : d.target) {
    //                System.out.print(dd+" ");
    //            }
    //            System.out.println("");
    //        }
    //        System.out.println("debug");
    annModel.setDataSet(datas);
    annModel.setLearningRate(learningRate);
    annModel.setMomentum(momentum);
    switch (activationFunction) {
    case SIGMOID_FUNCTION:
        annModel.setActivationFunction(ANNModel.SIGMOID);
        break;
    case SIGN_FUNCTION:
        // ubah target jadi -1 dan 1
        for (Data d : datas) {
            for (Double dd : d.target) {
                if (dd == 0.0) {
                    dd = -1.0;
                }
            }
        }
        annModel.setActivationFunction(ANNModel.SIGN);
        break;
    case STEP_FUNCTION:
        annModel.setActivationFunction(ANNModel.STEP);
        break;
    default:
        break;
    }
    if (learningRule == BATCH_GRADIENT_DESCENT || learningRule == DELTA_RULE)
        annModel.setActivationFunction(ANNModel.NO_FUNC);
    if (topology == MULTILAYER_PERCEPTRON) {
        annModel.setActivationFunction(ANNModel.SIGMOID);
    }
    annModel.setThreshold(threshold);

    // jalankan algoritma
    boolean stop = false;
    iteration = 0;

    //annModel.print();
    annModel.resetDeltaWeight();
    do {
        if (topology == ONE_PERCEPTRON) {
            switch (learningRule) {
            case PERCEPTRON_TRAINING_RULE:
                annModel.perceptronTrainingRule();
                break;
            case BATCH_GRADIENT_DESCENT:
                annModel.batchGradienDescent();
                break;
            case DELTA_RULE:
                annModel.deltaRule();
                break;
            default:
                break;
            }
        } else if (topology == MULTILAYER_PERCEPTRON) {
            annModel.backProp();
        }
        iteration++;

        // berhenti jika terminateCondition terpenuhi
        switch (terminationCondition) {
        case TERMINATE_MAX_ITERATION:
            if (iteration >= maxIteration)
                stop = true;
            break;
        case TERMINATE_MSE:
            if (annModel.error < deltaMSE)
                stop = true;
            break;
        case TERMINATE_BOTH:
            if (iteration > maxIteration || annModel.error < deltaMSE)
                stop = true;
            break;
        default:
            break;
        }
        //            System.out.println(annModel.error);
    } while (!stop);
    //        annModel.print();
}

From source file:ann.MyANN.java

/**
 * Mengevaluasi model dengan membagi instances menjadi trainSet dan testSet sebanyak numFold
 * @param instances data yang akan diuji
 * @param numFold//from   w w  w. jav a 2 s. c om
 * @param rand 
 * @return confusion matrix
 */
public int[][] crossValidation(Instances instances, int numFold, Random rand) {
    int[][] totalResult = null;
    instances = new Instances(instances);
    instances.randomize(rand);
    if (instances.classAttribute().isNominal()) {
        instances.stratify(numFold);
    }
    for (int i = 0; i < numFold; i++) {
        try {
            // membagi instance berdasarkan jumlah fold
            Instances train = instances.trainCV(numFold, i, rand);
            Instances test = instances.testCV(numFold, i);
            MyANN cc = new MyANN(this);
            cc.buildClassifier(train);
            int[][] result = cc.evaluate(test);
            if (i == 0) {
                totalResult = cc.evaluate(test);
            } else {
                result = cc.evaluate(test);
                for (int j = 0; j < totalResult.length; j++) {
                    for (int k = 0; k < totalResult[0].length; k++) {
                        totalResult[j][k] += result[j][k];
                    }
                }
            }
        } catch (Exception ex) {
            Logger.getLogger(MyANN.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    return totalResult;
}

From source file:ann.SingleLayerPerceptron.java

@Override
public void buildClassifier(Instances data) throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(data);
    annOptions = new ANNOptions();
    annOptions = annOptions.loadConfiguration();
    output = new ArrayList<Neuron>();
    normalize = new Normalize();
    ntb = new NominalToBinary();
    output = annOptions.output;// w w  w . j  a  v a  2s  . c  o m

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    //nominal to binary filter
    ntb.setInputFormat(data);
    data = new Instances(Filter.useFilter(data, ntb));

    //normalize filter
    normalize.setInputFormat(data);
    data = new Instances(Filter.useFilter(data, normalize));

    // do main function
    doPerceptron(data);
}

From source file:ann.SingleLayerPerceptron.java

public int[] classifyInstances(Instances data) throws Exception {
    int[] classValue = new int[data.numInstances()];
    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();// www.ja va2  s  .  c om

    //nominal to binary filter
    ntb.setInputFormat(data);
    data = new Instances(Filter.useFilter(data, ntb));
    int right = 0;

    for (int i = 0; i < data.numInstances(); i++) {
        int outputSize = output.size();
        double[] result = new double[outputSize];
        for (int j = 0; j < outputSize; j++) {
            result[j] = 0.0;
            for (int k = 0; k < data.numAttributes(); k++) {
                double input = 1;
                if (k < data.numAttributes() - 1) {
                    input = data.instance(i).value(k);
                }
                result[j] += output.get(j).weights.get(k) * input;
            }
            result[j] = Util.activationFunction(result[j], annOptions);
        }

        if (outputSize >= 2) {
            for (int j = 0; j < outputSize; j++) {
                if (result[j] > result[classValue[i]]) {
                    classValue[i] = j;
                }
            }
        } else {
            classValue[i] = (int) result[0];
        }
        double target = data.instance(i).classValue();
        double output = classValue[i];
        System.out.println("Intance-" + i + " target: " + target + " output: " + output);
        if (target == output) {
            right = right + 1;
        }
    }

    System.out.println("Percentage: " + ((double) right / (double) data.numInstances()));

    return classValue;
}

From source file:ant.Game.java

public Game(int dim, boolean first, int fov) throws FileNotFoundException, IOException {
    m_dim = dim;//from   w w  w. j  av  a 2 s  .c  o m
    m_score = 0;
    m_moves = m_dim * 2;
    initGrid();
    initFood();
    initAnt(fov);
    if (first) {
        initFile();
    }
    BufferedReader reader = new BufferedReader(new FileReader(m_fileName));
    m_data = new Instances(reader);
    m_data.setClassIndex(m_data.numAttributes() - 1);
    m_wrapper1x1 = new WekaWrapper1x1();
    m_wrapper5x1 = new WekaWrapper5x1();
    m_wrapper10x1 = new WekaWrapper10x1();
    m_wrapper1x2 = new WekaWrapper1x2();
    m_wrapper5x2 = new WekaWrapper5x2();
    m_wrapper10x2 = new WekaWrapper10x2();
    m_wrapper50x2 = new WekaWrapper50x2();
}

From source file:app.RunApp.java

License:Open Source License

/**
 * Preprocess dataset// w  w  w. j a  v a 2  s  .c om
 * 
 * @return Positive number if successfull and negative otherwise
 */
private int preprocess() {
    trainDatasets = new ArrayList();
    testDatasets = new ArrayList();

    Instances train, test;

    if (dataset == null) {
        JOptionPane.showMessageDialog(null, "You must load a dataset.", "alert", JOptionPane.ERROR_MESSAGE);
        return -1;
    }

    MultiLabelInstances preprocessDataset = dataset.clone();

    if (!radioNoIS.isSelected()) {
        //Do Instance Selection
        if (radioRandomIS.isSelected()) {
            int nInstances = Integer.parseInt(textRandomIS.getText());

            if (nInstances < 1) {
                JOptionPane.showMessageDialog(null,
                        "The number of instances must be a positive natural number.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nInstances > dataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of instances to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            Instances dataIS;
            try {
                Randomize randomize = new Randomize();
                dataIS = dataset.getDataSet();

                randomize.setInputFormat(dataIS);
                dataIS = Filter.useFilter(dataIS, randomize);
                randomize.batchFinished();

                RemoveRange removeRange = new RemoveRange();
                removeRange.setInputFormat(dataIS);
                removeRange.setInstancesIndices((nInstances + 1) + "-last");

                dataIS = Filter.useFilter(dataIS, removeRange);
                removeRange.batchFinished();

                preprocessDataset = dataset.reintegrateModifiedDataSet(dataIS);
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }

            if (preprocessDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessedDataset = preprocessDataset;
        }
    }

    if (!radioNoFS.isSelected()) {
        //FS_BR
        if (radioBRFS.isSelected()) {
            int nFeatures = Integer.parseInt(textBRFS.getText());
            if (nFeatures < 1) {
                JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFeatures > dataset.getFeatureIndices().length) {
                JOptionPane.showMessageDialog(null,
                        "The number of features to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            String combination = jComboBoxBRFSComb.getSelectedItem().toString();
            String normalization = jComboBoxBRFSNorm.getSelectedItem().toString();
            String output = jComboBoxBRFSOut.getSelectedItem().toString();

            FeatureSelector fs;
            if (radioNoIS.isSelected()) {
                fs = new FeatureSelector(dataset, nFeatures);
            } else {
                //If IS have been done
                fs = new FeatureSelector(preprocessDataset, nFeatures);
            }

            preprocessedDataset = fs.select(combination, normalization, output);

            if (preprocessedDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessDataset = preprocessedDataset;
        } else if (radioRandomFS.isSelected()) {
            int nFeatures = Integer.parseInt(textRandomFS.getText());

            if (nFeatures < 1) {
                JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFeatures > dataset.getFeatureIndices().length) {
                JOptionPane.showMessageDialog(null,
                        "The number of features to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            FeatureSelector fs;

            if (radioNoIS.isSelected()) {
                fs = new FeatureSelector(dataset, nFeatures);
            } else {
                //If IS have been done
                fs = new FeatureSelector(preprocessDataset, nFeatures);
            }

            preprocessedDataset = fs.randomSelect();

            if (preprocessedDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessDataset = preprocessedDataset;
        }
    }

    if (!radioNoSplit.isSelected()) {
        //Random Holdout
        if (radioRandomHoldout.isSelected()) {
            String split = textRandomHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                RandomTrainTest pre = new RandomTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);
                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (InvalidDataFormatException ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Random CV
        else if (radioRandomCV.isSelected()) {
            String split = textRandomCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                MultiLabelInstances temp = preprocessDataset.clone();
                Instances dataTemp = temp.getDataSet();

                int seed = (int) (Math.random() * 100) + 100;
                Random rand = new Random(seed);

                dataTemp.randomize(rand);

                Instances[] foldsCV = new Instances[nFolds];
                for (int i = 0; i < nFolds; i++) {
                    foldsCV[i] = new Instances(dataTemp);
                    foldsCV[i].clear();
                }

                for (int i = 0; i < dataTemp.numInstances(); i++) {
                    foldsCV[i % nFolds].add(dataTemp.get(i));
                }

                train = new Instances(dataTemp);
                test = new Instances(dataTemp);
                for (int i = 0; i < nFolds; i++) {
                    train.clear();
                    test.clear();
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            System.out.println("Add fold " + j + " to train");
                            train.addAll(foldsCV[j]);
                        }
                    }
                    System.out.println("Add fold " + i + " to test");
                    test.addAll(foldsCV[i]);
                    System.out.println(train.get(0).toString());
                    System.out.println(test.get(0).toString());
                    trainDatasets.add(new MultiLabelInstances(new Instances(train),
                            preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(new Instances(test),
                            preprocessDataset.getLabelsMetaData()));
                    System.out.println(trainDatasets.get(i).getDataSet().get(0).toString());
                    System.out.println(testDatasets.get(i).getDataSet().get(0).toString());
                    System.out.println("---");
                }
            }

            catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Iterative stratified holdout
        else if (radioIterativeStratifiedHoldout.isSelected()) {
            String split = textIterativeStratifiedHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                IterativeTrainTest pre = new IterativeTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);

                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Iterative stratified CV
        else if (radioIterativeStratifiedCV.isSelected()) {
            String split = textIterativeStratifiedCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds = 0;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            IterativeStratification strat = new IterativeStratification();
            MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds);

            for (int i = 0; i < nFolds; i++) {
                try {

                    int trainSize = 0, testSize = 0;
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            trainSize += folds[j].getNumInstances();
                        }
                    }
                    testSize += folds[i].getNumInstances();

                    train = new Instances(preprocessDataset.getDataSet(), trainSize);
                    test = new Instances(preprocessDataset.getDataSet(), testSize);
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            train.addAll(folds[j].getDataSet());
                        }
                    }
                    test.addAll(folds[i].getDataSet());

                    trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData()));
                } catch (InvalidDataFormatException ex) {
                    Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
                }
            }

        }
        //LP stratified holdout
        else if (radioLPStratifiedHoldout.isSelected()) {
            String split = textLPStratifiedHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                IterativeTrainTest pre = new IterativeTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);

                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //LP stratified CV
        else if (radioLPStratifiedCV.isSelected()) {
            String split = textLPStratifiedCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds = 0;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            LabelPowersetTrainTest strat = new LabelPowersetTrainTest();
            MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds);

            for (int i = 0; i < nFolds; i++) {
                try {
                    train = new Instances(preprocessDataset.getDataSet(), 0);
                    test = new Instances(preprocessDataset.getDataSet(), 0);

                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            train.addAll(folds[j].getDataSet());
                        }
                    }
                    test.addAll(folds[i].getDataSet());

                    trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData()));
                } catch (InvalidDataFormatException ex) {
                    Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        }
    }

    jButtonSaveDatasets.setEnabled(true);
    jComboBoxSaveFormat.setEnabled(true);

    return 1;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput/*www. ja v a  2  s.c om*/
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidation(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation ST");

    PerformanceCounters.startTimer("cross-validation init ST");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    Instances trainSets[] = new Instances[folds];
    Instances testSets[] = new Instances[folds];
    Classifier foldCls[] = new Classifier[folds];

    for (int n = 0; n < folds; n++) {
        trainSets[n] = randData.trainCV(folds, n);
        testSets[n] = randData.testCV(folds, n);
        foldCls[n] = AbstractClassifier.makeCopy(cls);
    }

    PerformanceCounters.stopTimer("cross-validation init ST");
    PerformanceCounters.startTimer("cross-validation folds+train ST");
    //paralelize!!:--------------------------------------------------------------
    for (int n = 0; n < folds; n++) {
        Instances train = trainSets[n];
        Instances test = testSets[n];

        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);
        // build and evaluate classifier
        Classifier clsCopy = foldCls[n];
        clsCopy.buildClassifier(train);
        eval.evaluateModel(clsCopy, test);
    }

    cls.buildClassifier(data);
    //until here!-----------------------------------------------------------------

    PerformanceCounters.stopTimer("cross-validation folds+train ST");
    PerformanceCounters.startTimer("cross-validation post ST");
    // output evaluation
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post ST");
    PerformanceCounters.stopTimer("cross-validation ST");

    return out;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput//from   w w  w  .java  2 s .c om
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                AbstractClassifier.makeCopy(cls)));

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        //use the current thread to run the cross-validation instead of using the Thread instance created here:
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    cls.buildClassifier(data);

    for (Thread foldThread : foldThreads) {
        foldThread.join();
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.CrossValidation.java

static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds,
        String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);/*  w  w w .  j a v  a  2 s  . c om*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(cls)));
        } catch (Exception ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }

        //TODO: use Config.getNumThreads() for limiting these::
        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    try {
        cls.buildClassifier(data);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
    }

    for (Thread foldThread : foldThreads) {
        try {
            foldThread.join();
        } catch (InterruptedException ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, cls);
            } catch (Exception ex) {
                Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.NLPSystem.java

private String crossValidate(int seed, int folds, String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation");
    PerformanceCounters.startTimer("cross-validation init");

    AbstractClassifier abstractClassifier = (AbstractClassifier) classifier;
    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(trainingSet);
    randData.randomize(rand);/*w  w  w .  j ava  2s .  c  o m*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(abstractClassifier)));
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init");
    PerformanceCounters.startTimer("cross-validation folds+train");

    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    for (Thread foldThread : foldThreads) {
        while (foldThread.isAlive()) {
            try {
                foldThread.join();
            } catch (InterruptedException ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation folds+train");
    PerformanceCounters.startTimer("cross-validation post");
    // evaluation for output:
    String out = String.format(
            "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n",
            abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()),
            trainingSet.relationName(), folds, seed,
            eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false));

    try {
        crossValidationPearsonsCorrelation = eval.correlationCoefficient();
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
    }
    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, abstractClassifier);
            } catch (Exception ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    classifierBuiltWithCrossValidation = true;
    PerformanceCounters.stopTimer("cross-validation post");
    PerformanceCounters.stopTimer("cross-validation");
    return out;
}