com.heatonresearch.aifh.examples.rbf.LearnIrisAnnealROC.java Source code

Java tutorial

Introduction

Here is the source code for com.heatonresearch.aifh.examples.rbf.LearnIrisAnnealROC.java

Source

/*
 * Artificial Intelligence for Humans
 * Volume 3: Deep Learning and Neural Networks
 * Java Version
 * http://www.aifh.org
 * http://www.jeffheaton.com
 *
 * Code repository:
 * https://github.com/jeffheaton/aifh
 *
 * Copyright 2014-2015 by Jeff Heaton
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * For more information on Heaton Research copyrights, licenses
 * and trademarks visit:
 * http://www.heatonresearch.com/copyright
 */
package com.heatonresearch.aifh.examples.rbf;

import com.heatonresearch.aifh.general.data.BasicData;
import com.heatonresearch.aifh.learning.RBFNetwork;
import com.heatonresearch.aifh.learning.TrainAnneal;
import com.heatonresearch.aifh.learning.score.ScoreFunction;
import com.heatonresearch.aifh.learning.score.ScoreRegressionData;
import com.heatonresearch.aifh.normalize.DataSet;
import com.heatonresearch.aifh.randomize.MersenneTwisterGenerateRandom;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

import javax.swing.*;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

public class LearnIrisAnnealROC extends JFrame implements ActionListener {

    private final JButton buttonAnneal;
    private final JButton buttonReset;
    private final TrainAnneal trainer;
    private final RBFNetwork network;
    private final List<BasicData> training;
    private final XYSeriesCollection dataset;
    private final XYSeries dataSeries1;
    public static final String TPR = "True Positive Rate";
    public static final String FPR = "False Positive Rate";
    public static final String THRESHOLD = "threshold";

    class RocPair implements Comparable<RocPair> {
        private final double x;
        private final double y;

        public RocPair(double x, double y) {
            this.x = x;
            this.y = y;
        }

        public double getX() {
            return this.x;
        }

        public double getY() {
            return this.y;
        }

        public boolean equal(RocPair other) {
            return Double.compare(getX(), other.getX()) == 0;
        }

        @Override
        public int compareTo(RocPair other) {
            return Double.compare(getX(), other.getX());
        }

        @Override
        public String toString() {
            return "[" + this.x + "," + this.y + "]";
        }
    }

    public LearnIrisAnnealROC() {
        this.setSize(640, 480);
        this.setDefaultCloseOperation(EXIT_ON_CLOSE);
        this.setTitle("Iris Classify ROC");
        Container content = this.getContentPane();
        content.setLayout(new BorderLayout());

        this.dataset = new XYSeriesCollection();
        this.dataSeries1 = new XYSeries("Threshold");
        this.dataset.addSeries(this.dataSeries1);

        final JFreeChart lineChart = ChartFactory.createXYLineChart("Line Chart Demo 6", // chart title
                FPR, // x axis label
                TPR, // y axis label
                this.dataset, // data
                PlotOrientation.VERTICAL, true, // include legend
                true, // tooltips
                false // urls
        );

        ChartPanel chartPanel = new ChartPanel(lineChart);
        chartPanel.setPreferredSize(new java.awt.Dimension(560, 367));
        content.add(chartPanel, BorderLayout.CENTER);

        JPanel buttonPanel = new JPanel();
        buttonPanel.add(this.buttonAnneal = new JButton("Anneal"));
        buttonPanel.add(this.buttonReset = new JButton("Reset"));
        this.buttonAnneal.addActionListener(this);
        this.buttonReset.addActionListener(this);
        content.add(buttonPanel, BorderLayout.SOUTH);

        this.training = loadIrisData();

        this.network = new RBFNetwork(4, 4, 1);
        this.network.reset(new MersenneTwisterGenerateRandom());

        final ScoreFunction score = new ScoreRegressionData(this.training);
        this.trainer = new TrainAnneal(this.network, score);
    }

    private List<BasicData> loadIrisData() {
        try {
            final InputStream istream = this.getClass().getResourceAsStream("/iris.csv");
            if (istream == null) {
                System.out.println("Cannot access data set, make sure the resources are available.");
                System.exit(1);
            }
            final DataSet ds = DataSet.load(istream);
            // The following ranges are setup for the Iris data set.  If you wish to normalize other files you will
            // need to modify the below function calls other files.
            ds.normalizeRange(0, 0, 1);
            ds.normalizeRange(1, 0, 1);
            ds.normalizeRange(2, 0, 1);
            ds.normalizeRange(3, 0, 1);
            final Map<String, Integer> species = ds.encodeNumeric(4);
            istream.close();

            int irisVersicolor = species.get("Iris-versicolor");

            final java.util.List<BasicData> trainingData = ds.extractSupervised(0, 4, 4, 1);

            for (BasicData aTrainingData : trainingData) {
                if (aTrainingData.getIdeal()[0] == irisVersicolor) {
                    aTrainingData.getIdeal()[0] = 1; // True, is versicolor
                } else {
                    aTrainingData.getIdeal()[0] = 0; // False, is not versicolor
                }

            }

            return trainingData;

        } catch (IOException ex) {
            ex.printStackTrace();
            System.exit(0);
            return null;
        }
    }

    private double[] calculateTruePositiveFalsePositive(double thresh) {
        int tp = 0;
        int fp = 0;
        int tn = 0;
        int fn = 0;

        for (BasicData item : this.training) {
            double x = this.network.computeRegression(item.getInput())[0];
            double y = item.getIdeal()[0];
            if (x > thresh) {
                if (y > 0.5) {
                    tp++;
                } else {
                    fp++;
                }
            } else {
                if (y < 0.5) {
                    tn++;
                } else {
                    fn++;
                }
            }
        }
        double tpr = ((double) tp) / (tp + fn);
        double fpr = ((double) fp) / (fp + tn);
        double[] result = new double[2];
        result[0] = fpr;
        result[1] = tpr;
        return result;
    }

    private void updateChart() {
        for (int i = 0; i < 5; i++) {
            this.trainer.iteration();
        }
        System.out.println(this.trainer.getLastError());

        Set<RocPair> list = new TreeSet<>();
        list.add(new RocPair(0, 0));
        for (int i = 0; i <= 10; i++) {
            double[] tpfp = calculateTruePositiveFalsePositive(i / 10.0);
            list.add(new RocPair(tpfp[0], tpfp[1]));

        }
        list.add(new RocPair(1, 1));

        this.dataSeries1.clear();
        for (RocPair pair : list) {
            this.dataSeries1.add(pair.getX(), pair.getY());
        }
    }

    public static void main(String[] args) {
        JFrame prg = new LearnIrisAnnealROC();
        prg.setVisible(true);
    }

    /**
     * Invoked when an action occurs.
     *
     * @param e
     */
    @Override
    public void actionPerformed(ActionEvent e) {
        if (e.getSource() == this.buttonAnneal) {
            updateChart();
        } else if (e.getSource() == this.buttonReset) {
            this.network.reset(new MersenneTwisterGenerateRandom());
            updateChart();
        }
    }
}