org.encog.workbench.dialogs.validate.ResultValidationChart.java Source code

Java tutorial

Introduction

Here is the source code for org.encog.workbench.dialogs.validate.ResultValidationChart.java

Source

/*
 * Encog(tm) Workbench v3.0
 * http://www.heatonresearch.com/encog/
 * http://code.google.com/p/encog-java/
     
 * Copyright 2008-2011 Heaton Research, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *   
 * For more information on Heaton Research copyrights, licenses 
 * and trademarks visit:
 * http://www.heatonresearch.com/copyright
 */
package org.encog.workbench.dialogs.validate;

import java.awt.BorderLayout;
import java.awt.Color;
import java.util.ArrayList;
import java.util.Vector;

import javax.swing.JScrollPane;
import javax.swing.JTabbedPane;
import javax.swing.JTable;

import org.encog.ml.MLClassification;
import org.encog.ml.MLMethod;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.workbench.WorkBenchError;
import org.encog.workbench.tabs.EncogCommonTab;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.StandardXYItemRenderer;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

public class ResultValidationChart extends EncogCommonTab {
    private static final long serialVersionUID = -2859655432840760344L;
    private JTabbedPane tabs = new JTabbedPane();
    private ArrayList<JFreeChart> charts = new ArrayList<JFreeChart>();
    private ArrayList<ChartPanel> chartPanels = new ArrayList<ChartPanel>();

    public ResultValidationChart() {
        super(null);
        setLayout(new BorderLayout());
        this.add(tabs, BorderLayout.CENTER);

    }

    public void setData(MLDataSet validationData, MLMethod method) {
        ArrayList<XYSeries> validation = new ArrayList<XYSeries>();
        ArrayList<XYSeries> computation = new ArrayList<XYSeries>();

        Vector<Vector<String>> tableData = new Vector<Vector<String>>();
        Vector<String> tableHeaders = null;

        int key = 0;
        Vector<String> tableDataRow;
        for (MLDataPair dataRow : validationData) {
            MLData input = dataRow.getInput();
            MLData validIdeal = dataRow.getIdeal();
            MLData computatedIdeal = getCalculatedResult(dataRow, method);
            int inputCount = input.size();
            int idealCount = validIdeal == null ? 0 : validIdeal.size();

            tableDataRow = new Vector<String>();
            if (tableHeaders == null) {
                tableHeaders = new Vector<String>();
                for (int i = 0; i < inputCount; i++) {
                    tableHeaders.add("Input " + i);
                }
                for (int i = 0; i < computatedIdeal.size(); i++) {
                    tableHeaders.add("Ideal " + i);
                    tableHeaders.add("Result " + i);
                }
            }

            for (int i = 0; i < inputCount; i++) {
                tableDataRow.add(new Double(input.getData(i)).toString());
            }

            for (int i = validation.size(); i < idealCount; i++) {
                validation.add(new XYSeries("Validation"));
                computation.add(new XYSeries("Computation"));
                createChart();
            }

            for (int i = 0; i < computatedIdeal.size(); i++) {
                double c = computatedIdeal.getData(i);

                if (idealCount > 0) {
                    double v = validIdeal.getData(i);
                    validation.get(i).add(key, v);
                    tableDataRow.add(new Double(v).toString());
                    computation.get(i).add(key, c);
                } else {
                    tableDataRow.add("N/A");
                }

                tableDataRow.add(new Double(c).toString());

            }

            tableData.add(tableDataRow);

            key++;
        }

        drawGraphs(validation, computation);
        drawTable(tableData, tableHeaders);
    }

    private void drawGraphs(ArrayList<XYSeries> validation, ArrayList<XYSeries> computation) {
        // Add charts
        int size = validation.size();
        for (int i = 0; i < size; i++) {
            XYSeries vSeries = validation.get(i);
            XYSeries cSeries = computation.get(i);
            JFreeChart chart = charts.get(i);
            ChartPanel chartPanel = chartPanels.get(i);

            XYPlot plot = chart.getXYPlot();
            plot.setDataset(0, new XYSeriesCollection(vSeries));
            final XYItemRenderer renderer1 = new StandardXYItemRenderer();
            renderer1.setSeriesPaint(0, Color.blue);
            plot.setRenderer(0, renderer1);

            plot.setDataset(1, new XYSeriesCollection(cSeries));
            final XYItemRenderer renderer2 = new StandardXYItemRenderer();
            renderer2.setSeriesPaint(0, Color.red);
            plot.setRenderer(1, renderer2);

            ChartUtilities.applyCurrentTheme(chart);

            tabs.addTab("Ideal" + (i + 1), chartPanel);
        }
    }

    private void drawTable(Vector<Vector<String>> tableData, Vector<String> tableHeaders) {
        JTable table = new JTable(tableData, tableHeaders) {
            private static final long serialVersionUID = 8364655578079933961L;

            public boolean isCellEditable(int rowIndex, int vColIndex) {
                return false;
            }
        };
        table.setAutoResizeMode(JTable.AUTO_RESIZE_OFF);
        tabs.addTab("Data", new JScrollPane(table));
    }

    private MLData getCalculatedResult(MLDataPair data, MLMethod method) {

        MLData out;

        if (method instanceof MLRegression) {
            out = ((MLRegression) method).compute(data.getInput());
        } else if (method instanceof MLClassification) {
            out = new BasicMLData(1);
            out.setData(0, ((MLClassification) method).classify(data.getInput()));

        } else {
            throw new WorkBenchError("Unsupported Machine Learning Method:" + method.getClass().getSimpleName());
        }

        return out;
    }

    /**
     * Create the initial chart.
     * 
     * @return The chart.
     */
    private void createChart() {
        JFreeChart chart = ChartFactory.createXYLineChart(null, "Result", "Increment", null,
                PlotOrientation.VERTICAL, true, true, false);

        ChartPanel chartPanel = new ChartPanel(chart);
        chartPanel.setPreferredSize(new java.awt.Dimension(600, 360));
        chartPanel.setDomainZoomable(true);
        chartPanel.setRangeZoomable(true);

        charts.add(chart);
        chartPanels.add(chartPanel);
    }

    @Override
    public String getName() {
        return "Validation";
    }
}