joinery.impl.Display.java Source code

Java tutorial

Introduction

Here is the source code for joinery.impl.Display.java

Source

/*
 * Joinery -- Data frames for Java
 * Copyright (c) 2014, 2015 IBM Corp.
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package joinery.impl;

import java.awt.Color;
import java.awt.Container;
import java.awt.GridLayout;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.Date;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

import javax.swing.JFrame;
import javax.swing.JScrollPane;
import javax.swing.JTable;
import javax.swing.SwingUtilities;
import javax.swing.table.AbstractTableModel;

import org.apache.commons.math3.stat.regression.SimpleRegression;

import com.xeiam.xchart.Chart;
import com.xeiam.xchart.ChartBuilder;
import com.xeiam.xchart.Series;
import com.xeiam.xchart.SeriesLineStyle;
import com.xeiam.xchart.SeriesMarker;
import com.xeiam.xchart.StyleManager.ChartType;
import com.xeiam.xchart.XChartPanel;

import joinery.DataFrame;
import joinery.DataFrame.PlotType;

public class Display {
    public static <C extends Container, V> C draw(final DataFrame<V> df, final C container, final PlotType type) {
        final List<XChartPanel> panels = new LinkedList<>();
        final DataFrame<Number> numeric = df.numeric().fillna(0);
        final int rows = (int) Math.ceil(Math.sqrt(numeric.size()));
        final int cols = numeric.size() / rows + 1;

        final List<Object> xdata = new ArrayList<>(df.length());
        final Iterator<Object> it = df.index().iterator();
        for (int i = 0; i < df.length(); i++) {
            final Object value = it.hasNext() ? it.next() : i;
            if (value instanceof Number || value instanceof Date) {
                xdata.add(value);
            } else if (PlotType.BAR.equals(type)) {
                xdata.add(String.valueOf(value));
            } else {
                xdata.add(i);
            }
        }

        if (EnumSet.of(PlotType.GRID, PlotType.GRID_WITH_TREND).contains(type)) {
            for (final Object col : numeric.columns()) {
                final Chart chart = new ChartBuilder().chartType(chartType(type)).width(800 / cols)
                        .height(800 / cols).title(String.valueOf(col)).build();
                final Series series = chart.addSeries(String.valueOf(col), xdata, numeric.col(col));
                if (type == PlotType.GRID_WITH_TREND) {
                    addTrend(chart, series, xdata);
                    series.setLineStyle(SeriesLineStyle.NONE);
                }
                chart.getStyleManager().setLegendVisible(false);
                chart.getStyleManager().setDatePattern(dateFormat(xdata));
                panels.add(new XChartPanel(chart));
            }
        } else {
            final Chart chart = new ChartBuilder().chartType(chartType(type)).build();

            chart.getStyleManager().setDatePattern(dateFormat(xdata));
            switch (type) {
            case SCATTER:
            case SCATTER_WITH_TREND:
            case LINE_AND_POINTS:
                break;
            default:
                chart.getStyleManager().setMarkerSize(0);
                break;
            }

            for (final Object col : numeric.columns()) {
                final Series series = chart.addSeries(String.valueOf(col), xdata, numeric.col(col));
                if (type == PlotType.SCATTER_WITH_TREND) {
                    addTrend(chart, series, xdata);
                    series.setLineStyle(SeriesLineStyle.NONE);
                }
            }

            panels.add(new XChartPanel(chart));
        }

        if (panels.size() > 1) {
            container.setLayout(new GridLayout(rows, cols));
        }
        for (final XChartPanel p : panels) {
            container.add(p);
        }

        return container;
    }

    public static <V> void plot(final DataFrame<V> df, final PlotType type) {
        SwingUtilities.invokeLater(new Runnable() {
            @Override
            public void run() {
                final JFrame frame = draw(df, new JFrame(title(df)), type);
                frame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
                frame.pack();
                frame.setVisible(true);
            }
        });
    }

    public static <V> void show(final DataFrame<V> df) {
        final List<Object> columns = new ArrayList<>(df.columns());
        final List<Class<?>> types = df.types();
        SwingUtilities.invokeLater(new Runnable() {
            @Override
            public void run() {
                final JFrame frame = new JFrame(title(df));
                final JTable table = new JTable(new AbstractTableModel() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public int getRowCount() {
                        return df.length();
                    }

                    @Override
                    public int getColumnCount() {
                        return df.size();
                    }

                    @Override
                    public Object getValueAt(final int row, final int col) {
                        return df.get(row, col);
                    }

                    @Override
                    public String getColumnName(final int col) {
                        return String.valueOf(columns.get(col));
                    }

                    @Override
                    public Class<?> getColumnClass(final int col) {
                        return types.get(col);
                    }
                });
                table.setAutoResizeMode(JTable.AUTO_RESIZE_OFF);
                frame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
                frame.add(new JScrollPane(table));
                frame.pack();
                frame.setVisible(true);
            }
        });
    }

    private static ChartType chartType(final PlotType type) {
        switch (type) {
        case AREA:
            return ChartType.Area;
        case BAR:
            return ChartType.Bar;
        case GRID:
        case SCATTER:
            return ChartType.Scatter;
        case SCATTER_WITH_TREND:
        case GRID_WITH_TREND:
        case LINE:
        default:
            return ChartType.Line;
        }
    }

    private static final String title(final DataFrame<?> df) {
        return String.format("%s (%d rows x %d columns)", df.getClass().getCanonicalName(), df.length(), df.size());
    }

    private static final String dateFormat(final List<Object> xdata) {
        final int[] fields = new int[] { Calendar.YEAR, Calendar.MONTH, Calendar.DAY_OF_MONTH, Calendar.HOUR_OF_DAY,
                Calendar.MINUTE, Calendar.SECOND };
        final String[] formats = new String[] { " yyy", "-MMM", "-d", " H", ":mm", ":ss" };
        final Calendar c1 = Calendar.getInstance(), c2 = Calendar.getInstance();

        if (!xdata.isEmpty() && xdata.get(0) instanceof Date) {
            String format = "";
            int first = 0, last = 0;
            c1.setTime(Date.class.cast(xdata.get(0)));
            // iterate over all x-axis values comparing dates
            for (int i = 1; i < xdata.size(); i++) {
                // early exit for non-date elements
                if (!(xdata.get(i) instanceof Date))
                    return formats[0].substring(1);
                c2.setTime(Date.class.cast(xdata.get(i)));

                // check which components differ, those are the fields to output
                for (int j = 1; j < fields.length; j++) {
                    if (c1.get(fields[j]) != c2.get(fields[j])) {
                        first = Math.max(j - 1, first);
                        last = Math.max(j, last);
                    }
                }
            }

            // construct a format string for the fields that differ
            for (int i = first; i <= last && i < formats.length; i++) {
                format += format.isEmpty() ? formats[i].substring(1) : formats[i];
            }

            return format;
        }

        return formats[0].substring(1);
    }

    private static void addTrend(final Chart chart, final Series series, final List<Object> xdata) {
        final SimpleRegression model = new SimpleRegression();
        final Iterator<? extends Number> y = series.getYData().iterator();
        for (int x = 0; y.hasNext(); x++) {
            model.addData(x, y.next().doubleValue());
        }
        final Color mc = series.getMarkerColor();
        final Color c = new Color(mc.getRed(), mc.getGreen(), mc.getBlue(), 0x60);
        final Series trend = chart.addSeries(series.getName() + " (trend)",
                Arrays.asList(xdata.get(0), xdata.get(xdata.size() - 1)),
                Arrays.asList(model.predict(0), model.predict(xdata.size() - 1)));
        trend.setLineColor(c);
        trend.setMarker(SeriesMarker.NONE);
    }
}