ca.craigthomas.visualclassifier.nn.trainer.TestTrainer.java Source code

Java tutorial

Introduction

Here is the source code for ca.craigthomas.visualclassifier.nn.trainer.TestTrainer.java

Source

/*
 * Copyright (C) 2014 Craig Thomas
 * This project uses an MIT style license - see LICENSE for details.
 */
package ca.craigthomas.visualclassifier.nn.trainer;

import static org.junit.Assert.*;

import java.io.PrintStream;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import org.apache.commons.io.output.ByteArrayOutputStream;
import org.jblas.DoubleMatrix;
import org.junit.Assert;
import org.junit.Test;

import ca.craigthomas.visualclassifier.dataset.DataSet;
import ca.craigthomas.visualclassifier.nn.activation.HyperbolicTangent;
import ca.craigthomas.visualclassifier.nn.activation.IActivationFunction;
import ca.craigthomas.visualclassifier.nn.network.NeuralNetwork;
import ca.craigthomas.visualclassifier.nn.trainer.Trainer;

public class TestTrainer {

    private Trainer mTrainer;
    private List<Integer> mLayerSizes;

    @Test
    public void testTrainerLearnNOTFunction() {
        Random random = new Random();
        mLayerSizes = Arrays.asList(1, 1);
        DoubleMatrix inputs = DoubleMatrix.ones(500, 1);
        DoubleMatrix outputs = DoubleMatrix.ones(500, 1);
        DoubleMatrix testInputs = DoubleMatrix.ones(10, 1);
        DoubleMatrix testOutputs = DoubleMatrix.ones(10, 1);

        for (int index = 0; index < 500; index++) {
            double value = (double) random.nextInt(100) + 1;
            if (value > 50.0) {
                inputs.put(index, 0, 0.0);
                outputs.put(index, 0, 1.0);
            } else {
                inputs.put(index, 0, 1.0);
                outputs.put(index, 0, 0.0);
            }
        }

        mTrainer = new Trainer.Builder(mLayerSizes, inputs, outputs).learningRate(0.001).maxIterations(10000)
                .heartBeat(0).lambda(1.0).build();
        mTrainer.train();

        NeuralNetwork network = mTrainer.getNeuralNetwork();
        for (int index = 0; index < 10; index++) {
            double value = (double) random.nextInt(100) + 1;
            if (value > 50.0) {
                testInputs.put(index, 0, 1.0);
                testOutputs.put(index, 0, 0.0);
            } else {
                testInputs.put(index, 0, 0.0);
                testOutputs.put(index, 0, 1.0);
            }
        }
        DoubleMatrix predictions = network.predict(testInputs);
        Assert.assertArrayEquals(testOutputs.toArray(), predictions.toArray(), 0.15);
    }

    @Test
    // Test assumes that DataSet is working correctly!
    public void testTrainerLearnNOTFunctionWithDataSet() {
        Random random = new Random();
        mLayerSizes = Arrays.asList(1, 1);
        DoubleMatrix inputs = DoubleMatrix.ones(500, 1);
        DoubleMatrix outputs = DoubleMatrix.ones(500, 1);
        DoubleMatrix testInputs = DoubleMatrix.ones(10, 1);
        DoubleMatrix testOutputs = DoubleMatrix.ones(10, 1);

        for (int index = 0; index < 500; index++) {
            double value = (double) random.nextInt(100) + 1;
            if (value > 50.0) {
                inputs.put(index, 0, 0.0);
                outputs.put(index, 0, 1.0);
            } else {
                inputs.put(index, 0, 1.0);
                outputs.put(index, 0, 0.0);
            }
        }

        DataSet dataSet = new DataSet(true, inputs, outputs);

        mTrainer = new Trainer.Builder(mLayerSizes, dataSet).learningRate(0.001).maxIterations(10000).heartBeat(0)
                .lambda(1.0).build();
        mTrainer.train();

        NeuralNetwork network = mTrainer.getNeuralNetwork();
        for (int index = 0; index < 10; index++) {
            double value = (double) random.nextInt(100) + 1;
            if (value > 50.0) {
                testInputs.put(index, 0, 1.0);
                testOutputs.put(index, 0, 0.0);
            } else {
                testInputs.put(index, 0, 0.0);
                testOutputs.put(index, 0, 1.0);
            }
        }
        DoubleMatrix predictions = network.predict(testInputs);
        Assert.assertArrayEquals(testOutputs.toArray(), predictions.toArray(), 0.15);
    }

    @Test
    public void testTrainerLearnANDFunction() {
        Random random = new Random();
        mLayerSizes = Arrays.asList(2, 1);
        DoubleMatrix inputs = DoubleMatrix.ones(500, 2);
        DoubleMatrix outputs = DoubleMatrix.ones(500, 1);
        DoubleMatrix testInputs = DoubleMatrix.ones(10, 2);
        DoubleMatrix testOutputs = DoubleMatrix.ones(10, 1);

        for (int index = 0; index < 500; index++) {
            double value1 = (double) random.nextInt(100) + 1;
            double value2 = (double) random.nextInt(100) + 1;

            if (value1 > 50.0) {
                inputs.put(index, 0, 1.0);
            } else {
                inputs.put(index, 0, 0.0);
            }

            if (value2 > 50.0) {
                inputs.put(index, 1, 1.0);
            } else {
                inputs.put(index, 1, 0.0);
            }

            if (value1 > 50.0 && value2 > 50.0) {
                outputs.put(index, 0, 1.0);
            } else {
                outputs.put(index, 0, 0.0);
            }
        }

        mTrainer = new Trainer.Builder(mLayerSizes, inputs, outputs).learningRate(0.001).maxIterations(20000)
                .heartBeat(0).lambda(1.0).build();
        mTrainer.train();

        NeuralNetwork network = mTrainer.getNeuralNetwork();
        for (int index = 0; index < 10; index++) {
            double value1 = (double) random.nextInt(100) + 1;
            double value2 = (double) random.nextInt(100) + 1;

            if (value1 > 50.0) {
                testInputs.put(index, 0, 1.0);
            } else {
                testInputs.put(index, 0, 0.0);
            }

            if (value2 > 50.0) {
                testInputs.put(index, 1, 1.0);
            } else {
                testInputs.put(index, 1, 0.0);
            }

            if (value1 > 50.0 && value2 > 50.0) {
                testOutputs.put(index, 0, 1.0);
            } else {
                testOutputs.put(index, 0, 0.0);
            }
        }
        DoubleMatrix predictions = network.predict(testInputs);
        Assert.assertArrayEquals(testOutputs.toArray(), predictions.toArray(), 0.15);
    }

    @Test
    public void testTrainerLearnORFunction() {
        Random random = new Random();
        mLayerSizes = Arrays.asList(2, 1);
        DoubleMatrix inputs = DoubleMatrix.ones(500, 2);
        DoubleMatrix outputs = DoubleMatrix.ones(500, 1);
        DoubleMatrix testInputs = DoubleMatrix.ones(10, 2);
        DoubleMatrix testOutputs = DoubleMatrix.ones(10, 1);

        for (int index = 0; index < 500; index++) {
            double value1 = (double) random.nextInt(100) + 1;
            double value2 = (double) random.nextInt(100) + 1;

            if (value1 > 50.0) {
                inputs.put(index, 0, 1.0);
            } else {
                inputs.put(index, 0, 0.0);
            }

            if (value2 > 50.0) {
                inputs.put(index, 1, 1.0);
            } else {
                inputs.put(index, 1, 0.0);
            }

            if (value1 > 50.0 || value2 > 50.0) {
                outputs.put(index, 0, 1.0);
            } else {
                outputs.put(index, 0, 0.0);
            }
        }

        mTrainer = new Trainer.Builder(mLayerSizes, inputs, outputs).learningRate(0.001).maxIterations(15000)
                .heartBeat(0).lambda(1.0).build();
        mTrainer.train();

        NeuralNetwork network = mTrainer.getNeuralNetwork();
        for (int index = 0; index < 10; index++) {
            double value1 = (double) random.nextInt(100) + 1;
            double value2 = (double) random.nextInt(100) + 1;

            if (value1 > 50.0) {
                testInputs.put(index, 0, 1.0);
            } else {
                testInputs.put(index, 0, 0.0);
            }

            if (value2 > 50.0) {
                testInputs.put(index, 1, 1.0);
            } else {
                testInputs.put(index, 1, 0.0);
            }

            if (value1 > 50.0 || value2 > 50.0) {
                testOutputs.put(index, 0, 1.0);
            } else {
                testOutputs.put(index, 0, 0.0);
            }
        }
        DoubleMatrix predictions = network.predict(testInputs);
        Assert.assertArrayEquals(testOutputs.toArray(), predictions.toArray(), 0.15);
    }

    @Test
    public void testActivationFunctionSentToNeuralNetwork() {
        mLayerSizes = Arrays.asList(2, 1);
        IActivationFunction activationFunction = new HyperbolicTangent();
        DoubleMatrix inputs = DoubleMatrix.ones(500, 2);
        DoubleMatrix outputs = DoubleMatrix.ones(500, 1);

        mTrainer = new Trainer.Builder(mLayerSizes, inputs, outputs).learningRate(0.001).maxIterations(0)
                .heartBeat(0).activationFunction(activationFunction).build();
        mTrainer.train();

        assertEquals(activationFunction, mTrainer.getNeuralNetwork().getActivationFunction());
    }

    @Test
    public void testRecordCostsRecordsAllIterations() {
        Random random = new Random();
        mLayerSizes = Arrays.asList(2, 1);
        DoubleMatrix inputs = DoubleMatrix.ones(500, 2);
        DoubleMatrix outputs = DoubleMatrix.ones(500, 1);

        for (int index = 0; index < 500; index++) {
            double value1 = (double) random.nextInt(100) + 1;
            double value2 = (double) random.nextInt(100) + 1;

            if (value1 > 50.0) {
                inputs.put(index, 0, 1.0);
            } else {
                inputs.put(index, 0, 0.0);
            }

            if (value2 > 50.0) {
                inputs.put(index, 1, 1.0);
            } else {
                inputs.put(index, 1, 0.0);
            }

            if (value1 > 50.0 || value2 > 50.0) {
                outputs.put(index, 0, 1.0);
            } else {
                outputs.put(index, 0, 0.0);
            }
        }

        mTrainer = new Trainer.Builder(mLayerSizes, inputs, outputs).learningRate(0.001).maxIterations(200)
                .heartBeat(0).recordCosts().build();
        mTrainer.train();

        List<Double> costs = mTrainer.getCosts();

        assertEquals(200, costs.size());
        for (int i = 0; i < costs.size(); i++) {
            assertTrue(costs.get(i).doubleValue() > 0.0);
        }
    }

    @Test
    public void testHeartbeatOutputToConsole() {
        Random random = new Random();
        mLayerSizes = Arrays.asList(2, 1);
        DoubleMatrix inputs = DoubleMatrix.ones(500, 2);
        DoubleMatrix outputs = DoubleMatrix.ones(500, 1);

        for (int index = 0; index < 500; index++) {
            double value1 = (double) random.nextInt(100) + 1;
            double value2 = (double) random.nextInt(100) + 1;

            if (value1 > 50.0) {
                inputs.put(index, 0, 1.0);
            } else {
                inputs.put(index, 0, 0.0);
            }

            if (value2 > 50.0) {
                inputs.put(index, 1, 1.0);
            } else {
                inputs.put(index, 1, 0.0);
            }

            if (value1 > 50.0 || value2 > 50.0) {
                outputs.put(index, 0, 1.0);
            } else {
                outputs.put(index, 0, 0.0);
            }
        }

        ByteArrayOutputStream stdOut = new ByteArrayOutputStream();
        System.setOut(new PrintStream(stdOut));

        mTrainer = new Trainer.Builder(mLayerSizes, inputs, outputs).learningRate(0.001).maxIterations(200)
                .heartBeat(1).recordCosts().build();
        mTrainer.train();

        String standardOut = stdOut.toString();
        String[] strings = standardOut.split("\\n");
        assertEquals(200, strings.length);
        for (int i = 0; i < strings.length; i++) {
            assertTrue(strings[i].contains("Iteration: " + (i + 1)));
        }
    }
}