Example usage for weka.classifiers Evaluation crossValidateModel

List of usage examples for weka.classifiers Evaluation crossValidateModel

Introduction

In this page you can find the example usage for weka.classifiers Evaluation crossValidateModel.

Prototype

public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random)
        throws Exception 

Source Link

Document

Performs a (stratified if class is nominal) cross-validation for a classifier on a set of instances.

Usage

From source file:binarizer.LayoutAnalysis.java

public double crossValidation(String arffFile) throws Exception {
    DataSource source = new DataSource(arffFile);
    Instances trainingData = source.getDataSet();
    if (trainingData.classIndex() == -1)
        trainingData.setClassIndex(trainingData.numAttributes() - 1);
    NaiveBayes nb = new NaiveBayes();
    nb.setUseSupervisedDiscretization(true);
    Evaluation evaluation = new Evaluation(trainingData);
    evaluation.crossValidateModel(nb, trainingData, 10, new Random(1));
    System.out.println(evaluation.toSummaryString());
    return evaluation.errorRate();
}

From source file:c4.pkg5crossv.Classifier.java

public static void trainAndTest() throws FileNotFoundException, IOException, Exception {

    Instances data = DataLoad.loadData("./src/data/irysy.arff");
    data.setClassIndex(data.numAttributes() - 1);

    //Losowy podzial tablicy
    data.randomize(new Random());
    double percent = 60.0;
    int trainSize = (int) Math.round(data.numInstances() * percent / 100);
    int testSize = data.numInstances() - trainSize;
    Instances trainData = new Instances(data, 0, trainSize);
    Instances testData = new Instances(data, trainSize, testSize);

    String[] options = Utils.splitOptions("-U -M 10");
    J48 tree = new J48();
    tree.setOptions(options);//from  w w  w  .  j a  va  2s  .  c  o  m
    tree.buildClassifier(trainData);

    Evaluation eval2 = new Evaluation(trainData);
    eval2.crossValidateModel(tree, testData, 10, new Random(1)); // 5 - fold
    System.out.println(eval2.toSummaryString("Wyniki:", false)); //Wypisanie testovania cross validation
}

From source file:ca.uottawa.balie.WekaLearner.java

License:Open Source License

/**
 * Approximate training set error.//from w ww.  j a va  2  s  .  c om
 * 
 * @return evaluation module from which many types of errors are exposed (e.g.: mean absolute error)
 */
public Evaluation EstimateConfidence() {
    Evaluation evaluation = null;
    try {
        evaluation = new Evaluation(m_TrainingSet);
        evaluation.crossValidateModel(m_Scheme, m_TrainingSet, 10, new Random());
    } catch (Exception e) {
        System.out.println(e.getMessage());
    }
    // which error is the best? depends on the application.
    return evaluation;
}

From source file:ca.uqac.florentinth.speakerauthentication.Learning.Learning.java

License:Apache License

public void trainClassifier(Classifier classifier, FileReader trainingDataset, FileOutputStream trainingModel,
        Integer crossValidationFoldNumber) throws Exception {
    Instances instances = new Instances(new BufferedReader(trainingDataset));

    switch (classifier) {
    case KNN://from  w ww .j  a va 2s .  co  m
        int K = (int) Math.ceil(Math.sqrt(instances.numInstances()));
        this.classifier = new IBk(K);
        break;
    case NB:
        this.classifier = new NaiveBayes();
    }

    if (instances.classIndex() == -1) {
        instances.setClassIndex(instances.numAttributes() - 1);
    }

    this.classifier.buildClassifier(instances);

    if (crossValidationFoldNumber > 0) {
        Evaluation evaluation = new Evaluation(instances);
        evaluation.crossValidateModel(this.classifier, instances, crossValidationFoldNumber, new Random(1));
        kappa = evaluation.kappa();
        fMeasure = evaluation.weightedFMeasure();
        confusionMatrix = evaluation.toMatrixString("Confusion matrix: ");
    }

    ObjectOutputStream outputStream = new ObjectOutputStream(trainingModel);
    outputStream.writeObject(this.classifier);
    outputStream.flush();
    outputStream.close();
}

From source file:classif.ExperimentsLauncher.java

License:Open Source License

public void launchseq() {
    try {//  ww w.  ja  v  a 2  s  . co m
        nbPrototypesMax = 10;
        int[] bestprototypes = new int[train.numClasses()];
        double lowerror = 1.0;
        for (int j = 1; j <= nbPrototypesMax; j++) {
            int[] nbPrototypesPerClass = new int[train.numClasses()];
            for (int i = 0; i < train.numClasses(); i++) {
                nbPrototypesPerClass[i] = j;
            }
            double errorBefore = 1;
            double errorNow = 1;
            int flag = 0;
            do {
                Unbalancecluster classifierseq = new Unbalancecluster();
                classifierseq.setNbPrototypesPerClass(nbPrototypesPerClass);
                System.out.println(Arrays.toString(nbPrototypesPerClass));
                //               classifierseq.buildClassifier(train);
                Evaluation evalcv = new Evaluation(train);
                Random rand = new Random(1);
                evalcv.crossValidateModel(classifierseq, train, 10, rand);
                //               errorNow = classifierseq.predictAccuracyXVal(10);
                errorNow = evalcv.errorRate();
                System.out.println("errorBefore " + errorBefore);
                System.out.println("errorNow " + errorNow);
                if (errorNow < errorBefore) {
                    nbPrototypesPerClass[flag]++;
                    errorBefore = errorNow;
                } else {
                    nbPrototypesPerClass[flag]--;
                    flag++;
                    if (flag >= nbPrototypesPerClass.length)
                        break;
                    nbPrototypesPerClass[flag]++;
                }
            } while (flag < nbPrototypesPerClass.length);
            //            System.out.println("\nbest nbPrototypesPerClass " + Arrays.toString(nbPrototypesPerClass));
            double testError = 0;
            for (int n = 0; n < nbExp; n++) {
                Unbalancecluster classifier = new Unbalancecluster();
                classifier.setNbPrototypesPerClass(nbPrototypesPerClass);
                classifier.buildClassifier(train);
                Evaluation evaltest = new Evaluation(train);
                evaltest.evaluateModel(classifier, test);
                testError += evaltest.errorRate();
            }
            double avgTestError = testError / nbExp;
            System.out.println(avgTestError);
            if (avgTestError < lowerror) {
                bestprototypes = nbPrototypesPerClass;
                lowerror = avgTestError;
            }
        }
        System.out.println("Best prototypes:" + Arrays.toString(bestprototypes) + "\n");
        System.out.println("Best errorRate:" + lowerror + "\n");
    } catch (Exception e) {
        e.printStackTrace();
    }
}

From source file:classify.Classifier.java

/**
 * @param args the command line arguments
 *///  w ww . ja va  2  s  . c om
public static void main(String[] args) {
    //read in data
    try {
        DataSource input = new DataSource("no_missing_values.csv");
        Instances data = input.getDataSet();
        //Instances data = readFile("newfixed.txt");
        missingValuesRows(data);

        setAttributeValues(data);
        data.setClassIndex(data.numAttributes() - 1);

        //boosting
        AdaBoostM1 boosting = new AdaBoostM1();
        boosting.setNumIterations(25);
        boosting.setClassifier(new DecisionStump());

        //build the classifier
        boosting.buildClassifier(data);

        //evaluate using 10-fold cross validation
        Evaluation e1 = new Evaluation(data);
        e1.crossValidateModel(boosting, data, 10, new Random(1));

        DecimalFormat nf = new DecimalFormat("0.000");

        System.out.println("Results of Boosting with Decision Stumps:");
        System.out.println(boosting.toString());
        System.out.println("Results of Cross Validation:");
        System.out.println("Number of correctly classified instances: " + e1.correct() + " ("
                + nf.format(e1.pctCorrect()) + "%)");
        System.out.println("Number of incorrectly classified instances: " + e1.incorrect() + " ("
                + nf.format(e1.pctIncorrect()) + "%)");

        System.out.println("TP Rate: " + nf.format(e1.weightedTruePositiveRate() * 100) + "%");
        System.out.println("FP Rate: " + nf.format(e1.weightedFalsePositiveRate() * 100) + "%");
        System.out.println("Precision: " + nf.format(e1.weightedPrecision() * 100) + "%");
        System.out.println("Recall: " + nf.format(e1.weightedRecall() * 100) + "%");

        System.out.println();
        System.out.println("Confusion Matrix:");
        for (int i = 0; i < e1.confusionMatrix().length; i++) {
            for (int j = 0; j < e1.confusionMatrix()[0].length; j++) {
                System.out.print(e1.confusionMatrix()[i][j] + "   ");
            }
            System.out.println();
        }
        System.out.println();
        System.out.println();
        System.out.println();

        //logistic regression
        Logistic l = new Logistic();
        l.buildClassifier(data);

        e1 = new Evaluation(data);

        e1.crossValidateModel(l, data, 10, new Random(1));
        System.out.println("Results of Logistic Regression:");
        System.out.println(l.toString());
        System.out.println("Results of Cross Validation:");
        System.out.println("Number of correctly classified instances: " + e1.correct() + " ("
                + nf.format(e1.pctCorrect()) + "%)");
        System.out.println("Number of incorrectly classified instances: " + e1.incorrect() + " ("
                + nf.format(e1.pctIncorrect()) + "%)");

        System.out.println("TP Rate: " + nf.format(e1.weightedTruePositiveRate() * 100) + "%");
        System.out.println("FP Rate: " + nf.format(e1.weightedFalsePositiveRate() * 100) + "%");
        System.out.println("Precision: " + nf.format(e1.weightedPrecision() * 100) + "%");
        System.out.println("Recall: " + nf.format(e1.weightedRecall() * 100) + "%");

        System.out.println();
        System.out.println("Confusion Matrix:");
        for (int i = 0; i < e1.confusionMatrix().length; i++) {
            for (int j = 0; j < e1.confusionMatrix()[0].length; j++) {
                System.out.print(e1.confusionMatrix()[i][j] + "   ");
            }
            System.out.println();
        }

    } catch (Exception ex) {
        //data couldn't be read, so end program
        System.out.println("Exception thrown, program ending.");
    }
}

From source file:com.daniel.convert.IncrementalClassifier.java

License:Open Source License

/**
 * Expects an ARFF file as first argument (class attribute is assumed to be
 * the last attribute).//from   www .j a  v a  2 s. co  m
 * 
 * @param args
 *            the commandline arguments
 * @throws Exception
 *             if something goes wrong
 */
public static BayesNet treinar(String[] args) throws Exception {
    // load data
    ArffLoader loader = new ArffLoader();
    loader.setFile(new File(args[0]));
    Instances structure = loader.getStructure();
    structure.setClassIndex(structure.numAttributes() - 1);

    // train NaiveBayes
    BayesNet BayesNet = new BayesNet();

    Instance current;
    while ((current = loader.getNextInstance(structure)) != null) {
        structure.add(current);
    }
    BayesNet.buildClassifier(structure);

    // output generated model
    // System.out.println(nb);

    // test set
    BayesNet BayesNetTest = new BayesNet();

    // test the model
    Evaluation eTest = new Evaluation(structure);
    // eTest.evaluateModel(nb, structure);
    eTest.crossValidateModel(BayesNetTest, structure, 15, new Random(1));

    // Print the result  la Weka explorer:
    String strSummary = eTest.toSummaryString();
    System.out.println(strSummary);

    return BayesNet;
}

From source file:com.guidefreitas.locator.services.PredictionService.java

public Evaluation train() {
    try {//w  w  w .  java2 s . co  m
        String arffData = this.generateTrainData();
        InputStream stream = new ByteArrayInputStream(arffData.getBytes(StandardCharsets.UTF_8));
        DataSource source = new DataSource(stream);
        Instances data = source.getDataSet();
        data.setClassIndex(data.numAttributes() - 1);
        this.classifier = new LibSVM();
        this.classifier.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_POLYNOMIAL, LibSVM.TAGS_KERNELTYPE));
        this.classifier.setSVMType(new SelectedTag(LibSVM.SVMTYPE_C_SVC, LibSVM.TAGS_SVMTYPE));

        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(this.classifier, data, 10, new Random(1));

        this.classifier.buildClassifier(data);
        return eval;
    } catch (Exception ex) {
        Logger.getLogger(PredictionService.class.getName()).log(Level.SEVERE, null, ex);
    }

    return null;
}

From source file:com.ivanrf.smsspam.SpamClassifier.java

License:Apache License

public static void evaluate(int wordsToKeep, String tokenizerOp, boolean useAttributeSelection,
        String classifierOp, boolean boosting, JTextArea log) {
    try {//from ww w.  j  a  va  2 s  .com
        long start = System.currentTimeMillis();

        String modelName = getModelName(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp,
                boosting);
        showEstimatedTime(false, modelName, log);

        Instances trainData = loadDataset("SMSSpamCollection.arff", log);
        trainData.setClassIndex(0);
        FilteredClassifier classifier = initFilterClassifier(wordsToKeep, tokenizerOp, useAttributeSelection,
                classifierOp, boosting);

        publishEstado("=== Performing cross-validation ===", log);
        Evaluation eval = new Evaluation(trainData);
        //         eval.evaluateModel(classifier, trainData);
        eval.crossValidateModel(classifier, trainData, 10, new Random(1));

        publishEstado(eval.toSummaryString(), log);
        publishEstado(eval.toClassDetailsString(), log);
        publishEstado(eval.toMatrixString(), log);
        publishEstado("=== Evaluation finished ===", log);

        publishEstado("Elapsed time: " + Utils.getDateHsMinSegString(System.currentTimeMillis() - start), log);
    } catch (Exception e) {
        e.printStackTrace();
        publishEstado("Error found when evaluating", log);
    }
}

From source file:com.sliit.views.DataVisualizerPanel.java

void getRocCurve() {
    try {// w  ww.  j a  v a  2  s .  c o m
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        // display curve
        String plotName = vmc.getName();
        final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
        jf.setSize(500, 400);
        jf.getContentPane().setLayout(new BorderLayout());
        jf.getContentPane().add(vmc, BorderLayout.CENTER);
        jf.addWindowListener(new java.awt.event.WindowAdapter() {
            public void windowClosing(java.awt.event.WindowEvent e) {
                jf.dispose();
            }
        });
        jf.setVisible(true);
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}