Java tutorial
/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * MarginCurve.java * Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.evaluation; import java.util.ArrayList; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.core.RevisionHandler; import weka.core.RevisionUtils; import weka.core.Utils; /** * Generates points illustrating the prediction margin. The margin is defined as * the difference between the probability predicted for the actual class and the * highest probability predicted for the other classes. One hypothesis as to the * good performance of boosting algorithms is that they increaes the margins on * the training data and this gives better performance on test data. * * @author Len Trigg (len@reeltwo.com) * @version $Revision$ */ public class MarginCurve implements RevisionHandler { /** * Calculates the cumulative margin distribution for the set of predictions, * returning the result as a set of Instances. The structure of these * Instances is as follows: * <p> * <ul> * <li><b>Margin</b> contains the margin value (which should be plotted as an * x-coordinate) * <li><b>Current</b> contains the count of instances with the current margin * (plot as y axis) * <li><b>Cumulative</b> contains the count of instances with margin less than * or equal to the current margin (plot as y axis) * </ul> * <p> * * @return datapoints as a set of instances, null if no predictions have been * made. */ public Instances getCurve(ArrayList<Prediction> predictions) { if (predictions.size() == 0) { return null; } Instances insts = makeHeader(); double[] margins = getMargins(predictions); int[] sorted = Utils.sort(margins); int binMargin = 0; int totalMargin = 0; insts.add(makeInstance(-1, binMargin, totalMargin)); for (int element : sorted) { double current = margins[element]; double weight = ((NominalPrediction) predictions.get(element)).weight(); totalMargin += weight; binMargin += weight; if (true) { insts.add(makeInstance(current, binMargin, totalMargin)); binMargin = 0; } } return insts; } /** * Pulls all the margin values out of a vector of NominalPredictions. * * @param predictions a FastVector containing NominalPredictions * @return an array of margin values. */ private double[] getMargins(ArrayList<Prediction> predictions) { // sort by predicted probability of the desired class. double[] margins = new double[predictions.size()]; for (int i = 0; i < margins.length; i++) { NominalPrediction pred = (NominalPrediction) predictions.get(i); margins[i] = pred.margin(); } return margins; } /** * Creates an Instances object with the attributes we will be calculating. * * @return the Instances structure. */ private Instances makeHeader() { ArrayList<Attribute> fv = new ArrayList<Attribute>(); fv.add(new Attribute("Margin")); fv.add(new Attribute("Current")); fv.add(new Attribute("Cumulative")); return new Instances("MarginCurve", fv, 100); } /** * Creates an Instance object with the attributes calculated. * * @param margin the margin for this data point. * @param current the number of instances with this margin. * @param cumulative the number of instances with margin less than or equal to * this margin. * @return the Instance object. */ private Instance makeInstance(double margin, int current, int cumulative) { int count = 0; double[] vals = new double[3]; vals[count++] = margin; vals[count++] = current; vals[count++] = cumulative; return new DenseInstance(1.0, vals); } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision$"); } /** * Tests the MarginCurve generation from the command line. The classifier is * currently hardcoded. Pipe in an arff file. * * @param args currently ignored */ public static void main(String[] args) { try { Utils.SMALL = 0; Instances inst = new Instances(new java.io.InputStreamReader(System.in)); inst.setClassIndex(inst.numAttributes() - 1); MarginCurve tc = new MarginCurve(); EvaluationUtils eu = new EvaluationUtils(); weka.classifiers.meta.LogitBoost classifier = new weka.classifiers.meta.LogitBoost(); classifier.setNumIterations(20); ArrayList<Prediction> predictions = eu.getTrainTestPredictions(classifier, inst, inst); Instances result = tc.getCurve(predictions); System.out.println(result); } catch (Exception ex) { ex.printStackTrace(); } } }