org.usfirst.frc.team2084.neuralnetwork.RobotHeadingTest.java Source code

Java tutorial

Introduction

Here is the source code for org.usfirst.frc.team2084.neuralnetwork.RobotHeadingTest.java

Source

/* 
 * Copyright (c) 2015 RobotsByTheC. All rights reserved.
 *
 * Open Source Software - may be modified and shared by FRC teams. The code must
 * be accompanied by the BSD license file in the root directory of the project.
 */
package org.usfirst.frc.team2084.neuralnetwork;

import java.awt.Container;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;

import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.SwingUtilities;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.CompassPlot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.general.DefaultValueDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

/**
 * @author ben
 */
public class RobotHeadingTest implements Test {

    public static final double MAX_ERROR = 5;

    /**
     * A robot "simulation", used to test the unsupervised learning ability.
     */
    private static class Robot {

        public static final double MAX_ACCELERATION = 5;
        public static final double TIME_STEP = 0.05;
        public static final double MAX_SPEED = 50;

        public double speed = 0;
        public double heading = 0;

        public void rotate(double power) {
            speed += power * TIME_STEP * MAX_ACCELERATION;
            if (speed > MAX_SPEED) {
                speed = MAX_SPEED;
            } else if (speed < -MAX_SPEED) {
                speed = -MAX_SPEED;
            }
            heading += speed * TIME_STEP * MAX_SPEED;
        }
    }

    private volatile boolean running = false;
    private volatile Thread thread;

    /**
     * 
     */
    @Override
    public void run() {
        try {
            final DefaultValueDataset headingData = new DefaultValueDataset(0);
            final DefaultValueDataset desiredHeadingData = new DefaultValueDataset(0);
            final CompassPlot headingPlot = new CompassPlot();
            headingPlot.addDataset(headingData);
            headingPlot.addDataset(desiredHeadingData);
            final JFreeChart headingChart = new JFreeChart("Heading", headingPlot);

            final XYSeries headingTimeSeries = new XYSeries("Heading");
            final XYSeriesCollection headingTimeData = new XYSeriesCollection();
            headingTimeData.addSeries(headingTimeSeries);
            final JFreeChart headingTimeChart = ChartFactory.createXYLineChart("Heading vs. Time", "Time",
                    "Heading", headingTimeData, PlotOrientation.VERTICAL, true, true, false);

            final XYSeries errorTimeSeries = new XYSeries("Error");
            final XYSeriesCollection errorTimeData = new XYSeriesCollection();
            errorTimeData.addSeries(errorTimeSeries);
            final JFreeChart errorTimeChart = ChartFactory.createXYLineChart("Error vs. Time", "Time", "Error",
                    errorTimeData, PlotOrientation.VERTICAL, true, true, false);

            SwingUtilities.invokeAndWait(() -> {
                final JFrame frame = new JFrame("Charts");
                frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
                final Container content = frame.getContentPane();
                content.setLayout(new BoxLayout(content, BoxLayout.PAGE_AXIS));

                final JPanel chartPanel = new JPanel();
                chartPanel.setLayout(new GridLayout(2, 2));
                content.add(chartPanel);

                final ChartPanel headingPanel = new ChartPanel(headingChart);
                chartPanel.add(headingPanel);

                final ChartPanel headingTimePanel = new ChartPanel(headingTimeChart);
                chartPanel.add(headingTimePanel);

                final ChartPanel errorTimePanel = new ChartPanel(errorTimeChart);
                chartPanel.add(errorTimePanel);

                final JPanel buttonPanel = new JPanel();
                content.add(buttonPanel);

                final JButton startButton = new JButton("Start");
                final JButton stopButton = new JButton("Stop");

                startButton.addActionListener(new ActionListener() {

                    @Override
                    public void actionPerformed(ActionEvent e) {
                        stop();
                        startButton.setEnabled(false);
                        stopButton.setEnabled(true);
                        start(headingData, desiredHeadingData, headingTimeSeries, errorTimeSeries);
                    }
                });
                buttonPanel.add(startButton);

                stopButton.addActionListener(new ActionListener() {

                    @Override
                    public void actionPerformed(ActionEvent e) {
                        stop();
                        startButton.setEnabled(true);
                        stopButton.setEnabled(false);
                    }
                });
                stopButton.setEnabled(false);
                buttonPanel.add(stopButton);

                frame.pack();
                frame.setVisible(true);
            });
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }

    private void start(DefaultValueDataset headingData, DefaultValueDataset desiredHeadingData,
            XYSeries headingTimeSeries, XYSeries errorTimeSeries) {
        try {
            thread = new Thread(new Runnable() {

                @Override
                public void run() {
                    try {

                        Robot robot = new Robot();
                        // Load the robot network characteristics from a file
                        Data data = new Data(new File("data/robot.txt"));
                        Network network = data.getNetwork();

                        double time = 0;
                        double error = 1;
                        // Pick a random desired heading
                        double desired = Math.random();
                        System.out.println("Trying to rotate to heading: " + desired);
                        desiredHeadingData.setValue(desired);
                        headingTimeSeries.clear();
                        errorTimeSeries.clear();
                        do {
                            // Feed forward the heading error
                            network.feedForward(desired - robot.heading);
                            // Rotate the robot at the speed given by the
                            // output of last hidden layer
                            double speed = network.getLayer(network.getTotalLayers() - 2)[0].getOutputValue();
                            robot.rotate(speed);
                            // Apply that rotation speed for a certain
                            // amount of time
                            Thread.sleep((int) (Robot.TIME_STEP * 1000));
                            time += Robot.TIME_STEP;
                            // Set the output neuron to the new heading
                            // error. This is the kind of weird part,
                            // because it tricks the back-propagation
                            // algorithm into minimizing the difference
                            // between the real and desired headings.
                            network.setLayerOutputs(network.getTotalLayers() - 1, desired - robot.heading);
                            // Back-propagate to adjust weights to minimize
                            // error
                            network.backPropagation(0);
                            error = network.getRecentAverageError();
                            // System.out.println("Speed: " + speed);
                            // System.out.println("Heading: " +
                            // robot.heading);
                            headingData.setValue(robot.heading * 180);
                            headingTimeSeries.add(time, robot.heading);
                            errorTimeSeries.add(time, error);
                            // System.out.println("Error: " + error + "\n");
                        } while (running);
                        System.out.println("====================");
                        System.out.println("Success!");
                        System.out.println("Final error: " + error);
                        System.out.println("====================");
                    } catch (Exception ex) {
                        ex.printStackTrace();
                    }
                }
            });
            running = true;
            thread.start();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void stop() {
        running = false;
        if (thread != null) {
            try {
                thread.join();
            } catch (InterruptedException e) {
            }
        }
    }

    private static void pause() {
        System.out.println("Press any key to continue...");
        try {
            System.in.read();
        } catch (Exception e) {
        }
    }
}