Statistics.WekaFunctions.java Source code

Java tutorial

Introduction

Here is the source code for Statistics.WekaFunctions.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package Statistics;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.StringReader;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.logging.Level;
import java.util.logging.Logger;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.trees.RandomForest;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/**
 *
 * @author Norbert Heijne
 */
// Code inspired by
// http://stackoverflow.com/questions/20017957/how-to-reuse-saved-classifier-created-from-explorerin-weka-in-eclipse-java
public class WekaFunctions {

    Instances train;
    Instances trainParameters;
    Instances RF_testInstances;
    Instances predict;
    RandomForest model = new RandomForest();
    Remove removefilter = new Remove();
    FilteredClassifier fc = new FilteredClassifier();
    public double[] distributions;

    public WekaFunctions() {
        // do nothing
    }

    static String readFile(String path) {

        byte[] encoded;
        try {
            encoded = Files.readAllBytes(Paths.get(path));
            return StandardCharsets.UTF_8.decode(ByteBuffer.wrap(encoded)).toString();
        } catch (IOException ex) {
            Logger.getLogger(WekaFunctions.class.getName()).log(Level.SEVERE, null, ex);
        }
        System.out.println("Could not read file " + path);
        return "";
    }

    public void loadModel(String modelPath) {
        //load model
        try {
            model = (RandomForest) weka.core.SerializationHelper.read(modelPath);
        } catch (Exception ex) {
            Logger.getLogger(WekaFunctions.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public void loadTrainInstance(String instances) {
        System.out.println("Load instance starting");
        System.out.println(instances);
        try {
            // read the instances from the ARFF file
            train = new Instances(new BufferedReader(new StringReader(instances)));
            // Set the to be predicted class
            train.setClassIndex(train.numAttributes() - 2);
            // weka uses a 1-based index parameters are then 61-66 and challenge 69
            removefilter.setOptions(weka.core.Utils.splitOptions("-R 1-65,69 -V")); // set options

            // inform filter about dataset **AFTER** setting options
            removefilter.setInputFormat(train);
            // apply filter
            trainParameters = Filter.useFilter(train, removefilter);
            trainParameters.setClassIndex(trainParameters.numAttributes() - 1);
            predict = new Instances(trainParameters, 0);
            predict.setClassIndex(predict.numAttributes() - 1);
            System.out.println(trainParameters);
        } catch (IOException ex) {
            Logger.getLogger(WekaFunctions.class.getName()).log(Level.SEVERE, null, ex);
        } catch (Exception ex) {
            Logger.getLogger(WekaFunctions.class.getName()).log(Level.SEVERE, null, ex);
        }
        System.out.println("Load instance succesful");
    }

    public void loadTestInstances(boolean verbose, String instances, int personalize_mode) {
        try {
            //Load test instances into data
            //System.out.println("");
            //System.out.println("Loading test instances...");

            Instances data = new Instances(new BufferedReader(new StringReader(instances)));

            // setting class attribute
            data.setClassIndex(data.numAttributes() - 2); //2nd to last attribute is used for classification (last is timestamp)
            removefilter.setOptions(weka.core.Utils.splitOptions("-R 1-65,69 -V")); // set options

            // inform filter about dataset **AFTER** setting options
            removefilter.setInputFormat(data);
            // apply filter
            RF_testInstances = Filter.useFilter(data, removefilter);
            RF_testInstances.setClassIndex(RF_testInstances.numAttributes() - 1);

            System.out.println("-done loading " + RF_testInstances.numInstances() + " test instance(s)");

            // create copy
            Instances labeled = new Instances(RF_testInstances);
            // label instances

            //// label instances
            // for (int i = 0; i < RF_testInstances.numInstances(); i++) {
            //   double[] clsLabel = classifyInstance(RF_testInstances.instance(i),true);
            //   break;
            // }

            this.distributions = classifyInstance(selectTestInstance(), true, personalize_mode);

        } catch (Exception e) {
            //Error reading file
            System.out.println("ERROR!!! - In function loadTestInstances()...");
            System.out.println("-" + e);
        }
    }

    public double[] classifyInstance(Instance testInstance, boolean verbose, int personalize_mode) {
        try {
            //Classify one particular instance from loaded set of Test Instances

            // Specify that the instance belong to the training set 
            // in order to inherit from the set description                                
            //testInstance.setDataset(RF_trainingInstances);

            double[] fDistribution = new double[5];
            switch (personalize_mode) {
            case 1:
                // Adaptation based on ZeroQ classification
                // ZeroQ majority class is 3
                fDistribution[0] = 0.0;
                fDistribution[1] = 0.0;
                fDistribution[2] = 1.0;
                fDistribution[3] = 0.0;
                fDistribution[4] = 0.0;

                if (verbose) {
                    System.out.println("");
                    System.out.println("Classifying selected test instance with forced ZeroQ clasisfication...");
                    System.out.println("-probability of instance being class 1: " + fDistribution[0]);
                    System.out.println("-probability of instance being class 2: " + fDistribution[1]);
                    System.out.println("-probability of instance being class 3: " + fDistribution[2]);
                    System.out.println("-probability of instance being class 4: " + fDistribution[3]);
                    System.out.println("-probability of instance being class 5: " + fDistribution[4]);
                }
                break;
            default:
                // All other personalization modes                   
                // Get the likelihood of each class
                fDistribution = model.distributionForInstance(testInstance);

                if (verbose) {
                    System.out.println("");
                    System.out.println("Classifying selected test instance...");
                    System.out.println("-probability of instance being class 1: " + fDistribution[0]);
                    System.out.println("-probability of instance being class 2: " + fDistribution[1]);
                    System.out.println("-probability of instance being class 3: " + fDistribution[2]);
                    System.out.println("-probability of instance being class 4: " + fDistribution[3]);
                    System.out.println("-probability of instance being class 5: " + fDistribution[4]);
                }
                break;
            }
            return fDistribution;
        } catch (Exception e) {
            //Error reading file
            System.out.println("ERROR!!! - In function classifyInstance()...");
            System.out.println("-" + e);
            return new double[] { 0.0, 0.0, 0.0, 0.0, 0.0 }; //dummy values
        }
    }

    public Instance selectTestInstance() {
        //Select last instance from loaded set of Test Instances

        //Create test instance
        //Instance testInstance = new Instance(newDataTest.firstInstance());
        //Instance testInstance = new Instance(newDataTest.instance(0));
        Instance testInstance = new Instance(RF_testInstances.lastInstance());
        //System.out.println("-selecting last instance in test set RF_testInstances, done");

        // Specify that the instance belong to the training set 
        // in order to inherit from the set description                                
        testInstance.setDataset(trainParameters);
        System.out.println("-selected last instance in test set: " + testInstance.toString());

        return testInstance;
    }

    public void buildLRcls() {
        try {
            System.out.println("building classifier");
            model.setOptions(weka.core.Utils.splitOptions("-S 0 -D")); // set options
            model.buildClassifier(trainParameters);

            // evaluate classifier and print some statistics 
            Evaluation eval = new Evaluation(trainParameters);
            eval.evaluateModel(model, trainParameters);
            System.out.println(eval.toSummaryString());

        } catch (Exception ex) {
            Logger.getLogger(WekaFunctions.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public double predict(double[] parameterValues) {
        //produce proper instance for predict Instances
        double[] inputvalues = new double[7];
        for (int i = 0; i < parameterValues.length; i++) {
            inputvalues[i] = parameterValues[i];
        }
        predict.add(new Instance(1, inputvalues));
        double predictedValue = 0;
        try {
            //predictedValue = model.classifyInstance(predict.lastInstance(), true, personalize_mode);            
            predictedValue = model.classifyInstance(predict.lastInstance());
        } catch (Exception ex) {
            Logger.getLogger(WekaFunctions.class.getName()).log(Level.SEVERE, null, ex);
        }
        predict.delete(0);

        System.out.println("The predicted Challenge for parameter values " + parameterValues.toString() + ": "
                + predictedValue);
        return predictedValue;
    }

    public static void main(String[] args) {
        WekaFunctions wf = new WekaFunctions();
        //wf.loadModel("../../MAINOOR/traindata/LinRegressionModel.model");
        //or
        //wf.loadTrainInstance(wf.readFile("../../MAINOOR/MarioPOMDP-testinstances.arff"))
        wf.predict(new double[] { 4, 4, 4, 4, 4, 4 });
    }
}