Java tutorial
/** * * This file is part of STFUD. * * STFUD is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * STFUD is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with STFUD. If not, see <http://www.gnu.org/licenses/>. * * File name: GetAUC.java * Package: cs.man.ac.uk.classifiers * Created: Oct 23, 2013 * Author: Rob Lyon * * Contact: rob@scienceguyrob.com or robert.lyon@cs.man.ac.uk * Web: <http://www.scienceguyrob.com> or <http://www.cs.manchester.ac.uk> * or <http://www.jb.man.ac.uk> */ package cs.man.ac.uk.classifiers; import java.awt.*; import java.io.*; import java.util.*; import cs.man.ac.uk.arff.ARFFFile; import cs.man.ac.uk.classifiers.streamLearners.trees.HoeffdingTreeGaussianHellingerTester; import cs.man.ac.uk.classifiers.streamLearners.trees.HoeffdingTreeTester; import cs.man.ac.uk.common.Common; import cs.man.ac.uk.log.DebugLogger; import weka.core.*; import weka.core.converters.ArffSaver; import weka.classifiers.*; import weka.classifiers.Evaluation; import weka.classifiers.evaluation.*; import weka.classifiers.trees.J48; import weka.classifiers.trees.j48.Distribution; import weka.gui.visualize.*; /** * The class GetAUC was written to reproduce the results reported in the papers: * * "Learning Decision Trees for Unbalanced Data", Cieslak, David. A. and Chawla, Nitesh. V., * in Machine Learning and Knowledge Discovery in Databases (Daelemans, Goethals and Morik, * editors), vol. 5211 of Lecture notes in Computer Science, pp. 241-256, 2008. * * "Hellinger distance decision trees are robust and skew-insensitive", Cieslak, David. A., * Hoens, Ryan, Chawla, Nitesh. V. and Kegelmeyer, Philip, in Data Mining and Knowledge * Discovery, vol.24, issue 1, pp. 136-158, 2012. * * The main purpose of this class then is to obtain the area under the ROC curve (AUC) using * the algorithms and data sets reported in these papers. As neither of these papers provide * a full implementation of the Hellinger distance based decision tree algorithm, this class * is also used to verify that the implementation of the Hellinger decision tree provided here * is as described in these papers. * * This class is intended to be used manually as an experimental test bed. * * @author Rob Lyon * * @version 1.0, 10/23/13 */ public class GetAUC { // ***************************************** // ***************************************** // Variables // ***************************************** // ***************************************** /** * The path to the Weka ARFF data set to use. */ private static String homeDirectory = "/Users/Rob"; // Letter Data Set private static int CLASS_INDEX = 17; private static String dataPath = homeDirectory + "/Dropbox/ARFF/Letter/Letter_VowelsPositive_RestNegative.arff"; private static String trainPath = homeDirectory + "/Dropbox/ARFF/Letter/Scratch/LetterTrain.arff"; private static String testPath = homeDirectory + "/Dropbox/ARFF/Letter/Scratch/LetterTest.arff"; public static String scratch = homeDirectory + "/Dropbox/ARFF/Letter/Scratch"; // Pen Data Set //private static int CLASS_INDEX = 17; //private static String dataPath = "/Users/"+homeDirectory+"/Dropbox/ARFF/PenBasedRecognition_2Class_3_is_minority.arff"; //private static String dataPath = "/Users/"+homeDirectory+"/Dropbox/ARFF/PenBasedRecognition_2Class_5_is_minority.arff"; // Statlog Landsat Data Set //private static int CLASS_INDEX = 37; //private static String dataPath = "/Users/"+homeDirectory+"/Dropbox/ARFF/StatlogLandsat_2Class.arff"; // Statlog Image Segmentation Data Set //private static int CLASS_INDEX = 20; //private static String dataPath = "/Users/"+homeDirectory+"/Dropbox/ARFF/ImageSegmentation_2Class.arff"; /** * Flag if set to true will cause the ROC curve to be displayed in a GUI. */ private static boolean showGUI = false; /** * The data to use during validation. */ private static Instances data = null; /** * The ROC curve panel used for visualization and computing the AUC. */ private static ThresholdVisualizePanel vmc = null; // ***************************************** // ***************************************** // Main // ***************************************** // ***************************************** /** * Main method that executes the calls that calculate the AUC for the supplied data * set and learning algorithm. * @param args unused arguments. */ public static void main(String[] args) { // Declare classifier to test. @SuppressWarnings("unused") Classifier learner = new J48(); if (getData()) { //double AUC = validate(learner); double AUC = validate5x2CVStream(); /** * C4.5 Bounds */ // Letter dataset //double plus_or_minus = 0.004; // Error margin reported in paper. //double reportedAUC = 0.990; // AUC reported in paper. // Pen digits //double plus_or_minus = 0.005; // Error margin reported in paper. //double reportedAUC = 0.985; // AUC reported in paper. // Satellite image //double plus_or_minus = 0.009; // Error margin reported in paper. //double reportedAUC = 0.906; // AUC reported in paper. // Image segmentation //double plus_or_minus = 0.006; // Error margin reported in paper. //double reportedAUC = 0.982; // AUC reported in paper. /** * HDDT Bounds */ // Letter dataset double plus_or_minus = 0.004; // Error margin reported in paper. double reportedAUC = 0.990; // AUC reported in paper. // Pen digits //double plus_or_minus = 0.002; // Error margin reported in paper. //double reportedAUC = 0.992; // AUC reported in paper. // Satellite image //double plus_or_minus = 0.007; // Error margin reported in paper. //double reportedAUC = 0.911; // AUC reported in paper. // Image segmentation //double plus_or_minus = 0.007; // Error margin reported in paper. //double reportedAUC = 0.984; // AUC reported in paper. double lowerBound = reportedAUC - plus_or_minus; double upperBound = reportedAUC + plus_or_minus; boolean inInterval = inInterval(lowerBound, upperBound, AUC); System.out.println("Area under curve: " + AUC + " Interval [" + lowerBound + "," + upperBound + "] in interval: " + inInterval); if (showGUI) displayGUI(); } } /** * Loads the data set stored at the path stored in dataPath. * @return true if the data set was successfully loaded, else false. */ private static boolean getData() { // load data try { data = new Instances(new BufferedReader(new FileReader(dataPath))); data.setClassIndex(data.numAttributes() - 1); System.out.println("Data set loaded from: " + dataPath); return true; } catch (FileNotFoundException e) { System.out.println("Could not load data set! No data set file at: " + dataPath); return false; } catch (IOException e) { System.out.println("Could not load data set! IOException reading file at: " + dataPath); return false; } } /** * Computes the AUC for the supplied stream learner. * @return the AUC as a double value. */ private static double validate5x2CVStream() { try { // Other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); Distribution testDistribution = new Distribution(test); ArffSaver trainSaver = new ArffSaver(); trainSaver.setInstances(train); trainSaver.setFile(new File(trainPath)); trainSaver.writeBatch(); ArffSaver testSaver = new ArffSaver(); testSaver.setInstances(test); double[][] dist = testDistribution.matrix(); int negativeClassSize = (int) dist[0][0]; int positiveClassSize = (int) dist[0][1]; double balance = (double) positiveClassSize / (double) negativeClassSize; String tempTestPath = testPath.replace(".arff", "_" + positiveClassSize + "_" + negativeClassSize + "_" + balance + "_1.0.arff");// [Test-n-Set-n]_[+]_[-]_[K]_[L]; testSaver.setFile(new File(tempTestPath)); testSaver.writeBatch(); ARFFFile file = new ARFFFile(tempTestPath, CLASS_INDEX, new DebugLogger(false)); file.createMetaData(); HoeffdingTreeTester streamClassifier = new HoeffdingTreeTester(trainPath, tempTestPath, CLASS_INDEX, new String[] { "0", "1" }, new DebugLogger(true)); streamClassifier.train(); System.in.read(); //AUC_SUM += streamClassifier.getROCExternalData("",(int)testDistribution.perClass(1),(int)testDistribution.perClass(0)); streamClassifier.testStatic(homeDirectory + "/FuckSakeTest.txt"); String[] files = Common.getFilePaths(scratch); for (int j = 0; j < files.length; j++) Common.fileDelete(files[j]); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); e.printStackTrace(); return 0; } } /** * Computes the AUC for the supplied learner. * @return the AUC as a double value. */ @SuppressWarnings("unused") private static double validate5x2CV() { try { // other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier String[] options = { "-U", "-A" }; J48 classifier = new J48(); //HTree classifier = new HTree(); classifier.setOptions(options); classifier.buildClassifier(train); eval.evaluateModel(classifier, test); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); AUC_SUM += ThresholdCurve.getROCArea(result); System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } } /** * Computes the AUC for the supplied learner. * @param learner the learning algorithm to use. * @return the AUC as a double value. */ @SuppressWarnings("unused") private static double validate(Classifier learner) { try { Evaluation eval = new Evaluation(data); eval.crossValidateModel(learner, data, 2, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); double AUC = ThresholdCurve.getROCArea(result); vmc.setROCString( "(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 9) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) cp[n] = true; tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); return AUC; } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } } /** * Displays a JFrame with a ROC curve and details of the AUC. */ private static void displayGUI() { if (vmc != null) { // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } else System.out.println("Unable to display GUI as Threshold panel not initialised!"); } public static boolean inInterval(double lowerBound, double upperBound, double value) { int lowerResult = Double.compare(value, lowerBound); int upperResult = Double.compare(value, upperBound); if (lowerResult == 0 | upperResult == 0) // Value equal to either of the bounds. return true; if (lowerResult > 0 & upperResult < 0) // value greater than lower bound, lower than upper bound. return true; else return false; // value lower than lower bound, or greater than upper bound. } }