cs.man.ac.uk.classifiers.GetAUC.java Source code

Java tutorial

Introduction

Here is the source code for cs.man.ac.uk.classifiers.GetAUC.java

Source

/**
 *
 * This file is part of STFUD.
 *
 * STFUD is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * STFUD is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with STFUD.  If not, see <http://www.gnu.org/licenses/>.
 *
 * File name:    GetAUC.java
 * Package: cs.man.ac.uk.classifiers
 * Created:   Oct 23, 2013
 * Author:   Rob Lyon
 * 
 * Contact:   rob@scienceguyrob.com or robert.lyon@cs.man.ac.uk
 * Web:      <http://www.scienceguyrob.com> or <http://www.cs.manchester.ac.uk> 
 *          or <http://www.jb.man.ac.uk>
 */
package cs.man.ac.uk.classifiers;

import java.awt.*;
import java.io.*;
import java.util.*;
import cs.man.ac.uk.arff.ARFFFile;
import cs.man.ac.uk.classifiers.streamLearners.trees.HoeffdingTreeGaussianHellingerTester;
import cs.man.ac.uk.classifiers.streamLearners.trees.HoeffdingTreeTester;
import cs.man.ac.uk.common.Common;
import cs.man.ac.uk.log.DebugLogger;
import weka.core.*;
import weka.core.converters.ArffSaver;
import weka.classifiers.*;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.*;
import weka.classifiers.trees.J48;
import weka.classifiers.trees.j48.Distribution;
import weka.gui.visualize.*;

/**
 * The class GetAUC was written to reproduce the results reported in the papers:
 * 
 * "Learning Decision Trees for Unbalanced Data", Cieslak, David. A. and Chawla, Nitesh. V.,
 * in Machine Learning and Knowledge Discovery in Databases (Daelemans, Goethals and Morik,
 * editors), vol. 5211 of Lecture notes in Computer Science, pp. 241-256, 2008.
 * 
 * "Hellinger distance decision trees are robust and skew-insensitive", Cieslak, David. A.,
 * Hoens, Ryan, Chawla, Nitesh. V. and Kegelmeyer, Philip, in Data Mining and Knowledge
 * Discovery, vol.24, issue 1, pp. 136-158, 2012. 
 * 
 * The main purpose of this class then is to obtain the area under the ROC curve (AUC) using
 * the algorithms and data sets reported in these papers. As neither of these papers provide
 * a full implementation of the Hellinger distance based decision tree algorithm, this class
 * is also used to verify that the implementation of the Hellinger decision tree provided here
 * is as described in these papers.
 * 
 * This class is intended to be used manually as an experimental test bed.
 * 
 * @author Rob Lyon
 *
 * @version 1.0, 10/23/13
 */
public class GetAUC {
    // *****************************************
    // *****************************************
    //             Variables
    // *****************************************
    // *****************************************

    /**
     * The path to the Weka ARFF data set to use.
     */

    private static String homeDirectory = "/Users/Rob";

    // Letter Data Set
    private static int CLASS_INDEX = 17;
    private static String dataPath = homeDirectory + "/Dropbox/ARFF/Letter/Letter_VowelsPositive_RestNegative.arff";
    private static String trainPath = homeDirectory + "/Dropbox/ARFF/Letter/Scratch/LetterTrain.arff";
    private static String testPath = homeDirectory + "/Dropbox/ARFF/Letter/Scratch/LetterTest.arff";
    public static String scratch = homeDirectory + "/Dropbox/ARFF/Letter/Scratch";

    // Pen Data Set
    //private static int CLASS_INDEX = 17;
    //private static String dataPath = "/Users/"+homeDirectory+"/Dropbox/ARFF/PenBasedRecognition_2Class_3_is_minority.arff";
    //private static String dataPath = "/Users/"+homeDirectory+"/Dropbox/ARFF/PenBasedRecognition_2Class_5_is_minority.arff";

    // Statlog Landsat Data Set
    //private static int CLASS_INDEX = 37;
    //private static String dataPath = "/Users/"+homeDirectory+"/Dropbox/ARFF/StatlogLandsat_2Class.arff";

    // Statlog Image Segmentation Data Set
    //private static int CLASS_INDEX = 20;
    //private static String dataPath = "/Users/"+homeDirectory+"/Dropbox/ARFF/ImageSegmentation_2Class.arff";

    /**
     * Flag if set to true will cause the ROC curve to be displayed in a GUI.
     */
    private static boolean showGUI = false;

    /**
     * The data to use during validation.
     */
    private static Instances data = null;

    /**
     * The ROC curve panel used for visualization and computing the AUC.
     */
    private static ThresholdVisualizePanel vmc = null;

    // *****************************************
    // *****************************************
    //             Main
    // *****************************************
    // *****************************************

    /**
     * Main method that executes the calls that calculate the AUC for the supplied data
     * set and learning algorithm.
     * @param args unused arguments.
     */
    public static void main(String[] args) {
        // Declare classifier to test.
        @SuppressWarnings("unused")
        Classifier learner = new J48();

        if (getData()) {
            //double AUC = validate(learner);
            double AUC = validate5x2CVStream();

            /**
             * C4.5 Bounds
             */

            // Letter dataset
            //double plus_or_minus = 0.004; // Error margin reported in paper.
            //double reportedAUC = 0.990; // AUC reported in paper.

            // Pen digits
            //double plus_or_minus = 0.005; // Error margin reported in paper.
            //double reportedAUC = 0.985; // AUC reported in paper.

            // Satellite image
            //double plus_or_minus = 0.009; // Error margin reported in paper.
            //double reportedAUC = 0.906; // AUC reported in paper.

            // Image segmentation
            //double plus_or_minus = 0.006; // Error margin reported in paper.
            //double reportedAUC = 0.982; // AUC reported in paper.

            /**
             * HDDT Bounds
             */

            // Letter dataset
            double plus_or_minus = 0.004; // Error margin reported in paper.
            double reportedAUC = 0.990; // AUC reported in paper.

            // Pen digits
            //double plus_or_minus = 0.002; // Error margin reported in paper.
            //double reportedAUC = 0.992; // AUC reported in paper.

            // Satellite image
            //double plus_or_minus = 0.007; // Error margin reported in paper.
            //double reportedAUC = 0.911; // AUC reported in paper.

            // Image segmentation
            //double plus_or_minus = 0.007; // Error margin reported in paper.
            //double reportedAUC = 0.984; // AUC reported in paper.

            double lowerBound = reportedAUC - plus_or_minus;
            double upperBound = reportedAUC + plus_or_minus;

            boolean inInterval = inInterval(lowerBound, upperBound, AUC);

            System.out.println("Area under curve: " + AUC + " Interval [" + lowerBound + "," + upperBound
                    + "] in interval: " + inInterval);

            if (showGUI)
                displayGUI();
        }
    }

    /**
     * Loads the data set stored at the path stored in dataPath.
     * @return true if the data set was successfully loaded, else false.
     */
    private static boolean getData() {
        // load data
        try {
            data = new Instances(new BufferedReader(new FileReader(dataPath)));
            data.setClassIndex(data.numAttributes() - 1);
            System.out.println("Data set loaded from: " + dataPath);
            return true;
        } catch (FileNotFoundException e) {
            System.out.println("Could not load data set! No data set file at: " + dataPath);
            return false;
        } catch (IOException e) {
            System.out.println("Could not load data set! IOException reading file at: " + dataPath);
            return false;
        }
    }

    /**
     * Computes the AUC for the supplied stream learner.
     * @return the AUC as a double value.
     */
    private static double validate5x2CVStream() {
        try {
            // Other options
            int runs = 5;
            int folds = 2;
            double AUC_SUM = 0;

            // perform cross-validation
            for (int i = 0; i < runs; i++) {
                // randomize data
                int seed = i + 1;
                Random rand = new Random(seed);
                Instances randData = new Instances(data);
                randData.randomize(rand);

                if (randData.classAttribute().isNominal()) {
                    System.out.println("Stratifying...");
                    randData.stratify(folds);
                }

                for (int n = 0; n < folds; n++) {
                    Instances train = randData.trainCV(folds, n);
                    Instances test = randData.testCV(folds, n);

                    Distribution testDistribution = new Distribution(test);

                    ArffSaver trainSaver = new ArffSaver();
                    trainSaver.setInstances(train);
                    trainSaver.setFile(new File(trainPath));
                    trainSaver.writeBatch();

                    ArffSaver testSaver = new ArffSaver();
                    testSaver.setInstances(test);

                    double[][] dist = testDistribution.matrix();
                    int negativeClassSize = (int) dist[0][0];
                    int positiveClassSize = (int) dist[0][1];
                    double balance = (double) positiveClassSize / (double) negativeClassSize;

                    String tempTestPath = testPath.replace(".arff",
                            "_" + positiveClassSize + "_" + negativeClassSize + "_" + balance + "_1.0.arff");// [Test-n-Set-n]_[+]_[-]_[K]_[L];
                    testSaver.setFile(new File(tempTestPath));
                    testSaver.writeBatch();

                    ARFFFile file = new ARFFFile(tempTestPath, CLASS_INDEX, new DebugLogger(false));
                    file.createMetaData();

                    HoeffdingTreeTester streamClassifier = new HoeffdingTreeTester(trainPath, tempTestPath,
                            CLASS_INDEX, new String[] { "0", "1" }, new DebugLogger(true));

                    streamClassifier.train();

                    System.in.read();

                    //AUC_SUM += streamClassifier.getROCExternalData("",(int)testDistribution.perClass(1),(int)testDistribution.perClass(0));
                    streamClassifier.testStatic(homeDirectory + "/FuckSakeTest.txt");

                    String[] files = Common.getFilePaths(scratch);
                    for (int j = 0; j < files.length; j++)
                        Common.fileDelete(files[j]);
                }
            }

            return AUC_SUM / ((double) runs * (double) folds);
        } catch (Exception e) {
            System.out.println("Exception validating data!");
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * Computes the AUC for the supplied learner.
     * @return the AUC as a double value.
     */
    @SuppressWarnings("unused")
    private static double validate5x2CV() {
        try {
            // other options
            int runs = 5;
            int folds = 2;
            double AUC_SUM = 0;

            // perform cross-validation
            for (int i = 0; i < runs; i++) {
                // randomize data
                int seed = i + 1;
                Random rand = new Random(seed);
                Instances randData = new Instances(data);
                randData.randomize(rand);

                if (randData.classAttribute().isNominal()) {
                    System.out.println("Stratifying...");
                    randData.stratify(folds);
                }

                Evaluation eval = new Evaluation(randData);

                for (int n = 0; n < folds; n++) {
                    Instances train = randData.trainCV(folds, n);
                    Instances test = randData.testCV(folds, n);

                    // the above code is used by the StratifiedRemoveFolds filter, the
                    // code below by the Explorer/Experimenter:
                    // Instances train = randData.trainCV(folds, n, rand);

                    // build and evaluate classifier
                    String[] options = { "-U", "-A" };
                    J48 classifier = new J48();
                    //HTree classifier = new HTree();

                    classifier.setOptions(options);
                    classifier.buildClassifier(train);
                    eval.evaluateModel(classifier, test);

                    // generate curve
                    ThresholdCurve tc = new ThresholdCurve();
                    int classIndex = 0;
                    Instances result = tc.getCurve(eval.predictions(), classIndex);

                    // plot curve
                    vmc = new ThresholdVisualizePanel();
                    AUC_SUM += ThresholdCurve.getROCArea(result);
                    System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM);
                }
            }

            return AUC_SUM / ((double) runs * (double) folds);
        } catch (Exception e) {
            System.out.println("Exception validating data!");
            return 0;
        }
    }

    /**
     * Computes the AUC for the supplied learner.
     * @param learner the learning algorithm to use.
     * @return the AUC as a double value.
     */
    @SuppressWarnings("unused")
    private static double validate(Classifier learner) {
        try {

            Evaluation eval = new Evaluation(data);
            eval.crossValidateModel(learner, data, 2, new Random(1));

            // generate curve
            ThresholdCurve tc = new ThresholdCurve();
            int classIndex = 0;
            Instances result = tc.getCurve(eval.predictions(), classIndex);

            // plot curve
            vmc = new ThresholdVisualizePanel();
            double AUC = ThresholdCurve.getROCArea(result);
            vmc.setROCString(
                    "(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 9) + ")");
            vmc.setName(result.relationName());

            PlotData2D tempd = new PlotData2D(result);
            tempd.setPlotName(result.relationName());
            tempd.addInstanceNumberAttribute();

            // specify which points are connected
            boolean[] cp = new boolean[result.numInstances()];
            for (int n = 1; n < cp.length; n++)
                cp[n] = true;

            tempd.setConnectPoints(cp);
            // add plot
            vmc.addPlot(tempd);

            return AUC;
        } catch (Exception e) {
            System.out.println("Exception validating data!");
            return 0;
        }
    }

    /**
     * Displays a JFrame with a ROC curve and details of the AUC.
     */
    private static void displayGUI() {
        if (vmc != null) {
            // display curve
            String plotName = vmc.getName();
            final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
            jf.setSize(500, 400);
            jf.getContentPane().setLayout(new BorderLayout());
            jf.getContentPane().add(vmc, BorderLayout.CENTER);

            jf.addWindowListener(new java.awt.event.WindowAdapter() {
                public void windowClosing(java.awt.event.WindowEvent e) {
                    jf.dispose();
                }
            });
            jf.setVisible(true);
        } else
            System.out.println("Unable to display GUI as Threshold panel not initialised!");
    }

    public static boolean inInterval(double lowerBound, double upperBound, double value) {
        int lowerResult = Double.compare(value, lowerBound);
        int upperResult = Double.compare(value, upperBound);

        if (lowerResult == 0 | upperResult == 0) // Value equal to either of the bounds.
            return true;

        if (lowerResult > 0 & upperResult < 0) // value greater than lower bound, lower than upper bound.
            return true;
        else
            return false; // value lower than lower bound, or greater than upper bound.

    }
}