ExperimentDemo.java Source code

Java tutorial

Introduction

Here is the source code for ExperimentDemo.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 * ExperimentDemo.java
 * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
 */

import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.core.Range;
import weka.core.Utils;
import weka.experiment.ClassifierSplitEvaluator;
import weka.experiment.CrossValidationResultProducer;
import weka.experiment.Experiment;
import weka.experiment.InstancesResultListener;
import weka.experiment.PairedCorrectedTTester;
import weka.experiment.PairedTTester;
import weka.experiment.PropertyNode;
import weka.experiment.RandomSplitResultProducer;
import weka.experiment.RegressionSplitEvaluator;
import weka.experiment.ResultMatrix;
import weka.experiment.ResultMatrixPlainText;
import weka.experiment.SplitEvaluator;

import java.beans.IntrospectionException;
import java.beans.PropertyDescriptor;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;

import javax.swing.DefaultListModel;

/**
 * A class for demonstrating the use of the Experiment class for one classifier
 * and one or more datasets. Cross-validation or random splits are possible, as
 * well as classification or regression.
 * 
 * @author  fracpete (fracpete at waikato dot ac dot nz)
 * @version $Revision$
 */
public class ExperimentDemo {

    /**
     * Expects the following parameters: 
     * <ul>
     *   <li>-classifier "classifier incl. parameters"</li>
     *   <li>-exptype "classification|regression"</li>
     *   <li>-splittype "crossvalidation|randomsplit"</li>
     *   <li>-runs "# of runs"</li>
     *   <li>-folds "# of cross-validation folds"</li>
     *   <li>-percentage "percentage for randomsplit"</li>
     *   <li>-result "arff file for storing the results"</li>
     *   <li>-t "dataset" (can be supplied multiple times)</li>
     * </ul>
     * 
     * @param args   the commandline arguments
     * @throws Exception   if something goes wrong
     */
    public static void main(String[] args) throws Exception {
        // parameters provided?
        if (args.length == 0) {
            System.out.println("\nUsage: weka.examples.experiment.ExperimentDemo\n"
                    + "\t   -classifier <classifier incl. parameters>\n"
                    + "\t   -exptype <classification|regression>\n"
                    + "\t   -splittype <crossvalidation|randomsplit>\n" + "\t   -runs <# of runs>\n"
                    + "\t   -folds <folds for CV>\n" + "\t   -percentage <percentage for randomsplit>\n"
                    + "\t   -result <ARFF file for storing the results>\n"
                    + "\t   -t dataset (can be supplied multiple times)\n");
            System.exit(1);
        }

        // 1. setup the experiment
        System.out.println("Setting up...");
        Experiment exp = new Experiment();
        exp.setPropertyArray(new Classifier[0]);
        exp.setUsePropertyIterator(true);

        String option;

        // classification or regression
        option = Utils.getOption("exptype", args);
        if (option.length() == 0)
            throw new IllegalArgumentException("No experiment type provided!");

        SplitEvaluator se = null;
        /*
         * Interface to objects able to generate a fixed set of results for a particular split of a dataset.
         * The set of results should contain fields related to any settings of the SplitEvaluator (not including the dataset name.
         * For example, one field for the classifier used to get the results, another for the classifier options, etc).
         * Possible implementations of SplitEvaluator: StdClassification results, StdRegression results.
         */
        Classifier sec = null;
        boolean classification = false;
        if (option.equals("classification")) {
            classification = true;
            se = new ClassifierSplitEvaluator();
            /*
             * A SplitEvaluator that produces results for a classification scheme on a nominal class attribute. 
             */
            sec = ((ClassifierSplitEvaluator) se).getClassifier();
        } else if (option.equals("regression")) {
            se = new RegressionSplitEvaluator();
            sec = ((RegressionSplitEvaluator) se).getClassifier();
        } else {
            throw new IllegalArgumentException("Unknown experiment type '" + option + "'!");
        }

        // crossvalidation or randomsplit
        option = Utils.getOption("splittype", args);
        if (option.length() == 0)
            throw new IllegalArgumentException("No split type provided!");

        if (option.equals("crossvalidation")) {
            CrossValidationResultProducer cvrp = new CrossValidationResultProducer();
            /*
             * Generates for each run, carries out an n-fold cross-validation, using the set SplitEvaluator to generate some results.
             * If the class attribute is nominal, the dataset is stratified. Results for each fold are generated, so you may wish to use
             * this in addition with an AveragingResultProducer to obtain averages for each run. 
             */
            option = Utils.getOption("folds", args);
            if (option.length() == 0)
                throw new IllegalArgumentException("No folds provided!");
            cvrp.setNumFolds(Integer.parseInt(option));
            cvrp.setSplitEvaluator(se);

            PropertyNode[] propertyPath = new PropertyNode[2];
            /*
             * Stores information on a property of an object: the class of the object with the property;
             * the property descriptor, and the current value.
             */
            try {
                propertyPath[0] = new PropertyNode(se,
                        new PropertyDescriptor("splitEvaluator", CrossValidationResultProducer.class),
                        CrossValidationResultProducer.class);
                propertyPath[1] = new PropertyNode(sec, new PropertyDescriptor("classifier", se.getClass()),
                        se.getClass());
            } catch (IntrospectionException e) {
                e.printStackTrace();
            }

            exp.setResultProducer(cvrp);
            exp.setPropertyPath(propertyPath);

        } else if (option.equals("randomsplit")) {
            RandomSplitResultProducer rsrp = new RandomSplitResultProducer();
            rsrp.setRandomizeData(true);
            option = Utils.getOption("percentage", args);
            if (option.length() == 0)
                throw new IllegalArgumentException("No percentage provided!");
            rsrp.setTrainPercent(Double.parseDouble(option));
            rsrp.setSplitEvaluator(se);

            PropertyNode[] propertyPath = new PropertyNode[2];
            try {
                propertyPath[0] = new PropertyNode(se,
                        new PropertyDescriptor("splitEvaluator", RandomSplitResultProducer.class),
                        RandomSplitResultProducer.class);
                propertyPath[1] = new PropertyNode(sec, new PropertyDescriptor("classifier", se.getClass()),
                        se.getClass());
            } catch (IntrospectionException e) {
                e.printStackTrace();
            }

            exp.setResultProducer(rsrp);
            exp.setPropertyPath(propertyPath);
        } else {
            throw new IllegalArgumentException("Unknown split type '" + option + "'!");
        }

        // runs
        option = Utils.getOption("runs", args);
        if (option.length() == 0)
            throw new IllegalArgumentException("No runs provided!");
        exp.setRunLower(1);
        exp.setRunUpper(Integer.parseInt(option));

        // classifier
        option = Utils.getOption("classifier", args);
        if (option.length() == 0)
            throw new IllegalArgumentException("No classifier provided!");
        String[] options = Utils.splitOptions(option);
        String classname = options[0];
        options[0] = "";
        Classifier c = (Classifier) Utils.forName(Classifier.class, classname, options);
        exp.setPropertyArray(new Classifier[] { c });

        // datasets
        boolean data = false;
        DefaultListModel model = new DefaultListModel();
        do {
            option = Utils.getOption("t", args);
            if (option.length() > 0) {
                File file = new File(option);
                if (!file.exists())
                    throw new IllegalArgumentException("File '" + option + "' does not exist!");
                data = true;
                model.addElement(file);
            }
        } while (option.length() > 0);
        if (!data)
            throw new IllegalArgumentException("No data files provided!");
        exp.setDatasets(model);

        // result
        option = Utils.getOption("result", args);
        if (option.length() == 0)
            throw new IllegalArgumentException("No result file provided!");
        InstancesResultListener irl = new InstancesResultListener();
        irl.setOutputFile(new File(option));
        exp.setResultListener(irl);

        // 2. run experiment
        System.out.println("Initializing...");
        exp.initialize();
        System.out.println("Running...");
        exp.runExperiment();
        System.out.println("Finishing...");
        exp.postProcess();

        // 3. calculate statistics and output them
        System.out.println("Evaluating...");
        PairedTTester tester = new PairedCorrectedTTester();
        /*
         * Calculates T-Test statistics on data stored in a set of instances. 
         */
        Instances result = new Instances(new BufferedReader(new FileReader(irl.getOutputFile())));
        tester.setInstances(result);
        tester.setSortColumn(-1);
        tester.setRunColumn(result.attribute("Key_Run").index());
        if (classification)
            tester.setFoldColumn(result.attribute("Key_Fold").index());
        tester.setDatasetKeyColumns(new Range("" + (result.attribute("Key_Dataset").index() + 1)));
        tester.setResultsetKeyColumns(new Range("" + (result.attribute("Key_Scheme").index() + 1) + ","
                + (result.attribute("Key_Scheme_options").index() + 1) + ","
                + (result.attribute("Key_Scheme_version_ID").index() + 1)));
        tester.setResultMatrix(new ResultMatrixPlainText());
        tester.setDisplayedResultsets(null);
        tester.setSignificanceLevel(0.05);
        tester.setShowStdDevs(true);
        // fill result matrix (but discarding the output)
        if (classification)
            tester.multiResultsetFull(0, result.attribute("Percent_correct").index());
        else
            tester.multiResultsetFull(0, result.attribute("Correlation_coefficient").index());
        // output results for reach dataset
        System.out.println("\nResult:");
        ResultMatrix matrix = tester.getResultMatrix();
        for (int i = 0; i < matrix.getColCount(); i++) {
            System.out.println(matrix.getColName(i));
            System.out.println("    Perc. correct: " + matrix.getMean(i, 0));
            System.out.println("    StdDev: " + matrix.getStdDev(i, 0));
        }
    }
}