ca.craigthomas.visualclassifier.commandline.TrainCommand.java Source code

Java tutorial

Introduction

Here is the source code for ca.craigthomas.visualclassifier.commandline.TrainCommand.java

Source

/*
 * Copyright (C) 2014 Craig Thomas
 * This project uses an MIT style license - see LICENSE for details.
 */
package ca.craigthomas.visualclassifier.commandline;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

import javax.imageio.ImageIO;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.stat.StatUtils;
import org.jblas.DoubleMatrix;

import ca.craigthomas.visualclassifier.dataset.DataSet;
import ca.craigthomas.visualclassifier.dataset.Prediction;
import ca.craigthomas.visualclassifier.imageprocessing.Image;
import ca.craigthomas.visualclassifier.nn.network.NeuralNetwork;
import ca.craigthomas.visualclassifier.nn.trainer.Trainer;

/**
 * The TrainCommand is used to train a neural network based upon a number of
 * positive and negative examples. 
 */
public class TrainCommand {
    // The logger for the class
    private final static Logger LOGGER = Logger.getLogger(Runner.class.getName());
    // The underlying data set
    private DataSet mDataSet;
    // The arguments passed to the command
    TrainArguments arguments;

    public TrainCommand(TrainArguments arguments) {
        this.arguments = arguments;
    }

    /**
     * Load the data from a CSV file.
     */
    public void loadFromCSV() {
        mDataSet = new DataSet(true);
        try {
            mDataSet.addFromCSVFile(arguments.csvFile);
            LOGGER.log(Level.INFO, "loaded " + mDataSet.getNumSamples() + " sample(s)");
        } catch (IOException e) {
            LOGGER.log(Level.SEVERE, e.getMessage());
            mDataSet = null;
        }
    }

    /**
     * Load data from a directory. Assumes that all samples are images.
     * The truth value indicates whether it is a positive or negative sample.
     * 
     * @param directory the directory to load images from
     * @param truth whether the samples are positive or negative
     */
    public void loadFromDirectory(File directory, double truth) {
        File[] files = directory.listFiles();
        for (File file : files) {
            String filename = file.getAbsolutePath();
            Image image = new Image(filename);

            if (image.getWidth() != arguments.requiredWidth || image.getHeight() != arguments.requiredHeight) {
                LOGGER.log(Level.WARNING,
                        "file " + filename + " not correct size, skipping (want " + arguments.requiredWidth + "x"
                                + arguments.requiredHeight + ", got " + image.getWidth() + "x" + image.getHeight()
                                + ")");
            } else {
                if (arguments.color) {
                    mDataSet.addSample(image.convertColorToMatrix(truth));
                } else {
                    mDataSet.addSample(image.convertGrayscaleToMatrix(truth));
                }
            }
        }
    }

    /**
     * Loads up the files from the specified directories.
     */
    public void loadFromDirectories() {
        File positiveDir = new File(arguments.positiveDir);
        File negativeDir = new File(arguments.negativeDir);

        if (!positiveDir.isDirectory()) {
            LOGGER.log(Level.SEVERE, "positives directory [" + arguments.positiveDir + "] is not a directory");
            return;
        }

        if (!negativeDir.isDirectory()) {
            LOGGER.log(Level.SEVERE, "negatives directory [" + arguments.negativeDir + "] is not a directory");
            return;
        }

        mDataSet = new DataSet(true);
        loadFromDirectory(positiveDir, 1.0);
        loadFromDirectory(negativeDir, 0.0);
        LOGGER.log(Level.INFO, "loaded " + mDataSet.getNumSamples() + " sample(s)");
    }

    public void saveImage(Image image, File path, String filename) {
        File saveFile = new File(path, filename);
        try {
            ImageIO.write(image.getBufferedImage(), "png", saveFile);
        } catch (IOException e) {
            LOGGER.log(Level.SEVERE, "could not save file [" + saveFile.getAbsolutePath() + "]");
        }
    }

    public void saveResults(NeuralNetwork bestModel, DataSet bestFold) {
        File directory = new File(arguments.saveDir);
        if (!directory.isDirectory()) {
            LOGGER.log(Level.SEVERE, "save directory [" + arguments.saveDir + "] is not a directory");
            return;
        }

        Prediction predictions = new Prediction(bestModel, arguments.predictionThreshold);
        predictions.predict(bestFold);
        DoubleMatrix falsePositives = predictions.getFalsePositiveSamples();
        DoubleMatrix falseNegatives = predictions.getFalseNegativeSamples();
        for (int i = 0; i < falsePositives.rows; i++) {
            Image image = new Image(falsePositives.getRow(i), arguments.requiredWidth, arguments.requiredHeight,
                    arguments.color);
            saveImage(image, directory, "fp" + (i + 1) + ".png");
        }
        for (int i = 0; i < falseNegatives.rows; i++) {
            Image image = new Image(falseNegatives.getRow(i), arguments.requiredWidth, arguments.requiredHeight,
                    arguments.color);
            saveImage(image, directory, "fn" + (i + 1) + ".png");
        }
    }

    public void execute() {
        NeuralNetwork bestModel = null;
        DataSet bestFold = null;
        double[] tp = new double[arguments.folds];
        double[] fp = new double[arguments.folds];
        double[] tn = new double[arguments.folds];
        double[] fn = new double[arguments.folds];
        double[] precision = new double[arguments.folds];
        double[] recall = new double[arguments.folds];
        double[] f1 = new double[arguments.folds];
        double bestF1 = 0;

        // Step 1: create the dataset
        if (!arguments.csvFile.isEmpty()) {
            loadFromCSV();
        } else {
            loadFromDirectories();
        }

        if (mDataSet == null) {
            LOGGER.log(Level.SEVERE, "no data set could be built, exiting");
            return;
        }

        // Step 2: Generate layer information
        List<Integer> layerSizes = new ArrayList<>();
        layerSizes.add(mDataSet.getNumColsSamples());
        if (arguments.layer1 != 0) {
            layerSizes.add(arguments.layer1);
        }
        if (arguments.layer2 != 0) {
            layerSizes.add(arguments.layer2);
        }
        layerSizes.add(arguments.outputLayer);

        // Step 3: generate the folds and train the model
        for (int fold = 0; fold < arguments.folds; fold++) {
            LOGGER.log(Level.INFO, "processing fold " + (fold + 1));
            LOGGER.log(Level.INFO, "randomizing dataset");
            mDataSet.randomize();
            LOGGER.log(Level.INFO, "generating training and testing sets");
            Pair<DataSet, DataSet> split = mDataSet.splitEqually(arguments.split);
            DataSet trainingData = split.getLeft();
            DataSet testingData = split.getRight();
            LOGGER.log(Level.INFO, "training neural network...");
            trainingData.randomize();
            Trainer trainer = new Trainer.Builder(layerSizes, trainingData).maxIterations(arguments.iterations)
                    .heartBeat(arguments.heartBeat).learningRate(arguments.learningRate).lambda(arguments.lambda)
                    .build();
            trainer.train();

            // Step 4: evaluate each model
            NeuralNetwork model = trainer.getNeuralNetwork();
            Prediction prediction = new Prediction(model, arguments.predictionThreshold);
            prediction.predict(testingData);
            System.out.println("True Positives " + prediction.getTruePositives());
            System.out.println("False Positives " + prediction.getFalsePositives());
            System.out.println("True Negatives " + prediction.getTrueNegatives());
            System.out.println("False Negatives " + prediction.getFalseNegatives());
            System.out.println("Precision " + prediction.getPrecision());
            System.out.println("Recall " + prediction.getRecall());
            System.out.println("F1 " + prediction.getF1());

            tp[fold] = prediction.getTruePositives();
            fp[fold] = prediction.getFalsePositives();
            tn[fold] = prediction.getTrueNegatives();
            fn[fold] = prediction.getFalseNegatives();
            precision[fold] = prediction.getPrecision();
            recall[fold] = prediction.getRecall();
            f1[fold] = prediction.getF1();
            if (f1[fold] > bestF1) {
                bestModel = model;
                bestFold = mDataSet.dup();
                bestF1 = f1[fold];
            }
        }

        // Step 6: save the best information to the specified directory
        if (!arguments.saveDir.isEmpty()) {
            saveResults(bestModel, bestFold);
        }

        // Step 5: compute the overall statistics
        System.out.println("Overall Statistics");
        System.out.println("True Positives " + StatUtils.mean(tp) + " (" + StatUtils.variance(tp) + ")");
        System.out.println("False Positives " + StatUtils.mean(fp) + " (" + StatUtils.variance(fp) + ")");
        System.out.println("True Negatives " + StatUtils.mean(tn) + " (" + StatUtils.variance(tn) + ")");
        System.out.println("False Negatives " + StatUtils.mean(fn) + " (" + StatUtils.variance(fn) + ")");
        System.out.println("Precision " + StatUtils.mean(precision) + " (" + StatUtils.variance(precision) + ")");
        System.out.println("Recall " + StatUtils.mean(recall) + " (" + StatUtils.variance(recall) + ")");
        System.out.println("F1 " + StatUtils.mean(f1) + " (" + StatUtils.variance(f1) + ")");
    }
}