cezeri.evaluater.FactoryEvaluation.java Source code

Java tutorial

Introduction

Here is the source code for cezeri.evaluater.FactoryEvaluation.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package cezeri.evaluater;

import cezeri.types.TFigureAttribute;
import cezeri.matrix.CMatrix;
import cezeri.utils.FactoryInstance;
import cezeri.utils.FactoryUtils;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;
import weka.core.Utils;

/**
 *
 * @author BAP1
 */
public class FactoryEvaluation {

    public static double[] simulated = new double[0];
    public static double[] observed = new double[0];

    public static Evaluation performCrossValidate(TFigureAttribute attr, Classifier model, Instances datax,
            int folds, boolean show_text, boolean show_plot) {
        return performCrossValidate(model, datax, folds, show_text, show_plot, attr);
    }

    public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text,
            boolean show_plot) {
        TFigureAttribute attr = new TFigureAttribute();
        return performCrossValidate(model, datax, folds, show_text, show_plot, attr);
    }

    public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text,
            boolean show_plot, TFigureAttribute attr) {
        Random rand = new Random(1);
        Instances randData = new Instances(datax);
        randData.randomize(rand);
        if (randData.classAttribute().isNominal()) {
            randData.stratify(folds);
        }
        Evaluation eval = null;
        try {
            // perform cross-validation
            eval = new Evaluation(randData);
            //            double[] simulated = new double[0];
            //            double[] observed = new double[0];
            //            double[] sim = new double[0];
            //            double[] obs = new double[0];
            for (int n = 0; n < folds; n++) {
                Instances train = randData.trainCV(folds, n, rand);
                Instances validation = randData.testCV(folds, n);
                // build and evaluate classifier
                Classifier clsCopy = Classifier.makeCopy(model);
                clsCopy.buildClassifier(train);

                //                sim = eval.evaluateModel(clsCopy, validation);
                //                obs = validation.attributeToDoubleArray(validation.classIndex());
                //                if (show_plot) {
                //                    double[][] d = new double[2][sim.length];
                //                    d[0] = obs;
                //                    d[1] = sim;
                //                    CMatrix f1 = CMatrix.getInstance(d);
                //                    f1.transpose().plot(attr);
                //                }
                //                if (show_text) {
                //                    // output evaluation
                //                    System.out.println();
                //                    System.out.println("=== Setup for each Cross Validation fold===");
                //                    System.out.println("Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
                //                    System.out.println("Dataset: " + randData.relationName());
                //                    System.out.println("Folds: " + folds);
                //                    System.out.println("Seed: " + 1);
                //                    System.out.println();
                //                    System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
                //                }
                simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation));
                observed = FactoryUtils.concatenate(observed,
                        validation.attributeToDoubleArray(validation.classIndex()));
                //                simulated = FactoryUtils.mean(simulated,eval.evaluateModel(clsCopy, validation));
                //                observed = FactoryUtils.mean(observed,validation.attributeToDoubleArray(validation.classIndex()));
            }

            if (show_plot) {
                double[][] d = new double[2][simulated.length];
                d[0] = observed;
                d[1] = simulated;
                CMatrix f1 = CMatrix.getInstance(d);
                attr.figureCaption = "overall performance";
                f1.transpose().plot(attr);
            }
            if (show_text) {
                // output evaluation
                System.out.println();
                System.out.println("=== Setup for Overall Cross Validation===");
                System.out.println(
                        "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
                System.out.println("Dataset: " + randData.relationName());
                System.out.println("Folds: " + folds);
                System.out.println("Seed: " + 1);
                System.out.println();
                System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
            }
        } catch (Exception ex) {
            Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
        }
        return eval;
    }

    public static Evaluation performCrossValidateTestAlso(Classifier model, Instances datax, Instances test,
            boolean show_text, boolean show_plot) {
        TFigureAttribute attr = new TFigureAttribute();
        Random rand = new Random(1);
        Instances randData = new Instances(datax);
        randData.randomize(rand);

        Evaluation eval = null;
        int folds = randData.numInstances();
        try {
            eval = new Evaluation(randData);
            for (int n = 0; n < folds; n++) {
                //                randData.randomize(rand);
                //                Instances train = randData;                
                Instances train = randData.trainCV(folds, n);
                //                Instances train = randData.trainCV(folds, n, rand);
                Classifier clsCopy = Classifier.makeCopy(model);
                clsCopy.buildClassifier(train);
                Instances validation = randData.testCV(folds, n);
                //                Instances validation = test.testCV(test.numInstances(), n%test.numInstances());
                //                CMatrix.fromInstances(train).showDataGrid();
                //                CMatrix.fromInstances(validation).showDataGrid();

                simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation));
                observed = FactoryUtils.concatenate(observed,
                        validation.attributeToDoubleArray(validation.classIndex()));
            }

            if (show_plot) {
                double[][] d = new double[2][simulated.length];
                d[0] = observed;
                d[1] = simulated;
                CMatrix f1 = CMatrix.getInstance(d);
                attr.figureCaption = "overall performance";
                f1.transpose().plot(attr);
            }
            if (show_text) {
                // output evaluation
                System.out.println();
                System.out.println("=== Setup for Overall Cross Validation===");
                System.out.println(
                        "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
                System.out.println("Dataset: " + randData.relationName());
                System.out.println("Folds: " + folds);
                System.out.println("Seed: " + 1);
                System.out.println();
                System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
            }
        } catch (Exception ex) {
            Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
        }
        return eval;
    }

    private static Evaluation doTest(boolean isTrained, Classifier model, Instances train, Instances test,
            boolean show_text, boolean show_plot, TFigureAttribute attr) {
        Instances data = new Instances(train);
        Random rand = new Random(1);
        data.randomize(rand);
        Evaluation eval = null;
        try {
            //            double[] simulated = null;
            eval = new Evaluation(train);
            if (isTrained) {
                simulated = eval.evaluateModel(model, test);
            } else {
                Classifier clsCopy = Classifier.makeCopy(model);
                clsCopy.buildClassifier(train);
                simulated = eval.evaluateModel(clsCopy, test);
            }
            if (show_plot) {
                observed = test.attributeToDoubleArray(test.classIndex());
                double[][] d = new double[2][simulated.length];
                d[0] = observed;
                d[1] = simulated;
                CMatrix f1 = CMatrix.getInstance(d);
                String[] items = { "Observed", "Simulated" };
                attr.items = items;
                attr.figureCaption = model.getClass().getCanonicalName();
                f1.transpose().plot(attr);
                //                if (attr.axis[0].isEmpty() && attr.axis[1].isEmpty()) {
                //                    f1.transpose().plot(attr);
                //                } else {
                //                    f1.transpose().plot(model.getClass().getCanonicalName(), attr.items, attr.axis);
                //                }
            }
            if (show_text) {
                System.out.println();
                System.out.println("=== Setup for Test ===");
                System.out.println(
                        "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
                System.out.println("Dataset: " + test.relationName());
                System.out.println();
                System.out.println(eval.toSummaryString("=== Test Results ===", false));
            }
        } catch (Exception ex) {
            Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
        }
        return eval;
    }

    public static Evaluation performTest(Classifier model, Instances train, Instances test, boolean show_text,
            boolean show_plot) {
        String[] s = { "index", "value" };
        TFigureAttribute attr = new TFigureAttribute();
        attr.axis = s;
        return performTest(model, train, test, show_text, show_plot, attr);
    }

    public static Evaluation performTest(Classifier model, Instances train, Instances test, boolean show_text,
            boolean show_plot, TFigureAttribute attr) {
        return doTest(false, model, train, test, show_text, show_plot, attr);
    }

    public static Evaluation performTestWithTrainedClassifier(Classifier model, Instances test, boolean show_text,
            boolean show_plot) {
        return doTest(true, model, test, test, show_text, show_plot, new TFigureAttribute());
    }

    public static Evaluation performTestWithTrainedClassifier(Classifier model, Instances test, boolean show_text,
            boolean show_plot, TFigureAttribute attr) {
        return doTest(true, model, test, test, show_text, show_plot, attr);
    }

    public static Evaluation performSubsetTestWithTrainedClassifier(String lstComb, Classifier model,
            Instances data, boolean show_text, boolean show_plot) {
        Instances subsetData = FactoryInstance.getSubsetData(data, lstComb.split(","));
        return doTest(true, model, subsetData, subsetData, show_text, show_plot, new TFigureAttribute());
    }

    public static Evaluation performSubsetTestWithTrainedClassifier(String lstComb, Classifier model,
            Instances data, boolean show_text, boolean show_plot, TFigureAttribute attr) {
        Instances subsetData = FactoryInstance.getSubsetData(data, lstComb.split(","));
        return doTest(true, model, subsetData, subsetData, show_text, show_plot, attr);
    }

    public static Evaluation performSubsetCrossValidate(String lstComb, Instances data, Classifier model,
            boolean show_text, boolean show_plot) {
        Instances subsetData = FactoryInstance.getSubsetData(data, lstComb.split(","));
        Evaluation eval = FactoryEvaluation.performCrossValidate(model, subsetData, 10, show_text, show_plot,
                new TFigureAttribute());
        return eval;
    }

    public static Evaluation performSubsetCrossValidate(String lstComb, Instances data, Classifier model,
            boolean show_text, boolean show_plot, TFigureAttribute attr) {
        Instances subsetData = FactoryInstance.getSubsetData(data, lstComb.split(","));
        Evaluation eval = FactoryEvaluation.performCrossValidate(model, subsetData, 10, show_text, show_plot, attr);
        return eval;
    }

    public static Evaluation performSubsetTest(String lstComb, Instances train, Instances test, Classifier model,
            boolean show_text, boolean show_plot, TFigureAttribute attr) {
        Instances subsetDataTrain = FactoryInstance.getSubsetData(train, lstComb.split(","));
        Instances subsetDataTest = FactoryInstance.getSubsetData(test, lstComb.split(","));
        Evaluation eval = FactoryEvaluation.performTest(model, subsetDataTrain, subsetDataTest, show_text,
                show_plot, attr);
        return eval;
    }

    public static Evaluation performSubsetTest(String lstComb, Instances train, Instances test, Classifier model,
            boolean show_text, boolean show_plot) {
        Instances subsetDataTrain = FactoryInstance.getSubsetData(train, lstComb.split(","));
        Instances subsetDataTest = FactoryInstance.getSubsetData(test, lstComb.split(","));
        Evaluation eval = FactoryEvaluation.performTest(model, subsetDataTrain, subsetDataTest, show_text,
                show_plot);
        return eval;
    }

    public static void saveSubsetClassifier(String lstComb, Classifier model, Instances data, String filePath) {
        try {
            Instances subsetData = FactoryInstance.getSubsetData(data, lstComb.split(","));
            Classifier clsCopy = Classifier.makeCopy(model);
            clsCopy.buildClassifier(subsetData);
            weka.core.SerializationHelper.write(filePath, clsCopy);
        } catch (Exception ex) {
            Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public static void saveClassifier(Classifier model, Instances train, String filePath) {
        try {
            Classifier clsCopy = Classifier.makeCopy(model);
            clsCopy.buildClassifier(train);
            weka.core.SerializationHelper.write(filePath, clsCopy);
        } catch (Exception ex) {
            Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public static Classifier loadClassifier(String filePath) {
        Classifier cls = null;
        try {
            cls = (Classifier) weka.core.SerializationHelper.read(filePath);
        } catch (Exception ex) {
            Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
        }
        return cls;
    }

}