List of usage examples for weka.classifiers.evaluation ThresholdCurve ThresholdCurve
ThresholdCurve
From source file:miRdup.WekaModule.java
License:Open Source License
public static void trainModel(File arff, String keyword) { dec.setMaximumFractionDigits(3);/*from ww w . j av a2s.c om*/ System.out.println("\nTraining model on file " + arff); try { // load data DataSource source = new DataSource(arff.toString()); Instances data = source.getDataSet(); if (data.classIndex() == -1) { data.setClassIndex(data.numAttributes() - 1); } PrintWriter pwout = new PrintWriter(new FileWriter(keyword + Main.modelExtension + "Output")); PrintWriter pwroc = new PrintWriter(new FileWriter(keyword + Main.modelExtension + "roc.arff")); //remove ID row Remove rm = new Remove(); rm.setAttributeIndices("1"); FilteredClassifier fc = new FilteredClassifier(); fc.setFilter(rm); // // train model svm // weka.classifiers.functions.LibSVM model = new weka.classifiers.functions.LibSVM(); // model.setOptions(weka.core.Utils.splitOptions("-S 0 -K 2 -D 3 -G 0.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.0010 -P 0.1 -B")); // train model MultilayerPerceptron // weka.classifiers.functions.MultilayerPerceptron model = new weka.classifiers.functions.MultilayerPerceptron(); // model.setOptions(weka.core.Utils.splitOptions("-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a")); // train model Adaboost on RIPPER // weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1(); // model.setOptions(weka.core.Utils.splitOptions("weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.rules.JRip -- -F 10 -N 2.0 -O 5 -S 1")); // train model Adaboost on FURIA // weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1(); // model.setOptions(weka.core.Utils.splitOptions("weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.rules.FURIA -- -F 10 -N 2.0 -O 5 -S 1 -p 0 -s 0")); //train model Adaboot on J48 trees // weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1(); // model.setOptions( // weka.core.Utils.splitOptions( // "-P 100 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -C 0.25 -M 2")); //train model Adaboot on Random Forest trees weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1(); model.setOptions(weka.core.Utils .splitOptions("-P 100 -S 1 -I 10 -W weka.classifiers.trees.RandomForest -- -I 50 -K 0 -S 1")); if (Main.debug) { System.out.print("Model options: " + model.getClass().getName().trim() + " "); } System.out.print(model.getClass() + " "); for (String s : model.getOptions()) { System.out.print(s + " "); } pwout.print("Model options: " + model.getClass().getName().trim() + " "); for (String s : model.getOptions()) { pwout.print(s + " "); } //build model // model.buildClassifier(data); fc.setClassifier(model); fc.buildClassifier(data); // cross validation 10 times on the model Evaluation eval = new Evaluation(data); //eval.crossValidateModel(model, data, 10, new Random(1)); StringBuffer sb = new StringBuffer(); eval.crossValidateModel(fc, data, 10, new Random(1), sb, new Range("first,last"), false); //System.out.println(sb); pwout.println(sb); pwout.flush(); // output pwout.println("\n" + eval.toSummaryString()); System.out.println(eval.toSummaryString()); pwout.println(eval.toClassDetailsString()); System.out.println(eval.toClassDetailsString()); //calculate importants values String ev[] = eval.toClassDetailsString().split("\n"); String ptmp[] = ev[3].trim().split(" "); String ntmp[] = ev[4].trim().split(" "); String avgtmp[] = ev[5].trim().split(" "); ArrayList<String> p = new ArrayList<String>(); ArrayList<String> n = new ArrayList<String>(); ArrayList<String> avg = new ArrayList<String>(); for (String s : ptmp) { if (!s.trim().isEmpty()) { p.add(s); } } for (String s : ntmp) { if (!s.trim().isEmpty()) { n.add(s); } } for (String s : avgtmp) { if (!s.trim().isEmpty()) { avg.add(s); } } double tp = Double.parseDouble(p.get(0)); double fp = Double.parseDouble(p.get(1)); double tn = Double.parseDouble(n.get(0)); double fn = Double.parseDouble(n.get(1)); double auc = Double.parseDouble(avg.get(7)); pwout.println("\nTP=" + tp + "\nFP=" + fp + "\nTN=" + tn + "\nFN=" + fn); System.out.println("\nTP=" + tp + "\nFP=" + fp + "\nTN=" + tn + "\nFN=" + fn); //specificity, sensitivity, Mathew's correlation, Prediction accuracy double sp = ((tn) / (tn + fp)); double se = ((tp) / (tp + fn)); double acc = ((tp + tn) / (tp + tn + fp + fn)); double mcc = ((tp * tn) - (fp * fn)) / Math.sqrt((tp + fp) * (tn + fn) * (tp + fn) * tn + fp); String output = "\nse=" + dec.format(se).replace(",", ".") + "\nsp=" + dec.format(sp).replace(",", ".") + "\nACC=" + dec.format(acc).replace(",", ".") + "\nMCC=" + dec.format(mcc).replace(",", ".") + "\nAUC=" + dec.format(auc).replace(",", "."); pwout.println(output); System.out.println(output); pwout.println(eval.toMatrixString()); System.out.println(eval.toMatrixString()); pwout.flush(); pwout.close(); //Saving model System.out.println("Model saved: " + keyword + Main.modelExtension); weka.core.SerializationHelper.write(keyword + Main.modelExtension, fc.getClassifier() /*model*/); // get curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); pwroc.print(result.toString()); pwroc.flush(); pwroc.close(); // draw curve //rocCurve(eval); } catch (Exception e) { e.printStackTrace(); } }
From source file:miRdup.WekaModule.java
License:Open Source License
public static void rocCurve(Evaluation eval) { try {//from w w w . j a v a 2 s.c om // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); result.toString(); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // result.toString(); // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); System.out.println(""); } catch (Exception e) { e.printStackTrace(); } }
From source file:mlflex.WekaInMemoryLearner.java
License:Open Source License
/** This method calculates the area under the curve for a set of predictions and is designed to support classification of more than two classes. This code was derived from Weka's source code. * * @param predictions Predictions that have been made * @return Area under the curve, weighted by the proportion of instances for each class * @throws Exception// w ww .j av a 2 s . com */ // public static double CalculateWeightedAreaUnderRoc(Predictions predictions) throws Exception { ArrayList<String> uniqueActualClasses = predictions.GetUniqueActualClasses(); if (uniqueActualClasses.size() == 0) return Double.NaN; if (predictions.Size() == 1) { if (predictions.Get(0).WasCorrect()) return 1.0; return 0.5; } if (uniqueActualClasses.size() == 1) return 0.5; ArrayList<String> dependentVariableClasses = Utilities.ProcessorVault.DependentVariableDataProcessor .GetUniqueDependentVariableValues(); FastVector predictionVector = new FastVector(); for (Prediction prediction : predictions) predictionVector.addElement( new NominalPrediction(dependentVariableClasses.indexOf(prediction.DependentVariableValue), Lists.ConvertToDoubleArray(prediction.ClassProbabilities))); double aucTotal = 0; for (int i = 0; i < dependentVariableClasses.size(); i++) { String dependentVariableClass = dependentVariableClasses.get(i); Instances result = new ThresholdCurve().getCurve(predictionVector, i); double auc = ThresholdCurve.getROCArea(result); if (!Instance.isMissingValue(auc)) aucTotal += (auc * new PredictionResults(predictions) .GetNumActualsWithDependentVariableClass(dependentVariableClass)); } return aucTotal / predictions.Size(); }
From source file:mulan.evaluation.measure.MacroAUC.java
License:Open Source License
public double getValue() { double[] labelAUC = new double[numOfLabels]; for (int i = 0; i < numOfLabels; i++) { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions[i], 1); labelAUC[i] = ThresholdCurve.getROCArea(result); }//from w ww. j a va 2 s . c om return Utils.mean(labelAUC); }
From source file:mulan.evaluation.measure.MicroAUC.java
License:Open Source License
public double getValue() { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(all_Predictions, 1); return ThresholdCurve.getROCArea(result); }
From source file:trainableSegmentation.Weka_Segmentation.java
License:GNU General Public License
/** * Display the threshold curve window (for precision/recall, ROC, etc.). * * @param data input instances/* w w w .jav a2 s . c o m*/ * @param classifier classifier to evaluate */ public static void displayGraphs(Instances data, AbstractClassifier classifier) { ThresholdCurve tc = new ThresholdCurve(); FastVector predictions = null; try { final EvaluationUtils eu = new EvaluationUtils(); predictions = eu.getTestPredictions(classifier, data); } catch (Exception e) { IJ.log("Error while evaluating data!"); e.printStackTrace(); return; } Instances result = tc.getCurve(predictions); ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setName(result.relationName() + " (display only)"); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); try { vmc.addPlot(tempd); } catch (Exception e) { IJ.log("Error while adding plot to visualization panel!"); e.printStackTrace(); return; } String plotName = vmc.getName(); JFrame jf = new JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.setVisible(true); }