List of usage examples for weka.classifiers.evaluation ThresholdCurve getCurve
public Instances getCurve(ArrayList<Prediction> predictions, int classIndex)
From source file:adams.data.conversion.WekaEvaluationToThresholdCurve.java
License:Open Source License
/** * Performs the actual conversion./* w ww .j av a 2 s . c o m*/ * * @return the converted data * @throws Exception if something goes wrong with the conversion */ @Override protected Object doConvert() throws Exception { Evaluation eval; ThresholdCurve curve; Instances cost; eval = (Evaluation) m_Input; m_ClassLabelIndex.setMax(eval.getHeader().classAttribute().numValues()); curve = new ThresholdCurve(); cost = curve.getCurve(eval.predictions(), m_ClassLabelIndex.getIntIndex()); return cost; }
From source file:adams.flow.sink.WekaCostBenefitAnalysis.java
License:Open Source License
/** * Plots the token (the panel and dialog have already been created at * this stage).// w w w . j a v a 2 s .c om * * @param token the token to display */ @Override protected void display(Token token) { Evaluation eval; Attribute classAtt; Attribute classAttToUse; int classValue; ThresholdCurve tc; Instances result; ArrayList<String> newNames; CostBenefitAnalysis cbAnalysis; PlotData2D tempd; boolean[] cp; int n; try { if (token.getPayload() instanceof WekaEvaluationContainer) eval = (Evaluation) ((WekaEvaluationContainer) token.getPayload()) .getValue(WekaEvaluationContainer.VALUE_EVALUATION); else eval = (Evaluation) token.getPayload(); if (eval.predictions() == null) { getLogger().severe("No predictions available from Evaluation object!"); return; } classAtt = eval.getHeader().classAttribute(); m_ClassIndex.setData(classAtt); classValue = m_ClassIndex.getIntIndex(); tc = new ThresholdCurve(); result = tc.getCurve(eval.predictions(), classValue); // Create a dummy class attribute with the chosen // class value as index 0 (if necessary). classAttToUse = eval.getHeader().classAttribute(); if (classValue != 0) { newNames = new ArrayList<>(); newNames.add(classAtt.value(classValue)); for (int k = 0; k < classAtt.numValues(); k++) { if (k != classValue) newNames.add(classAtt.value(k)); } classAttToUse = new Attribute(classAtt.name(), newNames); } // assemble plot data tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.m_alwaysDisplayPointsOfThisSize = 10; // specify which points are connected cp = new boolean[result.numInstances()]; for (n = 1; n < cp.length; n++) cp[n] = true; tempd.setConnectPoints(cp); // add plot m_CostBenefitPanel.setCurveData(tempd, classAttToUse); } catch (Exception e) { handleException("Failed to display token: " + token, e); } }
From source file:adams.flow.sink.WekaCostBenefitAnalysis.java
License:Open Source License
/** * Creates a new panel for the token.//from www.ja va2 s .co m * * @param token the token to display in a new panel, can be null * @return the generated panel */ public AbstractDisplayPanel createDisplayPanel(Token token) { AbstractDisplayPanel result; String name; if (token != null) name = "Cost curve (" + getEvaluation(token).getHeader().relationName() + ")"; else name = "Cost curve"; result = new AbstractComponentDisplayPanel(name) { private static final long serialVersionUID = -3513994354297811163L; protected CostBenefitAnalysis m_VisualizePanel; @Override protected void initGUI() { super.initGUI(); setLayout(new BorderLayout()); m_VisualizePanel = new CostBenefitAnalysis(); add(m_VisualizePanel, BorderLayout.CENTER); } @Override public void display(Token token) { try { Evaluation eval = getEvaluation(token); Attribute classAtt = eval.getHeader().classAttribute(); m_ClassIndex.setData(classAtt); int classValue = m_ClassIndex.getIntIndex(); ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(eval.predictions(), classValue); // Create a dummy class attribute with the chosen // class value as index 0 (if necessary). Attribute classAttToUse = eval.getHeader().classAttribute(); if (classValue != 0) { ArrayList<String> newNames = new ArrayList<>(); newNames.add(classAtt.value(classValue)); for (int k = 0; k < classAtt.numValues(); k++) { if (k != classValue) newNames.add(classAtt.value(k)); } classAttToUse = new Attribute(classAtt.name(), newNames); } // assemble plot data PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.m_alwaysDisplayPointsOfThisSize = 10; // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) cp[n] = true; tempd.setConnectPoints(cp); // add plot m_VisualizePanel.setCurveData(tempd, classAttToUse); } catch (Exception e) { getLogger().log(Level.SEVERE, "Failed to display token: " + token, e); } } @Override public JComponent supplyComponent() { return m_VisualizePanel; } @Override public void clearPanel() { } public void cleanUp() { } }; if (token != null) result.display(token); return result; }
From source file:adams.flow.sink.WekaThresholdCurve.java
License:Open Source License
/** * Plots the token (the panel and dialog have already been created at * this stage)./*from w w w .j a va 2s .c om*/ * * @param token the token to display */ @Override protected void display(Token token) { ThresholdCurve curve; Evaluation eval; PlotData2D plot; boolean[] connectPoints; int cp; Instances data; int[] indices; try { if (token.getPayload() instanceof WekaEvaluationContainer) eval = (Evaluation) ((WekaEvaluationContainer) token.getPayload()) .getValue(WekaEvaluationContainer.VALUE_EVALUATION); else eval = (Evaluation) token.getPayload(); if (eval.predictions() == null) { getLogger().severe("No predictions available from Evaluation object!"); return; } m_ClassLabelRange.setData(eval.getHeader().classAttribute()); indices = m_ClassLabelRange.getIntIndices(); for (int index : indices) { curve = new ThresholdCurve(); data = curve.getCurve(eval.predictions(), index); plot = new PlotData2D(data); plot.setPlotName(eval.getHeader().classAttribute().value(index)); plot.m_displayAllPoints = true; connectPoints = new boolean[data.numInstances()]; for (cp = 1; cp < connectPoints.length; cp++) connectPoints[cp] = true; plot.setConnectPoints(connectPoints); m_VisualizePanel.addPlot(plot); if (data.attribute(m_AttributeX.toDisplay()) != null) m_VisualizePanel.setXIndex(data.attribute(m_AttributeX.toDisplay()).index()); if (data.attribute(m_AttributeY.toDisplay()) != null) m_VisualizePanel.setYIndex(data.attribute(m_AttributeY.toDisplay()).index()); } } catch (Exception e) { handleException("Failed to display token: " + token, e); } }
From source file:adams.flow.sink.WekaThresholdCurve.java
License:Open Source License
/** * Creates a new panel for the token.//from w ww .j a v a 2 s.co m * * @param token the token to display in a new panel, can be null * @return the generated panel */ public AbstractDisplayPanel createDisplayPanel(Token token) { AbstractDisplayPanel result; String name; if (token != null) name = "Threshold curve (" + getEvaluation(token).getHeader().relationName() + ")"; else name = "Threshold curve"; result = new AbstractComponentDisplayPanel(name) { private static final long serialVersionUID = -7362768698548152899L; protected ThresholdVisualizePanel m_VisualizePanel; @Override protected void initGUI() { super.initGUI(); setLayout(new BorderLayout()); m_VisualizePanel = new ThresholdVisualizePanel(); add(m_VisualizePanel, BorderLayout.CENTER); } @Override public void display(Token token) { try { Evaluation eval = getEvaluation(token); m_ClassLabelRange.setMax(eval.getHeader().classAttribute().numValues()); int[] indices = m_ClassLabelRange.getIntIndices(); for (int index : indices) { ThresholdCurve curve = new ThresholdCurve(); Instances data = curve.getCurve(eval.predictions(), index); PlotData2D plot = new PlotData2D(data); plot.setPlotName(eval.getHeader().classAttribute().value(index)); plot.m_displayAllPoints = true; boolean[] connectPoints = new boolean[data.numInstances()]; for (int cp = 1; cp < connectPoints.length; cp++) connectPoints[cp] = true; plot.setConnectPoints(connectPoints); m_VisualizePanel.addPlot(plot); if (data.attribute(m_AttributeX.toDisplay()) != null) m_VisualizePanel.setXIndex(data.attribute(m_AttributeX.toDisplay()).index()); if (data.attribute(m_AttributeY.toDisplay()) != null) m_VisualizePanel.setYIndex(data.attribute(m_AttributeY.toDisplay()).index()); } } catch (Exception e) { getLogger().log(Level.SEVERE, "Failed to display token: " + token, e); } } @Override public JComponent supplyComponent() { return m_VisualizePanel; } @Override public void clearPanel() { m_VisualizePanel.removeAllPlots(); } public void cleanUp() { m_VisualizePanel.removeAllPlots(); } }; if (token != null) result.display(token); return result; }
From source file:bme.mace.logicdomain.Evaluation.java
License:Open Source License
/** * Returns the area under ROC for those predictions that have been collected * in the evaluateClassifier(Classifier, Instances) method. Returns * Instance.missingValue() if the area is not available. * //w w w.jav a 2 s. c o m * @param classIndex the index of the class to consider as "positive" * @return the area under the ROC curve or not a number */ public double areaUnderROC(int classIndex) { // Check if any predictions have been collected if (m_Predictions == null) { return Instance.missingValue(); } else { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions, classIndex); double rocArea = ThresholdCurve.getROCArea(result); if (rocArea < 0.5) { rocArea = 1 - rocArea; } int tpIndex = result.attribute(ThresholdCurve.TP_RATE_NAME).index(); int fpIndex = result.attribute(ThresholdCurve.FP_RATE_NAME).index(); double[] tpRate = result.attributeToDoubleArray(tpIndex); double[] fpRate = result.attributeToDoubleArray(fpIndex); try { FileWriter fw; if (classIndex == 0) fw = new FileWriter("C://1.csv", true); else fw = new FileWriter("C://1.csv", true); BufferedWriter bw = new BufferedWriter(fw); int length = fpRate.length; for (int i = 255; i >= 0; i--) { int index = i * (length - 1) / 255; bw.write(fpRate[index] + ","); } bw.write("\n"); for (int i = 255; i >= 0; i--) { int index = i * (length - 1) / 255; bw.write(tpRate[index] + ","); } bw.write("\n"); bw.close(); fw.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } return rocArea; } }
From source file:com.evaluation.ConfidenceLabelBasedMeasures.java
License:Open Source License
private void computeMeasures(MultiLabelOutput[] output, boolean[][] trueLabels) { int numLabels = trueLabels[0].length; // AUC//from w w w. jav a 2 s . c om FastVector[] m_Predictions = new FastVector[numLabels]; for (int j = 0; j < numLabels; j++) m_Predictions[j] = new FastVector(); FastVector all_Predictions = new FastVector(); int numInstances = output.length; for (int instanceIndex = 0; instanceIndex < numInstances; instanceIndex++) { double[] confidences = output[instanceIndex].getConfidences(); for (int labelIndex = 0; labelIndex < numLabels; labelIndex++) { int classValue; boolean actual = trueLabels[instanceIndex][labelIndex]; if (actual) classValue = 1; else classValue = 0; double[] dist = new double[2]; dist[1] = confidences[labelIndex]; dist[0] = 1 - dist[1]; m_Predictions[labelIndex].addElement(new NominalPrediction(classValue, dist, 1)); all_Predictions.addElement(new NominalPrediction(classValue, dist, 1)); } } labelAUC = new double[numLabels]; for (int i = 0; i < numLabels; i++) { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions[i], 1); labelAUC[i] = ThresholdCurve.getROCArea(result); } auc[Averaging.MACRO.ordinal()] = Utils.mean(labelAUC); ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(all_Predictions, 1); auc[Averaging.MICRO.ordinal()] = ThresholdCurve.getROCArea(result); }
From source file:com.sliit.views.DataVisualizerPanel.java
void getRocCurve() { try {// ww w . j av a2 s . com Instances data; data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.KNNView.java
void getRocCurve() { try {//from www .j av a2 s. co m Instances data; data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); rocPanel.removeAll(); rocPanel.add(vmc, "vmc", 0); rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.SVMView.java
/** * draw ROC curve//from w ww . j a va2s .c o m */ void getRocCurve() { try { Instances data; data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); //train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // rocPanel.removeAll(); // rocPanel.add(vmc, "vmc", 0); // rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }