List of usage examples for weka.classifiers.evaluation ThresholdCurve getROCArea
public static double getROCArea(Instances tcurve)
From source file:adams.gui.menu.ROC.java
License:Open Source License
/** * Launches the functionality of the menu item. *///from w w w.java2s. c om @Override public void launch() { File file; if (m_Parameters.length == 0) { // choose file int retVal = m_FileChooser.showOpenDialog(null); if (retVal != JFileChooser.APPROVE_OPTION) return; file = m_FileChooser.getSelectedFile(); } else { file = new PlaceholderFile(m_Parameters[0]).getAbsoluteFile(); m_FileChooser.setSelectedFile(file); } // create plot Instances result; try { result = m_FileChooser.getLoader().getDataSet(); } catch (Exception e) { GUIHelper.showErrorMessage(getOwner(), "Error loading file '" + file + "':\n" + adams.core.Utils.throwableToString(e)); return; } result.setClassIndex(result.numAttributes() - 1); ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) cp[n] = true; try { tempd.setConnectPoints(cp); vmc.addPlot(tempd); } catch (Exception e) { GUIHelper.showErrorMessage(getOwner(), "Error adding plot:\n" + adams.core.Utils.throwableToString(e)); return; } ChildFrame frame = createChildFrame(vmc, GUIHelper.getDefaultDialogDimension()); frame.setTitle(frame.getTitle() + " - " + file); }
From source file:bme.mace.logicdomain.Evaluation.java
License:Open Source License
/** * Returns the area under ROC for those predictions that have been collected * in the evaluateClassifier(Classifier, Instances) method. Returns * Instance.missingValue() if the area is not available. * //from w w w .j av a 2 s .c o m * @param classIndex the index of the class to consider as "positive" * @return the area under the ROC curve or not a number */ public double areaUnderROC(int classIndex) { // Check if any predictions have been collected if (m_Predictions == null) { return Instance.missingValue(); } else { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions, classIndex); double rocArea = ThresholdCurve.getROCArea(result); if (rocArea < 0.5) { rocArea = 1 - rocArea; } int tpIndex = result.attribute(ThresholdCurve.TP_RATE_NAME).index(); int fpIndex = result.attribute(ThresholdCurve.FP_RATE_NAME).index(); double[] tpRate = result.attributeToDoubleArray(tpIndex); double[] fpRate = result.attributeToDoubleArray(fpIndex); try { FileWriter fw; if (classIndex == 0) fw = new FileWriter("C://1.csv", true); else fw = new FileWriter("C://1.csv", true); BufferedWriter bw = new BufferedWriter(fw); int length = fpRate.length; for (int i = 255; i >= 0; i--) { int index = i * (length - 1) / 255; bw.write(fpRate[index] + ","); } bw.write("\n"); for (int i = 255; i >= 0; i--) { int index = i * (length - 1) / 255; bw.write(tpRate[index] + ","); } bw.write("\n"); bw.close(); fw.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } return rocArea; } }
From source file:com.evaluation.ConfidenceLabelBasedMeasures.java
License:Open Source License
private void computeMeasures(MultiLabelOutput[] output, boolean[][] trueLabels) { int numLabels = trueLabels[0].length; // AUC/*w ww . j ava2 s.co m*/ FastVector[] m_Predictions = new FastVector[numLabels]; for (int j = 0; j < numLabels; j++) m_Predictions[j] = new FastVector(); FastVector all_Predictions = new FastVector(); int numInstances = output.length; for (int instanceIndex = 0; instanceIndex < numInstances; instanceIndex++) { double[] confidences = output[instanceIndex].getConfidences(); for (int labelIndex = 0; labelIndex < numLabels; labelIndex++) { int classValue; boolean actual = trueLabels[instanceIndex][labelIndex]; if (actual) classValue = 1; else classValue = 0; double[] dist = new double[2]; dist[1] = confidences[labelIndex]; dist[0] = 1 - dist[1]; m_Predictions[labelIndex].addElement(new NominalPrediction(classValue, dist, 1)); all_Predictions.addElement(new NominalPrediction(classValue, dist, 1)); } } labelAUC = new double[numLabels]; for (int i = 0; i < numLabels; i++) { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions[i], 1); labelAUC[i] = ThresholdCurve.getROCArea(result); } auc[Averaging.MACRO.ordinal()] = Utils.mean(labelAUC); ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(all_Predictions, 1); auc[Averaging.MICRO.ordinal()] = ThresholdCurve.getROCArea(result); }
From source file:com.sliit.views.DataVisualizerPanel.java
void getRocCurve() { try {/*from www. jav a2s .co m*/ Instances data; data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.KNNView.java
void getRocCurve() { try {// w ww . ja va 2s .c o m Instances data; data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); rocPanel.removeAll(); rocPanel.add(vmc, "vmc", 0); rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.SVMView.java
/** * draw ROC curve// w ww.ja va2 s . c o m */ void getRocCurve() { try { Instances data; data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); //train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // rocPanel.removeAll(); // rocPanel.add(vmc, "vmc", 0); // rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:cotraining.copy.Evaluation_D.java
License:Open Source License
/** * Returns the area under ROC for those predictions that have been collected * in the evaluateClassifier(Classifier, Instances) method. Returns * Instance.missingValue() if the area is not available. * * @param classIndex the index of the class to consider as "positive" * @return the area under the ROC curve or not a number *//* w ww . ja v a2 s. co m*/ public double areaUnderROC(int classIndex) { // Check if any predictions have been collected if (m_Predictions == null) { return Instance.missingValue(); } else { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions, classIndex); return ThresholdCurve.getROCArea(result); } }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied learner. * @return the AUC as a double value./* ww w . ja v a 2s .c o m*/ */ @SuppressWarnings("unused") private static double validate5x2CV() { try { // other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier String[] options = { "-U", "-A" }; J48 classifier = new J48(); //HTree classifier = new HTree(); classifier.setOptions(options); classifier.buildClassifier(train); eval.evaluateModel(classifier, test); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); AUC_SUM += ThresholdCurve.getROCArea(result); System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied learner. * @param learner the learning algorithm to use. * @return the AUC as a double value./*from w w w . ja v a 2s . c o m*/ */ @SuppressWarnings("unused") private static double validate(Classifier learner) { try { Evaluation eval = new Evaluation(data); eval.crossValidateModel(learner, data, 2, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); double AUC = ThresholdCurve.getROCArea(result); vmc.setROCString( "(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 9) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) cp[n] = true; tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); return AUC; } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } }
From source file:meka.core.Metrics.java
License:Open Source License
/** Calculate AUROC: Area Under the ROC curve. */ public static double P_macroAUROC(int Y[][], double P[][]) { // works with missing int L = Y[0].length; double AUC[] = new double[L]; for (int j = 0; j < L; j++) { if (allMissing(Y[j])) { L--;/*from w ww . j ava 2 s .co m*/ continue; } ThresholdCurve curve = new ThresholdCurve(); Instances result = curve .getCurve(MLUtils.toWekaPredictions(MatrixUtils.getCol(Y, j), MatrixUtils.getCol(P, j))); AUC[j] = ThresholdCurve.getROCArea(result); } return Utils.mean(AUC); }