stratego.neural.net.NeuralNetTest.java Source code

Java tutorial

Introduction

Here is the source code for stratego.neural.net.NeuralNetTest.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package stratego.neural.net;

import java.io.IOException;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.eval.Evaluation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.ArrayList;
import java.util.List;
import javax.swing.*;
import java.io.File;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

/**
 *
 * @author s146928
 */

// NOTE BEFORE DELIVERING, CHECK COMMENTS FOR PROFESSIONALITY
public class NeuralNetTest {

    public static void main(String[] args) throws Exception {
        int numInput = 12; //Setting the number of input neurons
        int numHidden = 50; // SUBJECT TO CHANGE setting the number of hidden layer neurons
        int numOutput = 9; // setting the number of output neurons
        int rngSeed = 123; // setting the RNG seed 
        int batchSize = 150; // SUBJECT TO CHANGE setting the size of the mini-batch        
        int numEpochs = 150; // SUBJECT TO CHANGE setting the number of epochs to run the training for
        int iterations = 10; // SUBJECT TO CHANGE setting the number of iterations
        double learningRate = 0.05; // SUBJECT TO CHANGE the learning rate of the network

        // Reading in the data from a file
        //MIGHT NOT NEED THIS
        /*
        int numLinesToSkip = 0; // SUBJECT TO CHANGE The amount of lines to be skipped (should be zero if we format our data well)
        String delimiter = ","; // what the data is going to be split on
            
        RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
        recordReader.initialize(new FileSplit(new ClassPathResource(data).getFile())); // NOTE UPDATE "data.txt" to where the actual file is, and it's name!
        */

        int labelIndex = 12; // SUBJECT TO CHANGE: The index of where the label will be (The label is what the outcome should be)
        String data = "src/Data/test_data_1.csv"; // SUBJECT TO CHANGE the location of our data

        DataSet allData = readCSVDataset(data, batchSize, labelIndex, numOutput);

        allData.shuffle();
        double ratio = 0.9; // SUBJECT TO CHANGE the percentage of data to be used for training (now set to 80%)

        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(ratio);

        DataSet trainingData = testAndTrain.getTrain();
        DataSet testData = testAndTrain.getTest();

        //Normalizing our data (giving us mean 0, unit variance):
        DataNormalization normalizer = new NormalizerStandardize();
        normalizer.fit(trainingData); // collect the statistics from the training data. This does not modify the input data
        normalizer.transform(trainingData); // Apply normalization to the training data
        normalizer.transform(testData); // Apply the normalization to the test data, using the statistics from the training set (which is bigger so should be the same or better)

        //Building the neural network  NOTE: the format of the network is different across examples, might want to try a bunch of them out or research what vague things do
        System.out.println("Build Model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(rngSeed).iterations(iterations)
                .learningRate(learningRate).updater(Updater.NESTEROVS).momentum(0.9) // Not exactly sure what this does, might want to leave it out or properly research this
                .regularization(true).l2(1e-4) // applying L2 regularizations to work against overfitting (that's an "l" not a one)
                .list().layer(0, new DenseLayer.Builder().nIn(numInput).nOut(numHidden).activation("relu") // again, not entirely sure what relu is, but this is the activation fucntion (might need to research)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(numHidden).nOut(numOutput)
                                .activation("softmax") // again, but vague, but I have some idea what a softmax function is (S function, 1/(1+e^-x))
                                .weightInit(WeightInit.XAVIER).build())
                .pretrain(false).backprop(true) // of course we're using backpropagation!
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1)); // Listens to the score every iteration (might want to raise this value if we start training on large datasets

        double[] trainAccuracies = new double[numEpochs];
        double[] testAccuracies = new double[numEpochs];
        Evaluation eval = new Evaluation(numOutput);

        // HERE BE BUGS!
        System.out.println("Train model....");

        model.fit(trainingData);

        for (int i = 0; i < numEpochs; i++) { // for the total amount of epochs
            System.out.println("=====================");
            System.out.println("     Epoch " + i);
            System.out.println("=====================");
            model.fit(trainingData);

            INDArray outputTraining = model.output(trainingData.getFeatureMatrix());
            INDArray outputTest = model.output(testData.getFeatureMatrix());

            //Here we want some of the data from the evaluation, so we can make nice plots regarding the accuracy so we can say something about overfitting

            //evaluating on the training data and storing the accuracy in the array
            eval.eval(trainingData.getLabels(), outputTraining);
            trainAccuracies[i] = eval.accuracy();

            //evaluating on the test data and storing the accuracy in the array
            eval.eval(testData.getLabels(), outputTest);
            testAccuracies[i] = eval.accuracy();
        }

        //creating a list for the two accuracy arrays so we can use them for plotting
        List<NamedDataSet> AccuracyData = new ArrayList<>();
        NamedDataSet trainAccurSet = new NamedDataSet("Training", trainAccuracies);
        NamedDataSet testAccurSet = new NamedDataSet("Test", testAccuracies);

        AccuracyData.add(trainAccurSet);
        AccuracyData.add(testAccurSet);

        plotDataSet(AccuracyData);

        //Evaluate the model on the test set

        System.out.println("Evaluate model....");

        INDArray outputTraining = model.output(trainingData.getFeatureMatrix());
        INDArray outputTest = model.output(testData.getFeatureMatrix());
        System.out.println("Scores on training data");
        eval.eval(trainingData.getLabels(), outputTraining);
        System.out.println(eval.stats());
        System.out.println("Scores on test data");
        eval.eval(testData.getLabels(), outputTest);
        System.out.println(eval.stats());

        //Degbug code

        /*
         System.out.println("Test accuracy");
        System.out.print("[");
        for(int i=0; i<testAccuracies.length;i++){
        System.out.print(testAccuracies[i]+" ");
        }
        System.out.print("]");
        System.out.println();
            
         System.out.println("Train accuracy");
        System.out.print("[");
        for(int i=0; i<trainAccuracies.length;i++){
        System.out.print(trainAccuracies[i]+" ");
        }
        System.out.print("]");
        System.out.println();
            
        */

        //This is the predicting bit!

        /*
        double[][] voorspellingData = new double[][]{{4,4,3,4,1,5,4,4,4,2,4,3},{4,4,4,1,4,5,4,4,4,3,2,3}};
        INDArray voorspeldata = Nd4j.create(voorspellingData);
            
        int[] resultaat = model.predict(voorspeldata);
            
        System.out.println("Testje voor het voorspellen");
        System.out.print("[");
        for(int i=0; i<resultaat.length;i++){
        System.out.print(resultaat[i]+" ");
        }
        System.out.print("]");
        System.out.println();
        */

    } // Ends the whole damn thing

    /**
     * used for testing and training
     *
     * @param csvFileClasspath
     * @param batchSize
     * @param labelIndex
     * @param numClasses
     * @return
     * @throws IOException
     * @throws InterruptedException
     */
    private static DataSet readCSVDataset(String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
            throws IOException, InterruptedException {

        RecordReader rr = new CSVRecordReader();
        File file = new File(csvFileClasspath);
        // rr.initialize(new FileSplit(new ClassPathResource(csvFileClasspath).getFile()));
        rr.initialize(new FileSplit(file));
        DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses);
        return iterator.next();
    }

    /*
    Generate a scatterplot of the datasets provided
    */
    private static void plotDataSet(List<NamedDataSet> ArraySetList) {

        XYSeriesCollection plotData = new XYSeriesCollection();

        for (NamedDataSet ns : ArraySetList) {
            XYSeries series = new XYSeries(ns.getName());
            double[] data = ns.getArray();
            for (int i = 0; i < data.length; i++) {
                series.add((double) i, data[i]);
            }

            plotData.addSeries(series);
        }

        String title = "Overfitting Data";
        String xAxisLabel = "Epochs";
        String yAxisLabel = "Accuracy";
        PlotOrientation orientation = PlotOrientation.VERTICAL;
        boolean legend = true; // might wanna set this to true at some point, but research the library
        boolean tooltips = false;
        boolean urls = false;
        JFreeChart chart = ChartFactory.createScatterPlot(title, xAxisLabel, yAxisLabel, plotData, orientation,
                legend, tooltips, urls);

        JPanel panel = new ChartPanel(chart);

        JFrame f = new JFrame();
        f.add(panel);
        f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        f.pack();
        f.setTitle("Overfitting data");

        f.setVisible(true);
    }
}