cs.man.ac.uk.predict.Predictor.java Source code

Java tutorial

Introduction

Here is the source code for cs.man.ac.uk.predict.Predictor.java

Source

/**
 *
 * This file is part of STFUD.
 *
 * STFUD is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * STFUD is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with STFUD.  If not, see <http://www.gnu.org/licenses/>.
 *
 * File name:    Predictor.java
 * Package: cs.man.ac.uk.predict
 * Created:   October 8, 2013
 * Author:   Rob Lyon
 * 
 * Contact:   rob@scienceguyrob.com or robert.lyon@cs.man.ac.uk
 * Web:      <http://www.scienceguyrob.com> or <http://www.cs.manchester.ac.uk> 
 *          or <http://www.jb.man.ac.uk>
 */
package cs.man.ac.uk.predict;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Vector;

import cs.man.ac.uk.common.Common;
import cs.man.ac.uk.io.Writer;
import cs.man.ac.uk.stats.Cast;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.classifiers.functions.SMO;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;

/**
 * 
 * The class Predictor is used to make predictions on real data, i.e. outside a experimental framework.
 * The key here is presenting the results succinctly i.e. presenting the positive predictions in a simple
 * way that is useful for further analysis.
 *
 * Some of the methods in this class have been written specifically to accommodate some of the problems
 * associated with classifying pulsar data (or large data sets in general). These include recording the
 * positive predictions made, and linking these predictions back to the original data files for each
 * instance. This is important for radio astronomy as we may need to go back to the raw data to understand
 * why an instance received a positive label.
 * 
 * @author Rob Lyon
 *
 * @version 1.0, 10/09/13
 */
public class Predictor {
    //*****************************************
    //*****************************************
    //              Variables
    //*****************************************
    //*****************************************

    private static int score1 = 0;
    private static int score2 = 1;
    private static int score3 = 2;
    private static int score4 = 3;
    private static int score5 = 4;
    private static int score6 = 5;
    private static int score7 = 6;
    private static int score8 = 7;
    private static int score9 = 8;
    private static int score10 = 9;
    private static int score11 = 10;
    private static int score12 = 11;
    private static int score13 = 12;
    private static int score14 = 13;
    private static int score15 = 14;
    private static int score16 = 15;
    private static int score17 = 16;
    private static int score18 = 17;
    private static int score19 = 18;
    private static int score20 = 19;
    private static int score21 = 20;
    private static int score22 = 21;

    //*****************************************
    //*****************************************
    //              Ensemble
    //*****************************************
    //*****************************************

    /**
     * Executes the data processor.
     * @param args unused command line arguments.
     */
    public static void main(String[] args) {
        String testPath = "/Users/rob/Experiments/PULSAR/Candidates.unlabelled.arff";
        String resultPath = "/Users/rob/Experiments/PULSAR/RuleBasedPredictions.csv";

        makePredictionsRuleBased(testPath, resultPath);
    }

    public static void makePredictionsRuleBased(String testPath, String resultPath) {
        //Firstly try to read the files in
        File testFile = new File(testPath);

        //if the files exist.
        if (testFile.exists()) {
            // Variables used to store the line of the being read
            // using the input stream, and an array list to store the input
            // patterns into.
            String line = "";

            // Read the file and display it line by line. 
            BufferedReader in = null;

            // Read in and store each positive prediction in the vector.
            try {
                //open stream to file
                in = new BufferedReader(new FileReader(testFile));

                try {
                    // Ignore start of ARFF file
                    while ((line = in.readLine()) != null)
                        if (line.toLowerCase().contains("@data"))
                            break;

                    // Now process data.

                    int instanceNumber = 0;

                    while ((line = in.readLine()) != null) {
                        instanceNumber += 1;
                        String[] components = line.split(",");
                        double[] data = new double[22];

                        for (int i = 0; i < components.length - 1; i++)
                            data[i] = Cast.StringToDouble(components[i]);

                        // Now check rules:

                        /*
                         * 1. Score2='(0.097258-2.072206]' Score4='(-inf-1652]' 1281 ==> class=0 1262    conf:(0.99)
                         * 2. Score1='(105.605028-inf)' Score2='(0.097258-2.072206]' Score6='(0.885545-1.966381]' 1338 ==> class=0 1318    conf:(0.99)
                         * 3. Score2='(0.097258-2.072206]' Score8='(20.883587-inf)' 1330 ==> class=0 1310    conf:(0.98)
                         * 4. Score1='(105.605028-inf)' Score5='(-inf-21.247557]' Score6='(0.885545-1.966381]' 1292 ==> class=0 1272    conf:(0.98)
                         * 5. Score1='(105.605028-inf)' Score6='(0.885545-1.966381]' 1386 ==> class=0 1364    conf:(0.98)
                         * 6. Score1='(105.605028-inf)' Score2='(0.097258-2.072206]' Score5='(-inf-21.247557]' 1285 ==> class=0 1264    conf:(0.98)
                         * 7. Score2='(0.097258-2.072206]' Score5='(-inf-21.247557]' Score6='(0.885545-1.966381]' 1417 ==> class=0 1393    conf:(0.98)
                         * 8. Score2='(0.097258-2.072206]' Score6='(0.885545-1.966381]' 1514 ==> class=0 1488    conf:(0.98)
                         * 9. Score1='(105.605028-inf)' Score5='(-inf-21.247557]' 1330 ==> class=0 1307    conf:(0.98)
                         * 10. Score2='(0.097258-2.072206]' Score5='(-inf-21.247557]' 1466 ==> class=0 1438    conf:(0.98)
                         * 11. Score1='(105.605028-inf)' Score2='(0.097258-2.072206]' Score9='(1321.969072-inf)' 1488 ==> class=0 1456    conf:(0.98)
                         * 12. Score22='(-inf-4.252805]' 1381 ==> class=0 1351    conf:(0.98)
                         * 13. Score1='(105.605028-inf)' Score2='(0.097258-2.072206]' Score3='(8.5-inf)' Score9='(1321.969072-inf)' 1381 ==> class=0 1350    conf:(0.98)
                         * 14. Score2='(0.097258-2.072206]' Score9='(1321.969072-inf)' 1596 ==> class=0 1560    conf:(0.98)
                         * 15. Score2='(0.097258-2.072206]' Score3='(8.5-inf)' Score9='(1321.969072-inf)' 1396 ==> class=0 1364    conf:(0.98)
                         * 16. Score4='(-inf-1652]' Score6='(0.885545-1.966381]' 1468 ==> class=0 1434    conf:(0.98)
                         * 17. Score5='(-inf-21.247557]' Score6='(0.885545-1.966381]' Score9='(1321.969072-inf)' 1422 ==> class=0 1389    conf:(0.98)
                         * 18. Score4='(-inf-1652]' Score5='(-inf-21.247557]' Score6='(0.885545-1.966381]' 1416 ==> class=0 1383    conf:(0.98)
                         * 19. Score1='(105.605028-inf)' Score2='(0.097258-2.072206]' 1786 ==> class=0 1744    conf:(0.98)
                         * 20. Score4='(-inf-1652]' 1565 ==> class=0 1528    conf:(0.98)
                         * 21. Score1='(105.605028-inf)' Score3='(8.5-inf)' Score9='(1321.969072-inf)' 1395 ==> class=0 1362    conf:(0.98)
                         * 22. Score3='(8.5-inf)' Score9='(1321.969072-inf)' 1412 ==> class=0 1378    conf:(0.98)
                         * 23. Score6='(0.885545-1.966381]' Score9='(1321.969072-inf)' 1527 ==> class=0 1490    conf:(0.98)
                         * 24. Score4='(-inf-1652]' Score5='(-inf-21.247557]' 1435 ==> class=0 1400    conf:(0.98)
                         * 25. Score1='(105.605028-inf)' Score2='(0.097258-2.072206]' Score3='(8.5-inf)' 1674 ==> class=0 1633    conf:(0.98)
                         * 26. Score5='(-inf-21.247557]' Score9='(1321.969072-inf)' 1458 ==> class=0 1422    conf:(0.98)
                         * 27. Score2='(0.097258-2.072206]' Score3='(8.5-inf)' 1797 ==> class=0 1752    conf:(0.97)
                         * 28. Score2='(0.097258-2.072206]' 2079 ==> class=0 2025    conf:(0.97)
                         * 29. Score11='(-inf-760.961651]' Score13='(11.544606-inf)' 1303 ==> class=1 1269    conf:(0.97)
                         * 30. Score1='(105.605028-inf)' Score3='(8.5-inf)' 1691 ==> class=0 1646    conf:(0.97)
                         * 
                         * RULES USED:
                         * 
                         * 1
                         * 2
                         * 3
                         * 4
                         * 5
                         * 29
                         */

                        if (!Common.inInterval(0.097258, 2.072206, data[score2])
                                & !Common.inInterval(Double.NEGATIVE_INFINITY, 1652, data[score4]))
                            if (!Common.inInterval(105.605028, Double.POSITIVE_INFINITY, data[score1])
                                    & !Common.inInterval(0.097258, 2.072206, data[score2]))
                                if (!Common.inInterval(0.097258, 2.072206, data[score2])
                                        & !Common.inInterval(20.883587, Double.POSITIVE_INFINITY, data[score8]))
                                    if (!Common.inInterval(105.605028, Double.POSITIVE_INFINITY, data[score1])
                                            & !Common.inInterval(Double.NEGATIVE_INFINITY, 21.247557, data[score5])
                                            & !Common.inInterval(0.885545, 1.966381, data[score6]))
                                        if (Common.inInterval(Double.NEGATIVE_INFINITY, 760.961651, data[score11])
                                                & Common.inInterval(5.0, Double.POSITIVE_INFINITY, data[score13]))
                                            if (Common.inInterval(0, Double.POSITIVE_INFINITY, data[score14])) {
                                                Writer.append(resultPath, instanceNumber + "\n");
                                            }
                    }
                } catch (IOException e) {
                } finally {
                    in.close();
                }
            } catch (Exception e) {
            }
        } else {
            System.out.println("One of the supplied file paths is invalid.");
        }
    }

    public static void makePredictionsEnsembleNew(String trainPath, String testPath, String resultPath) {
        System.out.println("Training set: " + trainPath);
        System.out.println("Test set: " + testPath);

        /**
         * The ensemble classifiers. This is a heterogeneous ensemble.
         */
        J48 learner1 = new J48();
        SMO learner2 = new SMO();
        NaiveBayes learner3 = new NaiveBayes();
        MultilayerPerceptron learner5 = new MultilayerPerceptron();

        System.out.println("Training Ensemble.");
        long startTime = System.nanoTime();
        try {
            BufferedReader reader = new BufferedReader(new FileReader(trainPath));
            Instances data = new Instances(reader);
            data.setClassIndex(data.numAttributes() - 1);
            System.out.println("Training data length: " + data.numInstances());

            learner1.buildClassifier(data);
            learner2.buildClassifier(data);
            learner3.buildClassifier(data);
            learner5.buildClassifier(data);

            long endTime = System.nanoTime();
            long nanoseconds = endTime - startTime;
            double seconds = (double) nanoseconds / 1000000000.0;
            System.out.println("Training Ensemble completed in " + nanoseconds + " (ns) or " + seconds + " (s).");
        } catch (IOException e) {
            System.out.println("Could not train Ensemble classifier IOException on training data file.");
        } catch (Exception e) {
            System.out.println("Could not train Ensemble classifier Exception building model.");
        }

        try {
            String line = "";

            // Read the file and display it line by line. 
            BufferedReader in = null;

            // Read in and store each positive prediction in the tree map.
            try {
                //open stream to file
                in = new BufferedReader(new FileReader(testPath));

                while ((line = in.readLine()) != null) {
                    if (line.toLowerCase().contains("@data"))
                        break;
                }
            } catch (Exception e) {
            }

            // A different ARFF loader used here (compared to above) as
            // the ARFF file may be extremely large. In which case the whole
            // file cannot be read in. Instead it is read in incrementally.
            ArffLoader loader = new ArffLoader();
            loader.setFile(new File(testPath));

            Instances data = loader.getStructure();
            data.setClassIndex(data.numAttributes() - 1);

            System.out.println("Ensemble Classifier is ready.");
            System.out.println("Testing on all instances avaialable.");

            startTime = System.nanoTime();

            int instanceNumber = 0;

            // label instances
            Instance current;

            while ((current = loader.getNextInstance(data)) != null) {
                instanceNumber += 1;
                line = in.readLine();

                double classification1 = learner1.classifyInstance(current);
                double classification2 = learner2.classifyInstance(current);
                double classification3 = learner3.classifyInstance(current);
                double classification5 = learner5.classifyInstance(current);

                // All classifiers must agree. This is a very primitive ensemble strategy!
                if (classification1 == 1 && classification2 == 1 && classification3 == 1 && classification5 == 1) {
                    if (line != null) {
                        //System.out.println("Instance: "+instanceNumber+"\t"+line);
                        //System.in.read();
                    }
                    Writer.append(resultPath, instanceNumber + "\n");
                }
            }

            in.close();

            System.out.println("Test set instances: " + instanceNumber);

            long endTime = System.nanoTime();
            long duration = endTime - startTime;
            double seconds = (double) duration / 1000000000.0;

            System.out.println("Testing Ensemble completed in " + duration + " (ns) or " + seconds + " (s).");
        } catch (Exception e) {
            System.out.println("Could not test Ensemble classifier due to an error.");
        }
    }

    public static void makePredictionsEnsembleStream(String trainPath, String testPath, String resultPath) {
        System.out.println("Training set: " + trainPath);
        System.out.println("Test set: " + testPath);

        /**
         * The ensemble classifiers. This is a heterogeneous ensemble.
         */
        J48 learner1 = new J48();
        SMO learner2 = new SMO();
        NaiveBayes learner3 = new NaiveBayes();
        MultilayerPerceptron learner5 = new MultilayerPerceptron();

        System.out.println("Training Ensemble.");
        long startTime = System.nanoTime();
        try {
            BufferedReader reader = new BufferedReader(new FileReader(trainPath));
            Instances data = new Instances(reader);
            data.setClassIndex(data.numAttributes() - 1);
            System.out.println("Training data length: " + data.numInstances());

            learner1.buildClassifier(data);
            learner2.buildClassifier(data);
            learner3.buildClassifier(data);
            learner5.buildClassifier(data);

            long endTime = System.nanoTime();
            long nanoseconds = endTime - startTime;
            double seconds = (double) nanoseconds / 1000000000.0;
            System.out.println("Training Ensemble completed in " + nanoseconds + " (ns) or " + seconds + " (s).");
        } catch (IOException e) {
            System.out.println("Could not train Ensemble classifier IOException on training data file.");
        } catch (Exception e) {
            System.out.println("Could not train Ensemble classifier Exception building model.");
        }

        try {
            // A different ARFF loader used here (compared to above) as
            // the ARFF file may be extremely large. In which case the whole
            // file cannot be read in. Instead it is read in incrementally.
            ArffLoader loader = new ArffLoader();
            loader.setFile(new File(testPath));

            Instances data = loader.getStructure();
            data.setClassIndex(data.numAttributes() - 1);

            System.out.println("Ensemble Classifier is ready.");
            System.out.println("Testing on all instances avaialable.");

            startTime = System.nanoTime();

            int instanceNumber = 0;

            // label instances
            Instance current;

            while ((current = loader.getNextInstance(data)) != null) {
                instanceNumber += 1;

                double classification1 = learner1.classifyInstance(current);
                double classification2 = learner2.classifyInstance(current);
                double classification3 = learner3.classifyInstance(current);
                double classification5 = learner5.classifyInstance(current);

                // All classifiers must agree. This is a very primitive ensemble strategy!
                if (classification1 == 1 && classification2 == 1 && classification3 == 1 && classification5 == 1) {
                    Writer.append(resultPath, instanceNumber + "\n");
                }
            }

            System.out.println("Test set instances: " + instanceNumber);

            long endTime = System.nanoTime();
            long duration = endTime - startTime;
            double seconds = (double) duration / 1000000000.0;

            System.out.println("Testing Ensemble completed in " + duration + " (ns) or " + seconds + " (s).");
        } catch (Exception e) {
            System.out.println("Could not test Ensemble classifier due to an error.");
        }
    }

    public static void makePredictionsJ48(String trainPath, String testPath, String resultPath) {
        /**
         * The decision tree classifier.
         */
        J48 learner = new J48();

        System.out.println("Training set: " + trainPath);
        System.out.println("Test set: " + testPath);

        System.out.println("Training J48");
        long startTime = System.nanoTime();
        try {
            BufferedReader reader = new BufferedReader(new FileReader(trainPath));
            Instances data = new Instances(reader);
            data.setClassIndex(data.numAttributes() - 1);
            System.out.println("Training data length: " + data.numInstances());
            learner.buildClassifier(data);

            long endTime = System.nanoTime();
            long nanoseconds = endTime - startTime;
            double seconds = (double) nanoseconds / 1000000000.0;
            System.out.println("Training J48 completed in " + nanoseconds + " (ns) or " + seconds + " (s)");
        } catch (IOException e) {
            System.out.println("Could not train J48 classifier IOException on training data file");
        } catch (Exception e) {
            System.out.println("Could not train J48 classifier Exception building model");
        }

        try {
            // Prepare data for testing
            //BufferedReader reader = new BufferedReader( new FileReader(testPath));
            //Instances data = new Instances(reader);
            //data.setClassIndex(data.numAttributes() - 1);

            ArffLoader loader = new ArffLoader();
            loader.setFile(new File(testPath));
            Instances data = loader.getStructure();
            data.setClassIndex(data.numAttributes() - 1);

            System.out.println("J48 Classifier is ready.");
            System.out.println("Testing on all instances avaialable.");
            System.out.println("Test set instances: " + data.numInstances());

            startTime = System.nanoTime();

            int instanceNumber = 0;

            // label instances
            Instance current;

            //for (int i = 0; i < data.numInstances(); i++) 
            while ((current = loader.getNextInstance(data)) != null) {
                instanceNumber += 1;

                //double classification = learner.classifyInstance(data.instance(i));
                double classification = learner.classifyInstance(current);
                //String instanceClass= Double.toString(data.instance(i).classValue());

                if (classification == 1)// Predicted positive, actually negative
                {
                    Writer.append(resultPath, instanceNumber + "\n");
                }
            }

            long endTime = System.nanoTime();
            long duration = endTime - startTime;
            double seconds = (double) duration / 1000000000.0;

            System.out.println("Testing J48 completed in " + duration + " (ns) or " + seconds + " (s)");
        } catch (Exception e) {
            System.out.println("Could not test J48 classifier due to an error");
        }
    }
}