com.github.r351574nc3.amex.assignment2.App.java Source code

Java tutorial

Introduction

Here is the source code for com.github.r351574nc3.amex.assignment2.App.java

Source

/*
 * The MIT License (MIT)
 * 
 * Copyright (c) 2015 Leo Przybylski
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
package com.github.r351574nc3.amex.assignment2;

import static com.github.r351574nc3.logging.FormattedLogger.*;

import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.MissingArgumentException;

import weka.core.converters.ConverterUtils.DataSink;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.converters.ArffLoader;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import java.io.File;
import java.io.FileNotFoundException;

/**
 * Main entry point class
 *
 * @author Leo Przybylski
 */
public class App {
    protected static final String UNABLE_TO_LOAD_DATA_MESSAGE = "Unable to load the data file";
    protected static final String DEFAULT_MPG_DATA_FILE = "auto-mpg.data";
    protected static final String USER_DIR_KEY = "user.dir";

    protected Instances trained;
    protected Instances test;
    protected LinearRegression classifier;

    public App() {
    }

    public void setTest(final Instances instances) {
        this.test = instances;
    }

    public Instances getTest() {
        return this.test;
    }

    public void setTrained(final Instances instances) {
        this.trained = instances;
    }

    public Instances getTrained() {
        return this.trained;
    }

    public void setClassifier(final LinearRegression linearRegression) {
        this.classifier = linearRegression;
    }

    public LinearRegression getClassifier() {
        return this.classifier;
    }

    /**
     * Loads the MPG data file from UCI
     *
     * @param {@link String} intances of path of the dataset
     * @return {@link Instances} instance containing all records of the dataset.
     */
    public Instances load(final String mpgFile) throws Exception {
        try {
            final Instances retval = DataSource.read(mpgFile);
            retval.setClassIndex(0);
            retval.deleteStringAttributes();
            return retval;
        } catch (Exception e) {
            error("%s:%s", UNABLE_TO_LOAD_DATA_MESSAGE, e.getClass().getSimpleName());
            e.printStackTrace();
        }

        return null;

    }

    /**
     * Trains the model using a {@link LinearRegression} classifier.
     *
     * @throws an Exception
     */
    public void train() throws Exception {
        setClassifier(new LinearRegression());
        getClassifier().buildClassifier(getTrained());
    }

    /**
     * Tests/evaluates the trained model. This method assumes that {@link #train()} was previously called to assign a {@link LinearRegression} 
     * classifier. If it wasn't, an exception will be thrown.
     *
     * @throws Exception if train wasn't called prior.
     */
    public void test() throws Exception {
        if (getClassifier() == null) {
            throw new RuntimeException("Make sure train was run prior to this method call");
        }

        final Evaluation eval = new Evaluation(getTrained());
        eval.evaluateModel(getClassifier(), getTest());
        info("%s", eval.toSummaryString("Results\n\n", false));
        info("Percent of correctly classified instances: %s", eval.pctCorrect());
    }

    /**
     * Generates a predictive model based on a previously trained and evaluated model.
     *
     * @param inputName unlabeled model to load
     * @param outputName path to the file where results will be stored.
     */
    public void predict(final String inputName, final String outputName) throws Exception {
        final Instances input = load(inputName);
        final Instances labeled = new Instances(input);

        for (int i = 0; i < input.numInstances(); i++) {
            final Double clsLabel = getClassifier().classifyInstance(input.instance(i));
            labeled.instance(i).setClassValue(clsLabel);
        }

        boolean isLocal = true;
        if (outputName.contains(File.separator)) {
            isLocal = false;
        }

        final File pwd = isLocal ? new File(System.getProperty(USER_DIR_KEY))
                : new File(outputName).getParentFile();
        if (pwd.exists() && pwd.isDirectory()) {
            DataSink.write(outputName, labeled);
        } else {
            throw new FileNotFoundException("Cannot write to " + outputName);
        }
    }

    public static void main(final String... args) {
        if (args.length < 1) {
            printUsage();
            System.exit(0);
        }

        final Options options = new Options();
        options.addOption(OptionBuilder.withArgName("test").hasArg(true).isRequired(true)
                .withDescription("Path for ARFF test data").create("t"));
        options.addOption(OptionBuilder.withArgName("output").hasArg(true).isRequired(true)
                .withDescription("Path for ARFF output").create("o"));
        options.addOption(OptionBuilder.withArgName("input").hasArg(true).isRequired(true)
                .withDescription("Path for ARFF input").create("i"));

        final CommandLineParser parser = new BasicParser();
        CommandLine cmd = null;
        try {
            cmd = parser.parse(options, args);
        } catch (Exception e) {
            System.out.println(e.getMessage());
            printUsage();
            System.exit(0);
        }

        final String outputName = cmd.getOptionValue("o");
        final String inputName = cmd.getOptionValue("i");
        final String testName = cmd.getOptionValue("t");

        final App app = new App();
        try {
            if (args.length > 0) {
                app.setTrained(app.load(testName));
                app.setTest(app.load(testName));
            }
        } catch (Exception e) {
            error("There was an exception loading training and test datasets: %s", e.getMessage());
        }

        try {
            app.train();
            app.test();
        } catch (Exception e) {
            error("There was an exception testing the model: %s", e.getMessage());
        }

        try {
            app.predict(inputName, outputName);
        } catch (Exception e) {
            error("There was an exception predicting MPG: %s", e.getMessage());
            e.printStackTrace();
        }

        System.exit(0);

    }

    public static void printUsage() {
        System.out.println(new StringBuilder().append("Usage:\n").append(
                "    com.github.r351574nc3.amex.assignment2.App -t <test data.arff> -i <input data.arff> -o <output arff> \n\n")
                .append("            -t <test data.arff>   : Test data used for training and testing the model.\n")
                .append("            -i <input data.arff>  : Unlabeled dataset\n")
                .append("            -o <output data.arff> : Path where results will be output\n").toString());
    }
}