Example usage for weka.core Instances randomize

List of usage examples for weka.core Instances randomize

Introduction

In this page you can find the example usage for weka.core Instances randomize.

Prototype

public void randomize(Random random) 

Source Link

Document

Shuffles the instances in the set so that they are ordered randomly.

Usage

From source file:ann.MyANN.java

/**
 * Mengevaluasi model dengan membagi instances menjadi trainSet dan testSet sebanyak numFold
 * @param instances data yang akan diuji
 * @param numFold/*from  w w  w.  j  a va2  s .c o  m*/
 * @param rand 
 * @return confusion matrix
 */
public int[][] crossValidation(Instances instances, int numFold, Random rand) {
    int[][] totalResult = null;
    instances = new Instances(instances);
    instances.randomize(rand);
    if (instances.classAttribute().isNominal()) {
        instances.stratify(numFold);
    }
    for (int i = 0; i < numFold; i++) {
        try {
            // membagi instance berdasarkan jumlah fold
            Instances train = instances.trainCV(numFold, i, rand);
            Instances test = instances.testCV(numFold, i);
            MyANN cc = new MyANN(this);
            cc.buildClassifier(train);
            int[][] result = cc.evaluate(test);
            if (i == 0) {
                totalResult = cc.evaluate(test);
            } else {
                result = cc.evaluate(test);
                for (int j = 0; j < totalResult.length; j++) {
                    for (int k = 0; k < totalResult[0].length; k++) {
                        totalResult[j][k] += result[j][k];
                    }
                }
            }
        } catch (Exception ex) {
            Logger.getLogger(MyANN.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    return totalResult;
}

From source file:app.RunApp.java

License:Open Source License

/**
 * Preprocess dataset/*from ww w.jav  a2  s . c o  m*/
 * 
 * @return Positive number if successfull and negative otherwise
 */
private int preprocess() {
    trainDatasets = new ArrayList();
    testDatasets = new ArrayList();

    Instances train, test;

    if (dataset == null) {
        JOptionPane.showMessageDialog(null, "You must load a dataset.", "alert", JOptionPane.ERROR_MESSAGE);
        return -1;
    }

    MultiLabelInstances preprocessDataset = dataset.clone();

    if (!radioNoIS.isSelected()) {
        //Do Instance Selection
        if (radioRandomIS.isSelected()) {
            int nInstances = Integer.parseInt(textRandomIS.getText());

            if (nInstances < 1) {
                JOptionPane.showMessageDialog(null,
                        "The number of instances must be a positive natural number.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nInstances > dataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of instances to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            Instances dataIS;
            try {
                Randomize randomize = new Randomize();
                dataIS = dataset.getDataSet();

                randomize.setInputFormat(dataIS);
                dataIS = Filter.useFilter(dataIS, randomize);
                randomize.batchFinished();

                RemoveRange removeRange = new RemoveRange();
                removeRange.setInputFormat(dataIS);
                removeRange.setInstancesIndices((nInstances + 1) + "-last");

                dataIS = Filter.useFilter(dataIS, removeRange);
                removeRange.batchFinished();

                preprocessDataset = dataset.reintegrateModifiedDataSet(dataIS);
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }

            if (preprocessDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessedDataset = preprocessDataset;
        }
    }

    if (!radioNoFS.isSelected()) {
        //FS_BR
        if (radioBRFS.isSelected()) {
            int nFeatures = Integer.parseInt(textBRFS.getText());
            if (nFeatures < 1) {
                JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFeatures > dataset.getFeatureIndices().length) {
                JOptionPane.showMessageDialog(null,
                        "The number of features to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            String combination = jComboBoxBRFSComb.getSelectedItem().toString();
            String normalization = jComboBoxBRFSNorm.getSelectedItem().toString();
            String output = jComboBoxBRFSOut.getSelectedItem().toString();

            FeatureSelector fs;
            if (radioNoIS.isSelected()) {
                fs = new FeatureSelector(dataset, nFeatures);
            } else {
                //If IS have been done
                fs = new FeatureSelector(preprocessDataset, nFeatures);
            }

            preprocessedDataset = fs.select(combination, normalization, output);

            if (preprocessedDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessDataset = preprocessedDataset;
        } else if (radioRandomFS.isSelected()) {
            int nFeatures = Integer.parseInt(textRandomFS.getText());

            if (nFeatures < 1) {
                JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFeatures > dataset.getFeatureIndices().length) {
                JOptionPane.showMessageDialog(null,
                        "The number of features to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            FeatureSelector fs;

            if (radioNoIS.isSelected()) {
                fs = new FeatureSelector(dataset, nFeatures);
            } else {
                //If IS have been done
                fs = new FeatureSelector(preprocessDataset, nFeatures);
            }

            preprocessedDataset = fs.randomSelect();

            if (preprocessedDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessDataset = preprocessedDataset;
        }
    }

    if (!radioNoSplit.isSelected()) {
        //Random Holdout
        if (radioRandomHoldout.isSelected()) {
            String split = textRandomHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                RandomTrainTest pre = new RandomTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);
                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (InvalidDataFormatException ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Random CV
        else if (radioRandomCV.isSelected()) {
            String split = textRandomCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                MultiLabelInstances temp = preprocessDataset.clone();
                Instances dataTemp = temp.getDataSet();

                int seed = (int) (Math.random() * 100) + 100;
                Random rand = new Random(seed);

                dataTemp.randomize(rand);

                Instances[] foldsCV = new Instances[nFolds];
                for (int i = 0; i < nFolds; i++) {
                    foldsCV[i] = new Instances(dataTemp);
                    foldsCV[i].clear();
                }

                for (int i = 0; i < dataTemp.numInstances(); i++) {
                    foldsCV[i % nFolds].add(dataTemp.get(i));
                }

                train = new Instances(dataTemp);
                test = new Instances(dataTemp);
                for (int i = 0; i < nFolds; i++) {
                    train.clear();
                    test.clear();
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            System.out.println("Add fold " + j + " to train");
                            train.addAll(foldsCV[j]);
                        }
                    }
                    System.out.println("Add fold " + i + " to test");
                    test.addAll(foldsCV[i]);
                    System.out.println(train.get(0).toString());
                    System.out.println(test.get(0).toString());
                    trainDatasets.add(new MultiLabelInstances(new Instances(train),
                            preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(new Instances(test),
                            preprocessDataset.getLabelsMetaData()));
                    System.out.println(trainDatasets.get(i).getDataSet().get(0).toString());
                    System.out.println(testDatasets.get(i).getDataSet().get(0).toString());
                    System.out.println("---");
                }
            }

            catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Iterative stratified holdout
        else if (radioIterativeStratifiedHoldout.isSelected()) {
            String split = textIterativeStratifiedHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                IterativeTrainTest pre = new IterativeTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);

                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Iterative stratified CV
        else if (radioIterativeStratifiedCV.isSelected()) {
            String split = textIterativeStratifiedCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds = 0;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            IterativeStratification strat = new IterativeStratification();
            MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds);

            for (int i = 0; i < nFolds; i++) {
                try {

                    int trainSize = 0, testSize = 0;
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            trainSize += folds[j].getNumInstances();
                        }
                    }
                    testSize += folds[i].getNumInstances();

                    train = new Instances(preprocessDataset.getDataSet(), trainSize);
                    test = new Instances(preprocessDataset.getDataSet(), testSize);
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            train.addAll(folds[j].getDataSet());
                        }
                    }
                    test.addAll(folds[i].getDataSet());

                    trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData()));
                } catch (InvalidDataFormatException ex) {
                    Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
                }
            }

        }
        //LP stratified holdout
        else if (radioLPStratifiedHoldout.isSelected()) {
            String split = textLPStratifiedHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                IterativeTrainTest pre = new IterativeTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);

                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //LP stratified CV
        else if (radioLPStratifiedCV.isSelected()) {
            String split = textLPStratifiedCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds = 0;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            LabelPowersetTrainTest strat = new LabelPowersetTrainTest();
            MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds);

            for (int i = 0; i < nFolds; i++) {
                try {
                    train = new Instances(preprocessDataset.getDataSet(), 0);
                    test = new Instances(preprocessDataset.getDataSet(), 0);

                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            train.addAll(folds[j].getDataSet());
                        }
                    }
                    test.addAll(folds[i].getDataSet());

                    trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData()));
                } catch (InvalidDataFormatException ex) {
                    Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        }
    }

    jButtonSaveDatasets.setEnabled(true);
    jComboBoxSaveFormat.setEnabled(true);

    return 1;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput//from w w  w . j a v a  2s . co  m
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidation(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation ST");

    PerformanceCounters.startTimer("cross-validation init ST");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    Instances trainSets[] = new Instances[folds];
    Instances testSets[] = new Instances[folds];
    Classifier foldCls[] = new Classifier[folds];

    for (int n = 0; n < folds; n++) {
        trainSets[n] = randData.trainCV(folds, n);
        testSets[n] = randData.testCV(folds, n);
        foldCls[n] = AbstractClassifier.makeCopy(cls);
    }

    PerformanceCounters.stopTimer("cross-validation init ST");
    PerformanceCounters.startTimer("cross-validation folds+train ST");
    //paralelize!!:--------------------------------------------------------------
    for (int n = 0; n < folds; n++) {
        Instances train = trainSets[n];
        Instances test = testSets[n];

        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);
        // build and evaluate classifier
        Classifier clsCopy = foldCls[n];
        clsCopy.buildClassifier(train);
        eval.evaluateModel(clsCopy, test);
    }

    cls.buildClassifier(data);
    //until here!-----------------------------------------------------------------

    PerformanceCounters.stopTimer("cross-validation folds+train ST");
    PerformanceCounters.startTimer("cross-validation post ST");
    // output evaluation
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post ST");
    PerformanceCounters.stopTimer("cross-validation ST");

    return out;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput//from w  ww.j  a  v  a  2 s.  c om
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                AbstractClassifier.makeCopy(cls)));

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        //use the current thread to run the cross-validation instead of using the Thread instance created here:
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    cls.buildClassifier(data);

    for (Thread foldThread : foldThreads) {
        foldThread.join();
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.CrossValidation.java

static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds,
        String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);//from w  w w .j  a v  a 2s  .  c om
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(cls)));
        } catch (Exception ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }

        //TODO: use Config.getNumThreads() for limiting these::
        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    try {
        cls.buildClassifier(data);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
    }

    for (Thread foldThread : foldThreads) {
        try {
            foldThread.join();
        } catch (InterruptedException ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, cls);
            } catch (Exception ex) {
                Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.NLPSystem.java

private String crossValidate(int seed, int folds, String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation");
    PerformanceCounters.startTimer("cross-validation init");

    AbstractClassifier abstractClassifier = (AbstractClassifier) classifier;
    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(trainingSet);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);// w ww.  j  a  va  2  s.  co m
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(abstractClassifier)));
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init");
    PerformanceCounters.startTimer("cross-validation folds+train");

    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    for (Thread foldThread : foldThreads) {
        while (foldThread.isAlive()) {
            try {
                foldThread.join();
            } catch (InterruptedException ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation folds+train");
    PerformanceCounters.startTimer("cross-validation post");
    // evaluation for output:
    String out = String.format(
            "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n",
            abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()),
            trainingSet.relationName(), folds, seed,
            eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false));

    try {
        crossValidationPearsonsCorrelation = eval.correlationCoefficient();
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
    }
    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, abstractClassifier);
            } catch (Exception ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    classifierBuiltWithCrossValidation = true;
    PerformanceCounters.stopTimer("cross-validation post");
    PerformanceCounters.stopTimer("cross-validation");
    return out;
}

From source file:assign00.ExperimentShell.java

/**
 * @param args the command line arguments
 *///from w  ww.  ja  v a2 s.  c o  m
public static void main(String[] args) throws Exception {
    DataSource source = new DataSource(file);
    Instances dataSet = source.getDataSet();

    //Set up data
    dataSet.setClassIndex(dataSet.numAttributes() - 1);
    dataSet.randomize(new Random(1));

    //determine sizes
    int trainingSize = (int) Math.round(dataSet.numInstances() * .7);
    int testSize = dataSet.numInstances() - trainingSize;

    Instances training = new Instances(dataSet, 0, trainingSize);

    Instances test = new Instances(dataSet, trainingSize, testSize);

    Standardize standardizedData = new Standardize();
    standardizedData.setInputFormat(training);

    Instances newTest = Filter.useFilter(test, standardizedData);
    Instances newTraining = Filter.useFilter(training, standardizedData);

    NeuralNetworkClassifier NWC = new NeuralNetworkClassifier();
    NWC.buildClassifier(newTraining);

    Evaluation eval = new Evaluation(newTraining);
    eval.evaluateModel(NWC, newTest);

    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
}

From source file:br.com.ufu.lsi.rebfnetwork.RBFModel.java

License:Open Source License

/**
 * Method used to pre-process the data, perform clustering, and
 * set the initial parameter vector.//from   www  .  j  a  v a2  s  .co  m
 */
protected Instances initializeClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    data = new Instances(data);
    data.deleteWithMissingClass();

    // Make sure data is shuffled
    Random random = new Random(m_Seed);
    if (data.numInstances() > 2) {
        random = data.getRandomNumberGenerator(m_Seed);
    }
    data.randomize(random);

    double y0 = data.instance(0).classValue(); // This stuff is not relevant in classification case
    int index = 1;
    while (index < data.numInstances() && data.instance(index).classValue() == y0) {
        index++;
    }
    if (index == data.numInstances()) {
        // degenerate case, all class values are equal
        // we don't want to deal with this, too much hassle
        throw new Exception("All class values are the same. At least two class values should be different");
    }
    double y1 = data.instance(index).classValue();

    // Replace missing values   
    m_ReplaceMissingValues = new ReplaceMissingValues();
    m_ReplaceMissingValues.setInputFormat(data);
    data = Filter.useFilter(data, m_ReplaceMissingValues);

    // Remove useless attributes
    m_AttFilter = new RemoveUseless();
    m_AttFilter.setInputFormat(data);
    data = Filter.useFilter(data, m_AttFilter);

    // only class? -> build ZeroR model
    if (data.numAttributes() == 1) {
        System.err.println(
                "Cannot build model (only class attribute present in data after removing useless attributes!), "
                        + "using ZeroR model instead!");
        m_ZeroR = new weka.classifiers.rules.ZeroR();
        m_ZeroR.buildClassifier(data);
        return data;
    } else {
        m_ZeroR = null;
    }

    // Transform attributes
    m_NominalToBinary = new NominalToBinary();
    m_NominalToBinary.setInputFormat(data);
    data = Filter.useFilter(data, m_NominalToBinary);

    m_Filter = new Normalize();
    ((Normalize) m_Filter).setIgnoreClass(true);
    m_Filter.setInputFormat(data);
    data = Filter.useFilter(data, m_Filter);
    double z0 = data.instance(0).classValue(); // This stuff is not relevant in classification case
    double z1 = data.instance(index).classValue();
    m_x1 = (y0 - y1) / (z0 - z1); // no division by zero, since y0 != y1 guaranteed => z0 != z1 ???
    m_x0 = (y0 - m_x1 * z0); // = y1 - m_x1 * z1

    m_classIndex = data.classIndex();
    m_numClasses = data.numClasses();
    m_numAttributes = data.numAttributes();

    // Run k-means
    SimpleKMeans skm = new SimpleKMeans();
    skm.setMaxIterations(10000);
    skm.setNumClusters(m_numUnits);
    Remove rm = new Remove();
    data.setClassIndex(-1);
    rm.setAttributeIndices((m_classIndex + 1) + "");
    rm.setInputFormat(data);
    Instances dataRemoved = Filter.useFilter(data, rm);
    data.setClassIndex(m_classIndex);
    skm.buildClusterer(dataRemoved);
    Instances centers = skm.getClusterCentroids();

    if (centers.numInstances() < m_numUnits) {
        m_numUnits = centers.numInstances();
    }

    // Set up arrays
    OFFSET_WEIGHTS = 0;
    if (m_useAttributeWeights) {
        OFFSET_ATTRIBUTE_WEIGHTS = (m_numUnits + 1) * m_numClasses;
        OFFSET_CENTERS = OFFSET_ATTRIBUTE_WEIGHTS + m_numAttributes;
    } else {
        OFFSET_ATTRIBUTE_WEIGHTS = -1;
        OFFSET_CENTERS = (m_numUnits + 1) * m_numClasses;
    }
    OFFSET_SCALES = OFFSET_CENTERS + m_numUnits * m_numAttributes;

    switch (m_scaleOptimizationOption) {
    case USE_GLOBAL_SCALE:
        m_RBFParameters = new double[OFFSET_SCALES + 1];
        break;
    case USE_SCALE_PER_UNIT_AND_ATTRIBUTE:
        m_RBFParameters = new double[OFFSET_SCALES + m_numUnits * m_numAttributes];
        break;
    default:
        m_RBFParameters = new double[OFFSET_SCALES + m_numUnits];
        break;
    }

    // Set initial radius based on distance to nearest other basis function
    double maxMinDist = -1;
    for (int i = 0; i < centers.numInstances(); i++) {
        double minDist = Double.MAX_VALUE;
        for (int j = i + 1; j < centers.numInstances(); j++) {
            double dist = 0;
            for (int k = 0; k < centers.numAttributes(); k++) {
                if (k != centers.classIndex()) {
                    double diff = centers.instance(i).value(k) - centers.instance(j).value(k);
                    dist += diff * diff;
                }
            }
            if (dist < minDist) {
                minDist = dist;
            }
        }
        if ((minDist != Double.MAX_VALUE) && (minDist > maxMinDist)) {
            maxMinDist = minDist;
        }
    }

    // Initialize parameters
    if (m_scaleOptimizationOption == USE_GLOBAL_SCALE) {
        m_RBFParameters[OFFSET_SCALES] = Math.sqrt(maxMinDist);
    }
    for (int i = 0; i < m_numUnits; i++) {
        if (m_scaleOptimizationOption == USE_SCALE_PER_UNIT) {
            m_RBFParameters[OFFSET_SCALES + i] = Math.sqrt(maxMinDist);
        }
        int k = 0;
        for (int j = 0; j < m_numAttributes; j++) {
            if (k == centers.classIndex()) {
                k++;
            }
            if (j != data.classIndex()) {
                if (m_scaleOptimizationOption == USE_SCALE_PER_UNIT_AND_ATTRIBUTE) {
                    m_RBFParameters[OFFSET_SCALES + (i * m_numAttributes + j)] = Math.sqrt(maxMinDist);
                }
                m_RBFParameters[OFFSET_CENTERS + (i * m_numAttributes) + j] = centers.instance(i).value(k);
                k++;
            }
        }
    }

    if (m_useAttributeWeights) {
        for (int j = 0; j < m_numAttributes; j++) {
            if (j != data.classIndex()) {
                m_RBFParameters[OFFSET_ATTRIBUTE_WEIGHTS + j] = 1.0;
            }
        }
    }

    initializeOutputLayer(random);

    return data;
}

From source file:br.fapesp.myutils.MyUtils.java

License:Open Source License

/**
 * Generates a Gaussian data set with K clusters and m dimensions
 * //  ww  w .  j  ava 2  s.c o m
 * @param centers
 *            K x m matrix
 * @param sigmas
 *            K x m matrix
 * @param pointsPerCluster
 *            number of points per cluster
 * @param seed
 *            for the RNG
 * @param randomize
 *            should the order of the instances be randomized?
 * @param supervised
 *            should class label be present? if true, the class is the m+1
 *            attribute
 * 
 * @return
 */
public static Instances genGaussianDataset(double[][] centers, double[][] sigmas, int pointsPerCluster,
        long seed, boolean randomize, boolean supervised) {
    Random r = new Random(seed);

    int K = centers.length; // number of clusters
    int m = centers[0].length; // number of dimensions

    FastVector atts = new FastVector(m);
    for (int i = 0; i < m; i++)
        atts.addElement(new Attribute("at" + i));

    if (supervised) {
        FastVector cls = new FastVector(K);
        for (int i = 0; i < K; i++)
            cls.addElement("Gauss-" + i);
        atts.addElement(new Attribute("Class", cls));
    }

    Instances data;
    if (supervised)
        data = new Instances(K + "-Gaussians-supervised", atts, K * pointsPerCluster);
    else
        data = new Instances(K + "-Gaussians", atts, K * pointsPerCluster);

    if (supervised)
        data.setClassIndex(m);

    Instance ith;

    for (int i = 0; i < K; i++) {
        for (int j = 0; j < pointsPerCluster; j++) {
            if (!supervised)
                ith = new DenseInstance(m);
            else
                ith = new DenseInstance(m + 1);
            ith.setDataset(data);
            for (int k = 0; k < m; k++)
                ith.setValue(k, centers[i][k] + (r.nextGaussian() * sigmas[i][k]));
            if (supervised)
                ith.setValue(m, "Gauss-" + i);
            data.add(ith);
        }
    }

    // run randomization filter if desired
    if (randomize)
        data.randomize(r);

    return data;
}

From source file:br.ufrn.ia.core.clustering.EMIaProject.java

License:Open Source License

private void CVClusters() throws Exception {
    double CVLogLikely = -Double.MAX_VALUE;
    double templl, tll;
    boolean CVincreased = true;
    m_num_clusters = 1;// www.  j a va  2  s  .c om
    int num_clusters = m_num_clusters;
    int i;
    Random cvr;
    Instances trainCopy;
    int numFolds = (m_theInstances.numInstances() < 10) ? m_theInstances.numInstances() : 10;

    boolean ok = true;
    int seed = getSeed();
    int restartCount = 0;
    CLUSTER_SEARCH: while (CVincreased) {
        // theInstances.stratify(10);

        CVincreased = false;
        cvr = new Random(getSeed());
        trainCopy = new Instances(m_theInstances);
        trainCopy.randomize(cvr);
        templl = 0.0;
        for (i = 0; i < numFolds; i++) {
            Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
            if (num_clusters > cvTrain.numInstances()) {
                break CLUSTER_SEARCH;
            }
            Instances cvTest = trainCopy.testCV(numFolds, i);
            m_rr = new Random(seed);
            for (int z = 0; z < 10; z++)
                m_rr.nextDouble();
            m_num_clusters = num_clusters;
            EM_Init(cvTrain);
            try {
                iterate(cvTrain, false);
            } catch (Exception ex) {
                // catch any problems - i.e. empty clusters occuring
                ex.printStackTrace();
                // System.err.println("Restarting after CV training failure
                // ("+num_clusters+" clusters");
                seed++;
                restartCount++;
                ok = false;
                if (restartCount > 5) {
                    break CLUSTER_SEARCH;
                }
                break;
            }
            try {
                tll = E(cvTest, false);
            } catch (Exception ex) {
                // catch any problems - i.e. empty clusters occuring
                // ex.printStackTrace();
                ex.printStackTrace();
                // System.err.println("Restarting after CV testing failure
                // ("+num_clusters+" clusters");
                // throw new Exception(ex);
                seed++;
                restartCount++;
                ok = false;
                if (restartCount > 5) {
                    break CLUSTER_SEARCH;
                }
                break;
            }

            if (m_verbose) {
                System.out.println("# clust: " + num_clusters + " Fold: " + i + " Loglikely: " + tll);
            }
            templl += tll;
        }

        if (ok) {
            restartCount = 0;
            seed = getSeed();
            templl /= (double) numFolds;

            if (m_verbose) {
                System.out.println("===================================" + "==============\n# clust: "
                        + num_clusters + " Mean Loglikely: " + templl + "\n================================"
                        + "=================");
            }

            if (templl > CVLogLikely) {
                CVLogLikely = templl;
                CVincreased = true;
                num_clusters++;
            }
        }
    }

    if (m_verbose) {
        System.out.println("Number of clusters: " + (num_clusters - 1));
    }

    m_num_clusters = num_clusters - 1;
}