List of usage examples for weka.classifiers Evaluation predictions
public ArrayList<Prediction> predictions()
From source file:algoritmogeneticocluster.NewClass.java
public static void main(String[] args) throws Exception { BufferedReader datafile = readDataFile("tabela10.arff"); Instances data = new Instances(datafile); data.setClassIndex(data.numAttributes() - 1); // Do 10-split cross validation Instances[][] split = crossValidationSplit(data, 10); // Separate split into training and testing arrays Instances[] trainingSplits = split[0]; Instances[] testingSplits = split[1]; // Use a set of classifiers Classifier[] models = { new SMO(), new J48(), // a decision tree new PART(), new DecisionTable(), //decision table majority classifier new DecisionStump() //one-level decision tree };/* w w w. ja va 2 s.com*/ // Run for each model for (int j = 0; j < models.length; j++) { // Collect every group of predictions for current model in a FastVector FastVector predictions = new FastVector(); // For each training-testing split pair, train and test the classifier for (int i = 0; i < trainingSplits.length; i++) { Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]); predictions.appendElements(validation.predictions()); // Uncomment to see the summary for each training-testing pair. //System.out.println(models[j].toString()); } // Calculate overall accuracy of current classifier on all splits double accuracy = calculateAccuracy(predictions); // Print current classifier's name and accuracy in a complicated, // but nice-looking way. System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": " + String.format("%.2f%%", accuracy) + "\n---------------------------------"); } }
From source file:com.sliit.rules.RuleContainer.java
public String predictionResult(String filePath) throws Exception { File testPath = new File(filePath); CSVLoader loader = new CSVLoader(); loader.setSource(testPath);/*w w w .j a v a 2 s . c om*/ Instances testInstances = loader.getDataSet(); testInstances.setClassIndex(testInstances.numAttributes() - 1); Evaluation eval = new Evaluation(testInstances); eval.evaluateModel(ruleMoldel, testInstances); ArrayList<Prediction> predictions = eval.predictions(); int predictedVal = (int) predictions.get(0).predicted(); String cdetails = instances.classAttribute().value(predictedVal); return cdetails; }
From source file:com.sliit.views.DataVisualizerPanel.java
void getRocCurve() { try {//from www .j a v a 2 s.c o m Instances data; data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.KNNView.java
void getRocCurve() { try {//from w ww . j a v a 2 s . c o m Instances data; data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); rocPanel.removeAll(); rocPanel.add(vmc, "vmc", 0); rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.SVMView.java
/** * draw ROC curve//from w ww . j a v a2s . c om */ void getRocCurve() { try { Instances data; data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); //train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // rocPanel.removeAll(); // rocPanel.add(vmc, "vmc", 0); // rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied learner. * @return the AUC as a double value./*w ww .j a v a 2 s .c om*/ */ @SuppressWarnings("unused") private static double validate5x2CV() { try { // other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier String[] options = { "-U", "-A" }; J48 classifier = new J48(); //HTree classifier = new HTree(); classifier.setOptions(options); classifier.buildClassifier(train); eval.evaluateModel(classifier, test); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); AUC_SUM += ThresholdCurve.getROCArea(result); System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied learner. * @param learner the learning algorithm to use. * @return the AUC as a double value.//from ww w . j a v a 2 s. com */ @SuppressWarnings("unused") private static double validate(Classifier learner) { try { Evaluation eval = new Evaluation(data); eval.crossValidateModel(learner, data, 2, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); double AUC = ThresholdCurve.getROCArea(result); vmc.setROCString( "(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 9) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) cp[n] = true; tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); return AUC; } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } }
From source file:miRdup.WekaModule.java
License:Open Source License
public static void trainModel(File arff, String keyword) { dec.setMaximumFractionDigits(3);/* w w w. j av a 2s. c o m*/ System.out.println("\nTraining model on file " + arff); try { // load data DataSource source = new DataSource(arff.toString()); Instances data = source.getDataSet(); if (data.classIndex() == -1) { data.setClassIndex(data.numAttributes() - 1); } PrintWriter pwout = new PrintWriter(new FileWriter(keyword + Main.modelExtension + "Output")); PrintWriter pwroc = new PrintWriter(new FileWriter(keyword + Main.modelExtension + "roc.arff")); //remove ID row Remove rm = new Remove(); rm.setAttributeIndices("1"); FilteredClassifier fc = new FilteredClassifier(); fc.setFilter(rm); // // train model svm // weka.classifiers.functions.LibSVM model = new weka.classifiers.functions.LibSVM(); // model.setOptions(weka.core.Utils.splitOptions("-S 0 -K 2 -D 3 -G 0.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.0010 -P 0.1 -B")); // train model MultilayerPerceptron // weka.classifiers.functions.MultilayerPerceptron model = new weka.classifiers.functions.MultilayerPerceptron(); // model.setOptions(weka.core.Utils.splitOptions("-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a")); // train model Adaboost on RIPPER // weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1(); // model.setOptions(weka.core.Utils.splitOptions("weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.rules.JRip -- -F 10 -N 2.0 -O 5 -S 1")); // train model Adaboost on FURIA // weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1(); // model.setOptions(weka.core.Utils.splitOptions("weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.rules.FURIA -- -F 10 -N 2.0 -O 5 -S 1 -p 0 -s 0")); //train model Adaboot on J48 trees // weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1(); // model.setOptions( // weka.core.Utils.splitOptions( // "-P 100 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -C 0.25 -M 2")); //train model Adaboot on Random Forest trees weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1(); model.setOptions(weka.core.Utils .splitOptions("-P 100 -S 1 -I 10 -W weka.classifiers.trees.RandomForest -- -I 50 -K 0 -S 1")); if (Main.debug) { System.out.print("Model options: " + model.getClass().getName().trim() + " "); } System.out.print(model.getClass() + " "); for (String s : model.getOptions()) { System.out.print(s + " "); } pwout.print("Model options: " + model.getClass().getName().trim() + " "); for (String s : model.getOptions()) { pwout.print(s + " "); } //build model // model.buildClassifier(data); fc.setClassifier(model); fc.buildClassifier(data); // cross validation 10 times on the model Evaluation eval = new Evaluation(data); //eval.crossValidateModel(model, data, 10, new Random(1)); StringBuffer sb = new StringBuffer(); eval.crossValidateModel(fc, data, 10, new Random(1), sb, new Range("first,last"), false); //System.out.println(sb); pwout.println(sb); pwout.flush(); // output pwout.println("\n" + eval.toSummaryString()); System.out.println(eval.toSummaryString()); pwout.println(eval.toClassDetailsString()); System.out.println(eval.toClassDetailsString()); //calculate importants values String ev[] = eval.toClassDetailsString().split("\n"); String ptmp[] = ev[3].trim().split(" "); String ntmp[] = ev[4].trim().split(" "); String avgtmp[] = ev[5].trim().split(" "); ArrayList<String> p = new ArrayList<String>(); ArrayList<String> n = new ArrayList<String>(); ArrayList<String> avg = new ArrayList<String>(); for (String s : ptmp) { if (!s.trim().isEmpty()) { p.add(s); } } for (String s : ntmp) { if (!s.trim().isEmpty()) { n.add(s); } } for (String s : avgtmp) { if (!s.trim().isEmpty()) { avg.add(s); } } double tp = Double.parseDouble(p.get(0)); double fp = Double.parseDouble(p.get(1)); double tn = Double.parseDouble(n.get(0)); double fn = Double.parseDouble(n.get(1)); double auc = Double.parseDouble(avg.get(7)); pwout.println("\nTP=" + tp + "\nFP=" + fp + "\nTN=" + tn + "\nFN=" + fn); System.out.println("\nTP=" + tp + "\nFP=" + fp + "\nTN=" + tn + "\nFN=" + fn); //specificity, sensitivity, Mathew's correlation, Prediction accuracy double sp = ((tn) / (tn + fp)); double se = ((tp) / (tp + fn)); double acc = ((tp + tn) / (tp + tn + fp + fn)); double mcc = ((tp * tn) - (fp * fn)) / Math.sqrt((tp + fp) * (tn + fn) * (tp + fn) * tn + fp); String output = "\nse=" + dec.format(se).replace(",", ".") + "\nsp=" + dec.format(sp).replace(",", ".") + "\nACC=" + dec.format(acc).replace(",", ".") + "\nMCC=" + dec.format(mcc).replace(",", ".") + "\nAUC=" + dec.format(auc).replace(",", "."); pwout.println(output); System.out.println(output); pwout.println(eval.toMatrixString()); System.out.println(eval.toMatrixString()); pwout.flush(); pwout.close(); //Saving model System.out.println("Model saved: " + keyword + Main.modelExtension); weka.core.SerializationHelper.write(keyword + Main.modelExtension, fc.getClassifier() /*model*/); // get curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); pwroc.print(result.toString()); pwroc.flush(); pwroc.close(); // draw curve //rocCurve(eval); } catch (Exception e) { e.printStackTrace(); } }
From source file:miRdup.WekaModule.java
License:Open Source License
public static void testModel(File testarff, String predictionsFile, String classifier, boolean predictMiRNA) { System.out.println("Testing model on " + predictionsFile + " adapted in " + testarff + ". Submitted to model " + classifier); try {//w ww .ja va 2 s . c om //add predictions sequences to object ArrayList<MirnaObject> alobj = new ArrayList<MirnaObject>(); BufferedReader br = null; try { br = new BufferedReader(new FileReader(predictionsFile + ".folded")); } catch (FileNotFoundException fileNotFoundException) { br = new BufferedReader(new FileReader(predictionsFile)); } BufferedReader br2 = new BufferedReader(new FileReader(testarff)); String line2 = br2.readLine(); while (!line2.startsWith("@data")) { line2 = br2.readLine(); } String line = " "; int cpt = 0; while (br.ready()) { line = br.readLine(); line2 = br2.readLine(); String[] tab = line.split("\t"); MirnaObject m = new MirnaObject(); m.setArff(line2); m.setId(cpt++); m.setIdName(tab[0]); m.setMatureSequence(tab[1]); m.setPrecursorSequence(tab[2]); m.setStructure(tab[3]); alobj.add(m); } br.close(); br2.close(); // load data DataSource source = new DataSource(testarff.toString()); Instances data = source.getDataSet(); if (data.classIndex() == -1) { data.setClassIndex(data.numAttributes() - 1); } //remove ID row data.deleteAttributeAt(0); //load model Classifier model = (Classifier) weka.core.SerializationHelper.read(classifier); // evaluate dataset on the model Evaluation eval = new Evaluation(data); eval.evaluateModel(model, data); FastVector fv = eval.predictions(); // output PrintWriter pw = new PrintWriter(new FileWriter(predictionsFile + "." + classifier + ".miRdup.txt")); PrintWriter pwt = new PrintWriter( new FileWriter(predictionsFile + "." + classifier + ".miRdup.tab.txt")); PrintWriter pwout = new PrintWriter( new FileWriter(predictionsFile + "." + classifier + ".miRdupOutput.txt")); for (int i = 0; i < fv.size(); i++) { //System.out.println(fv.elementAt(i).toString()); String[] tab = fv.elementAt(i).toString().split(" "); int actual = Integer.valueOf(tab[1].substring(0, 1)); int predicted = Integer.valueOf(tab[2].substring(0, 1)); double score = 0.0; boolean validated = false; if (actual == predicted) { //case validated int s = tab[4].length(); try { score = Double.valueOf(tab[4]); //score = Double.valueOf(tab[4].substring(0, s - 1)); } catch (NumberFormatException numberFormatException) { score = 0.0; } validated = true; } else {// case not validated int s = tab[5].length(); try { score = Double.valueOf(tab[5]); //score = Double.valueOf(tab[5].substring(0, s - 1)); } catch (NumberFormatException numberFormatException) { score = 0.0; } validated = false; } MirnaObject m = alobj.get(i); m.setActual(actual); m.setPredicted(predicted); m.setScore(score); m.setValidated(validated); m.setNeedPrediction(predictMiRNA); String predictionMiRNA = ""; if (predictMiRNA && validated == false) { predictionMiRNA = miRdupPredictor.Predictor.predictionBySequence(m.getPrecursorSequence(), classifier, classifier + ".miRdupPrediction.txt"); try { m.setPredictedmiRNA(predictionMiRNA.split(",")[0]); m.setPredictedmiRNAstar(predictionMiRNA.split(",")[1]); } catch (Exception e) { m.setPredictedmiRNA(predictionMiRNA); m.setPredictedmiRNAstar(predictionMiRNA); } } pw.println(m.toStringFullPredictions()); pwt.println(m.toStringPredictions()); if (i % 100 == 0) { pw.flush(); pwt.flush(); } } //System.out.println(eval.toSummaryString("\nSummary results of predictions\n======\n", false)); String[] out = eval.toSummaryString("\nSummary results of predictions\n======\n", false).split("\n"); String info = out[0] + "\n" + out[1] + "\n" + out[2] + "\n" + out[4] + "\n" + out[5] + "\n" + out[6] + "\n" + out[7] + "\n" + out[11] + "\n"; System.out.println(info); //System.out.println("Predicted position of the miRNA by miRdup:"+predictionMiRNA); pwout.println( "File " + predictionsFile + " adapted in " + testarff + " submitted to model " + classifier); pwout.println(info); pw.flush(); pw.close(); pwt.flush(); pwt.close(); pwout.flush(); pwout.close(); System.out.println("Results in " + predictionsFile + "." + classifier + ".miRdup.txt"); // draw curve //rocCurve(eval); } catch (Exception e) { e.printStackTrace(); } }
From source file:miRdup.WekaModule.java
License:Open Source License
public static String testModel(File testarff, String classifier) { // System.out.println("Testing model on "+testarff+". Submitted to model "+classifier); try {/*w w w . ja va 2s. c o m*/ // load data DataSource source = new DataSource(testarff.toString()); Instances data = source.getDataSet(); if (data.classIndex() == -1) { data.setClassIndex(data.numAttributes() - 1); } //load model Classifier model = (Classifier) weka.core.SerializationHelper.read(classifier); // evaluate dataset on the model Evaluation eval = new Evaluation(data); eval.evaluateModel(model, data); FastVector fv = eval.predictions(); //calculate importants values String ev[] = eval.toClassDetailsString().split("\n"); String p = ev[3].trim(); String n = ev[4].trim(); double tp = Double.parseDouble(p.substring(0, 6).trim()); double fp = 0; try { fp = Double.parseDouble(p.substring(11, 16).trim()); } catch (Exception exception) { fp = Double.parseDouble(p.substring(7, 16).trim()); } double tn = Double.parseDouble(n.substring(0, 6).trim()); double fn = 0; try { fn = Double.parseDouble(n.substring(11, 16).trim()); } catch (Exception exception) { fn = Double.parseDouble(n.substring(7, 16).trim()); } //System.out.println("\nTP="+tp+"\nFP="+fp+"\nTN="+tn+"\nFN="+fn); //specificity, sensitivity, Mathew's correlation, Prediction accuracy double sp = ((tn) / (tn + fp)); double se = ((tp) / (tp + fn)); double acc = ((tp + tn) / (tp + tn + fp + fn)); double mcc = ((tp * tn) - (fp * fn)) / Math.sqrt((tp + fp) * (tn + fn) * (tp + fn) * tn + fp); // System.out.println("\nse="+se+"\nsp="+sp+"\nACC="+dec.format(acc).replace(",", ".")+"\nMCC="+dec.format(mcc).replace(",", ".")); // System.out.println(eval.toMatrixString()); String out = dec.format(acc).replace(",", "."); System.out.println(out); return out; } catch (Exception e) { e.printStackTrace(); return ""; } }