Example usage for weka.classifiers.evaluation ThresholdCurve getROCArea

List of usage examples for weka.classifiers.evaluation ThresholdCurve getROCArea

Introduction

In this page you can find the example usage for weka.classifiers.evaluation ThresholdCurve getROCArea.

Prototype

public static double getROCArea(Instances tcurve) 

Source Link

Document

Calculates the area under the ROC curve as the Wilcoxon-Mann-Whitney statistic.

Usage

From source file:adams.gui.menu.ROC.java

License:Open Source License

/**
 * Launches the functionality of the menu item.
 *///from  w w  w.java2s.  c  om
@Override
public void launch() {
    File file;
    if (m_Parameters.length == 0) {
        // choose file
        int retVal = m_FileChooser.showOpenDialog(null);
        if (retVal != JFileChooser.APPROVE_OPTION)
            return;
        file = m_FileChooser.getSelectedFile();
    } else {
        file = new PlaceholderFile(m_Parameters[0]).getAbsoluteFile();
        m_FileChooser.setSelectedFile(file);
    }

    // create plot
    Instances result;
    try {
        result = m_FileChooser.getLoader().getDataSet();
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(),
                "Error loading file '" + file + "':\n" + adams.core.Utils.throwableToString(e));
        return;
    }
    result.setClassIndex(result.numAttributes() - 1);
    ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
    vmc.setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");
    vmc.setName(result.relationName());
    PlotData2D tempd = new PlotData2D(result);
    tempd.setPlotName(result.relationName());
    tempd.addInstanceNumberAttribute();
    // specify which points are connected
    boolean[] cp = new boolean[result.numInstances()];
    for (int n = 1; n < cp.length; n++)
        cp[n] = true;
    try {
        tempd.setConnectPoints(cp);
        vmc.addPlot(tempd);
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(), "Error adding plot:\n" + adams.core.Utils.throwableToString(e));
        return;
    }

    ChildFrame frame = createChildFrame(vmc, GUIHelper.getDefaultDialogDimension());
    frame.setTitle(frame.getTitle() + " - " + file);
}

From source file:bme.mace.logicdomain.Evaluation.java

License:Open Source License

/**
 * Returns the area under ROC for those predictions that have been collected
 * in the evaluateClassifier(Classifier, Instances) method. Returns
 * Instance.missingValue() if the area is not available.
 * //from   w  w  w  .j av  a 2  s .c o m
 * @param classIndex the index of the class to consider as "positive"
 * @return the area under the ROC curve or not a number
 */
public double areaUnderROC(int classIndex) {

    // Check if any predictions have been collected
    if (m_Predictions == null) {
        return Instance.missingValue();
    } else {
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(m_Predictions, classIndex);
        double rocArea = ThresholdCurve.getROCArea(result);
        if (rocArea < 0.5) {
            rocArea = 1 - rocArea;
        }

        int tpIndex = result.attribute(ThresholdCurve.TP_RATE_NAME).index();
        int fpIndex = result.attribute(ThresholdCurve.FP_RATE_NAME).index();
        double[] tpRate = result.attributeToDoubleArray(tpIndex);
        double[] fpRate = result.attributeToDoubleArray(fpIndex);

        try {
            FileWriter fw;
            if (classIndex == 0)
                fw = new FileWriter("C://1.csv", true);
            else
                fw = new FileWriter("C://1.csv", true);

            BufferedWriter bw = new BufferedWriter(fw);
            int length = fpRate.length;
            for (int i = 255; i >= 0; i--) {

                int index = i * (length - 1) / 255;
                bw.write(fpRate[index] + ",");
            }
            bw.write("\n");
            for (int i = 255; i >= 0; i--) {
                int index = i * (length - 1) / 255;
                bw.write(tpRate[index] + ",");
            }
            bw.write("\n");

            bw.close();
            fw.close();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        return rocArea;
    }
}

From source file:com.evaluation.ConfidenceLabelBasedMeasures.java

License:Open Source License

private void computeMeasures(MultiLabelOutput[] output, boolean[][] trueLabels) {
    int numLabels = trueLabels[0].length;

    // AUC/*w ww  .  j ava2  s.co  m*/
    FastVector[] m_Predictions = new FastVector[numLabels];
    for (int j = 0; j < numLabels; j++)
        m_Predictions[j] = new FastVector();
    FastVector all_Predictions = new FastVector();

    int numInstances = output.length;
    for (int instanceIndex = 0; instanceIndex < numInstances; instanceIndex++) {
        double[] confidences = output[instanceIndex].getConfidences();
        for (int labelIndex = 0; labelIndex < numLabels; labelIndex++) {

            int classValue;
            boolean actual = trueLabels[instanceIndex][labelIndex];
            if (actual)
                classValue = 1;
            else
                classValue = 0;

            double[] dist = new double[2];
            dist[1] = confidences[labelIndex];
            dist[0] = 1 - dist[1];

            m_Predictions[labelIndex].addElement(new NominalPrediction(classValue, dist, 1));
            all_Predictions.addElement(new NominalPrediction(classValue, dist, 1));
        }
    }

    labelAUC = new double[numLabels];
    for (int i = 0; i < numLabels; i++) {
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(m_Predictions[i], 1);
        labelAUC[i] = ThresholdCurve.getROCArea(result);
    }
    auc[Averaging.MACRO.ordinal()] = Utils.mean(labelAUC);
    ThresholdCurve tc = new ThresholdCurve();
    Instances result = tc.getCurve(all_Predictions, 1);
    auc[Averaging.MICRO.ordinal()] = ThresholdCurve.getROCArea(result);
}

From source file:com.sliit.views.DataVisualizerPanel.java

void getRocCurve() {
    try {/*from   www. jav  a2s  .co  m*/
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        // display curve
        String plotName = vmc.getName();
        final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
        jf.setSize(500, 400);
        jf.getContentPane().setLayout(new BorderLayout());
        jf.getContentPane().add(vmc, BorderLayout.CENTER);
        jf.addWindowListener(new java.awt.event.WindowAdapter() {
            public void windowClosing(java.awt.event.WindowEvent e) {
                jf.dispose();
            }
        });
        jf.setVisible(true);
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.sliit.views.KNNView.java

void getRocCurve() {
    try {//  w  ww  .  ja va 2s  .c o m
        Instances data;
        data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        rocPanel.removeAll();
        rocPanel.add(vmc, "vmc", 0);
        rocPanel.revalidate();

    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.sliit.views.SVMView.java

/**
 * draw ROC curve//  w  ww.ja va2  s  .  c  o  m
 */
void getRocCurve() {
    try {
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        //train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        //            rocPanel.removeAll();
        //            rocPanel.add(vmc, "vmc", 0);
        //            rocPanel.revalidate();
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Returns the area under ROC for those predictions that have been collected
 * in the evaluateClassifier(Classifier, Instances) method. Returns 
 * Instance.missingValue() if the area is not available.
 *
 * @param classIndex the index of the class to consider as "positive"
 * @return the area under the ROC curve or not a number
 *//*  w ww  .  ja v  a2 s.  co  m*/
public double areaUnderROC(int classIndex) {

    // Check if any predictions have been collected
    if (m_Predictions == null) {
        return Instance.missingValue();
    } else {
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(m_Predictions, classIndex);
        return ThresholdCurve.getROCArea(result);
    }
}

From source file:cs.man.ac.uk.classifiers.GetAUC.java

License:Open Source License

/**
 * Computes the AUC for the supplied learner.
 * @return the AUC as a double value./* ww w  . ja v  a  2s  .c o  m*/
 */
@SuppressWarnings("unused")
private static double validate5x2CV() {
    try {
        // other options
        int runs = 5;
        int folds = 2;
        double AUC_SUM = 0;

        // perform cross-validation
        for (int i = 0; i < runs; i++) {
            // randomize data
            int seed = i + 1;
            Random rand = new Random(seed);
            Instances randData = new Instances(data);
            randData.randomize(rand);

            if (randData.classAttribute().isNominal()) {
                System.out.println("Stratifying...");
                randData.stratify(folds);
            }

            Evaluation eval = new Evaluation(randData);

            for (int n = 0; n < folds; n++) {
                Instances train = randData.trainCV(folds, n);
                Instances test = randData.testCV(folds, n);

                // the above code is used by the StratifiedRemoveFolds filter, the
                // code below by the Explorer/Experimenter:
                // Instances train = randData.trainCV(folds, n, rand);

                // build and evaluate classifier
                String[] options = { "-U", "-A" };
                J48 classifier = new J48();
                //HTree classifier = new HTree();

                classifier.setOptions(options);
                classifier.buildClassifier(train);
                eval.evaluateModel(classifier, test);

                // generate curve
                ThresholdCurve tc = new ThresholdCurve();
                int classIndex = 0;
                Instances result = tc.getCurve(eval.predictions(), classIndex);

                // plot curve
                vmc = new ThresholdVisualizePanel();
                AUC_SUM += ThresholdCurve.getROCArea(result);
                System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM);
            }
        }

        return AUC_SUM / ((double) runs * (double) folds);
    } catch (Exception e) {
        System.out.println("Exception validating data!");
        return 0;
    }
}

From source file:cs.man.ac.uk.classifiers.GetAUC.java

License:Open Source License

/**
 * Computes the AUC for the supplied learner.
 * @param learner the learning algorithm to use.
 * @return the AUC as a double value./*from  w  w  w .  ja v  a 2s . c o m*/
 */
@SuppressWarnings("unused")
private static double validate(Classifier learner) {
    try {

        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(learner, data, 2, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        vmc = new ThresholdVisualizePanel();
        double AUC = ThresholdCurve.getROCArea(result);
        vmc.setROCString(
                "(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 9) + ")");
        vmc.setName(result.relationName());

        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();

        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++)
            cp[n] = true;

        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        return AUC;
    } catch (Exception e) {
        System.out.println("Exception validating data!");
        return 0;
    }
}

From source file:meka.core.Metrics.java

License:Open Source License

/** Calculate AUROC: Area Under the ROC curve. */
public static double P_macroAUROC(int Y[][], double P[][]) {
    // works with missing
    int L = Y[0].length;
    double AUC[] = new double[L];
    for (int j = 0; j < L; j++) {
        if (allMissing(Y[j])) {
            L--;/*from w ww .  j ava  2  s .co  m*/
            continue;
        }
        ThresholdCurve curve = new ThresholdCurve();
        Instances result = curve
                .getCurve(MLUtils.toWekaPredictions(MatrixUtils.getCol(Y, j), MatrixUtils.getCol(P, j)));
        AUC[j] = ThresholdCurve.getROCArea(result);
    }
    return Utils.mean(AUC);
}