org.fhcrc.cpl.toolbox.gui.chart.PanelWithScatterPlot.java Source code

Java tutorial

Introduction

Here is the source code for org.fhcrc.cpl.toolbox.gui.chart.PanelWithScatterPlot.java

Source

/*
 * Copyright (c) 2003-2012 Fred Hutchinson Cancer Research Center
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.fhcrc.cpl.toolbox.gui.chart;

import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.renderer.xy.StandardXYItemRenderer;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.xy.XYSeriesCollection;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYDataItem;
import org.fhcrc.cpl.toolbox.statistics.MatrixUtil;
import org.fhcrc.cpl.toolbox.statistics.RegressionUtilities;
import org.apache.log4j.Logger;

import java.awt.geom.Ellipse2D;
import java.awt.*;
import java.util.*;

/**
 * PanelWithChart implementation to make it easy to put out scatterplots.
 * If you want to do anything super-serious with the chart, use
 * getChart() and getRenderer()
 */
public class PanelWithScatterPlot extends PanelWithChart {
    protected static Logger _log = Logger.getLogger(PanelWithScatterPlot.class);

    protected XYSeriesCollection dataset;
    protected NumberAxis xAxis = new NumberAxis("X");
    protected NumberAxis yAxis = new NumberAxis("Y");

    protected StandardXYItemRenderer renderer = null;

    protected static Shape defaultShape = new Ellipse2D.Double(-1, -1, 3, 3);

    protected static Color[] SERIES_COLORS = new Color[] { Color.BLUE, Color.RED, Color.GREEN, Color.CYAN,
            Color.MAGENTA, Color.ORANGE, Color.PINK, Color.YELLOW };
    protected Color[] seriesColors;

    public PanelWithScatterPlot() {
        super();
        init();
    }

    public PanelWithScatterPlot(boolean showLegend) {
        super();
        this.showLegend = showLegend;
        init();
    }

    public PanelWithScatterPlot(java.util.List<? extends Number> xValues, java.util.List<? extends Number> yValues,
            String dataSetName, String xAxisLabel, String yAxisLabel) {
        this(xValues, yValues, dataSetName);
        setAxisLabels(xAxisLabel, yAxisLabel);
    }

    public PanelWithScatterPlot(java.util.List<? extends Number> xValues, java.util.List<? extends Number> yValues,
            String dataSetName) {
        this();
        double[] xValuesArray = new double[xValues.size()];
        double[] yValuesArray = new double[yValues.size()];
        for (int i = 0; i < xValues.size(); i++) {
            xValuesArray[i] = xValues.get(i).doubleValue();
            yValuesArray[i] = yValues.get(i).doubleValue();
        }
        addData(xValuesArray, yValuesArray, dataSetName);
        setName(dataSetName);
    }

    public PanelWithScatterPlot(float[] xValues, float[] yValues, String dataSetName) {
        this();
        addData(xValues, yValues, dataSetName);
        setName(dataSetName);
    }

    public PanelWithScatterPlot(double[] xValues, double[] yValues, String dataSetName) {
        this();
        setName(dataSetName);
        addData(xValues, yValues, dataSetName);
    }

    public PanelWithScatterPlot(double[] xValues, double[] yValues, String dataSetName, Shape shape, Color color) {
        this();
        setName(dataSetName);
        addData(xValues, yValues, dataSetName, shape, color);
    }

    /**
     * Create a scatterplot with the logs of the values passed in
     * @param xValues
     * @param yValues
     * @param dataSetName
     * @return
     */
    public static PanelWithScatterPlot createPlotForLogValues(java.util.List<? extends Number> xValues,
            java.util.List<? extends Number> yValues, String dataSetName) {
        java.util.List<Double> xValuesLog = new ArrayList<Double>();
        java.util.List<Double> yValuesLog = new ArrayList<Double>();

        for (int i = 0; i < xValues.size(); i++) {
            xValuesLog.add(Math.log(xValues.get(i).doubleValue()));
            yValuesLog.add(Math.log(yValues.get(i).doubleValue()));
        }
        return new PanelWithScatterPlot(xValuesLog, yValuesLog, dataSetName);
    }

    protected void init() {
        seriesColors = new Color[SERIES_COLORS.length];
        System.arraycopy(SERIES_COLORS, 0, seriesColors, 0, SERIES_COLORS.length);
        dataset = new XYSeriesCollection();
        renderer = new StandardXYItemRenderer();
        renderer.setPlotLines(false);
        renderer.setBaseShapesVisible(true);
        renderer.setShapesFilled(true);
        //set all possible series to the default shape
        for (int i = 0; i < 10; i++)
            renderer.setSeriesShape(i, defaultShape);

        XYPlot scatterPlot = new XYPlot(dataset, xAxis, yAxis, renderer);
        init(scatterPlot);
    }

    public void setAxisLabels(String xLabel, String yLabel) {
        xAxis.setLabel(xLabel);
        yAxis.setLabel(yLabel);
    }

    protected Color getDefaultColorForSeries(int index) {
        if (index >= seriesColors.length)
            index = index % seriesColors.length;
        return seriesColors[index];
    }

    public void addDataRedBlueHeatmap(double[] xValues, double[] yValues, double[] zValues, int numShades) {
        int numPoints = xValues.length;
        double minZ = Double.MAX_VALUE;
        double maxZ = Double.MIN_VALUE;
        for (double zValue : zValues) {
            if (zValue < minZ)
                minZ = zValue;
            if (zValue > maxZ)
                maxZ = zValue;
        }

        double zRange = maxZ - minZ;

        for (int j = 0; j < numShades; j++) {
            double minZValThisGroup = minZ + j * zRange / numShades;
            double maxZValThisGroup = minZ + (j + 1) * zRange / numShades;
            int red = (255 / numShades) * j;
            int blue = 255 - (255 / numShades) * j;
            Color color = new Color(blue, 10, red);

            java.util.List<Float> thisGroupX = new ArrayList<Float>();
            java.util.List<Float> thisGroupY = new ArrayList<Float>();

            for (int k = 0; k < numPoints; k++) {
                if (zValues[k] <= maxZValThisGroup && zValues[k] >= minZValThisGroup) {
                    thisGroupX.add((float) xValues[k]);
                    thisGroupY.add((float) yValues[k]);
                    //if (Double.isNaN(xValues[k]) || Double.isInfinite(xValues[k]) ||
                    //        Double.isNaN(yValues[k]) || Double.isInfinite(yValues[k]))System.err.println(xValues[k] + " , " + yValues[k]);
                }
            }
            addData(thisGroupX, thisGroupY, "" + minZValThisGroup);
            //            setSeriesColor(j, color);
            //                setPointSize(3);
        }

    }

    public void addData(float[] xValues, float[] yValues, String dataSetName) {
        double[] xValuesDouble = new double[xValues.length];
        double[] yValuesDouble = new double[xValues.length];

        for (int i = 0; i < xValues.length; i++) {
            xValuesDouble[i] = xValues[i];
            yValuesDouble[i] = yValues[i];

        }
        addData(xValuesDouble, yValuesDouble, dataSetName);
    }

    public void addData(java.util.List<? extends Number> xValues, java.util.List<? extends Number> yValues,
            String dataSetName) {
        double[] xArray = new double[xValues.size()];
        double[] yArray = new double[yValues.size()];

        for (int i = 0; i < xValues.size(); i++) {
            xArray[i] = xValues.get(i).doubleValue();
            yArray[i] = yValues.get(i).doubleValue();
        }
        addData(xArray, yArray, dataSetName, defaultShape, getDefaultColorForSeries(dataset.getSeriesCount()));
    }

    public void addData(double[] xValues, double[] yValues, String dataSetName) {
        addData(xValues, yValues, dataSetName, defaultShape, getDefaultColorForSeries(dataset.getSeriesCount()));
    }

    public void addData(double[] xValues, double[] yValues, String dataSetName, Shape shape, Color color) {
        if (xValues.length != yValues.length)
            throw new RuntimeException("PanelWithScatterPlot: x and y values have different cardinality");

        XYSeries series = new XYSeries(dataSetName);

        for (int i = 0; i < xValues.length; i++) {
            series.add(xValues[i], yValues[i]);
        }
        dataset.addSeries(series);
        setSeriesColor(dataset.getSeriesCount(), color);
    }

    public void addSeries(XYSeries series) {
        dataset.addSeries(series);
    }

    public double[] addRegressionLine() {
        return addRegressionLine(0, false);
    }

    public double[] addRegressionLine(boolean robustRegression) {
        return addRegressionLine(0, false);
    }

    /**
     * This method only does anything if there's exactly one series in the dataset.
     *
     * Performs linear regression, and then plots a regression line, from the
     * minimum to the maximum X value of the series.
     *
     * Removes the old series, adds this new one, and then adds the old one again,
     * so the regression line will appear on top. 
     */
    public double[] addRegressionLine(int seriesIndex, boolean robustRegression) {
        if (dataset == null || dataset.getSeriesCount() < seriesIndex + 1)
            return null;
        XYSeries series = dataset.getSeries(seriesIndex);
        int n = series.getItemCount();
        double[] xValues = new double[n];
        double[] yValues = new double[n];

        double maxX = Double.MIN_VALUE;
        double minX = Double.MAX_VALUE;

        for (int i = 0; i < n; i++) {
            XYDataItem dataItem = series.getDataItem(i);
            xValues[i] = dataItem.getX().doubleValue();
            yValues[i] = dataItem.getY().doubleValue();

            if (xValues[i] > maxX)
                maxX = xValues[i];
            if (xValues[i] < minX)
                minX = xValues[i];
        }

        _log.debug("addRegressionLine, minX = " + minX + ", maxX = " + maxX);

        RegressionUtilities.robustRegression(xValues, yValues);
        double[] regressionCoefficients = null;
        if (robustRegression)
            regressionCoefficients = RegressionUtilities.robustRegression(xValues, yValues);
        else
            regressionCoefficients = MatrixUtil.linearRegression(xValues, yValues);
        _log.debug("addRegressionLine, coeffs = " + regressionCoefficients[0] + ", " + regressionCoefficients[1]);

        addLine(regressionCoefficients[1], regressionCoefficients[0], minX, maxX);

        return regressionCoefficients;
    }

    /**
     * adds a curve using a polynomial of degree coefficients.length
     * @param coefficients
     * @param minX
     * @param maxX
     */
    public void addLineOrCurve(double[] coefficients, double minX, double maxX) {
        int numLineDots = 1000;
        double[] lineXvals = new double[numLineDots];
        double[] lineYvals = new double[numLineDots];

        for (int i = 0; i < numLineDots; i++) {
            lineXvals[i] = minX + (i * (maxX - minX) / numLineDots);
            lineYvals[i] = 0;
            for (int j = 0; j < coefficients.length; j++)
                lineYvals[i] += coefficients[j] * Math.pow(lineXvals[i], j);
        }

        _log.debug("addLine, Y vals:  first = " + lineYvals[0] + ", last = " + lineYvals[lineYvals.length - 1]);

        addData(lineXvals, lineYvals, "Line or Curve");
    }

    public void addLineOrCurve(double[] coefficients) {

        addLineOrCurve(coefficients, dataset.getDomainLowerBound(true), dataset.getDomainUpperBound(true));
    }

    public void addLine(double slope, double intercept, double minX, double maxX) {
        addLineOrCurve(new double[] { intercept, slope }, minX, maxX);
    }

    public void addHorizontalLine(double yValue, double minX, double maxX) {
        int numLineDots = 1000;
        double[] lineXvals = new double[numLineDots];
        double[] lineYvals = new double[numLineDots];
        for (int i = 0; i < numLineDots; i++) {
            lineXvals[i] = minX + (i * (maxX - minX) / numLineDots);
            lineYvals[i] = yValue;
        }

        _log.debug("addLine, Y vals:  first = " + lineYvals[0] + ", last = " + lineYvals[lineYvals.length - 1]);

        addData(lineXvals, lineYvals, "Line");
    }

    public void addVerticalLine(double xValue, double minY, double maxY) {
        int numLineDots = 1000;
        double[] lineXvals = new double[numLineDots];
        double[] lineYvals = new double[numLineDots];
        for (int i = 0; i < numLineDots; i++) {
            lineXvals[i] = xValue;
            lineYvals[i] = minY + (i * (maxY - minY) / numLineDots);
        }

        _log.debug("addLine, Y vals:  first = " + lineYvals[0] + ", last = " + lineYvals[lineYvals.length - 1]);

        addData(lineXvals, lineYvals, "Line");
    }

    public void setPointSize(int pointSize) {
        int upLeft = -(pointSize / 2);

        renderer.setShape(new Ellipse2D.Double(-upLeft, -upLeft, pointSize, pointSize));
    }

    public void addCrosshairs(double initialDomainCrosshairValue, double initialRangeCrosshairValue) {
        ((XYPlot) getPlot()).setDomainCrosshairVisible(true);
        ((XYPlot) getPlot()).setRangeCrosshairVisible(true);
        ((XYPlot) getPlot()).setDomainCrosshairValue(initialDomainCrosshairValue);
        ((XYPlot) getPlot()).setRangeCrosshairValue(initialRangeCrosshairValue);
    }

    public void addCrosshairsAndListener(CrosshairChangeListener crosshairListener,
            double initialDomainCrosshairValue, double initialRangeCrosshairValue) {
        addCrosshairs(initialDomainCrosshairValue, initialRangeCrosshairValue);
        getChart().addProgressListener(crosshairListener);
    }

    public void setSeriesColor(int i, Color color) {
        renderer.setSeriesPaint(i, color);
    }

    XYSeries getDataSeries(int i) {
        return dataset.getSeries(i);
    }

    public StandardXYItemRenderer getRenderer() {
        return renderer;
    }

    public NumberAxis getXAxis() {
        return xAxis;
    }

    public NumberAxis getYAxis() {
        return yAxis;
    }

    public Color[] getSeriesColors() {
        return seriesColors;
    }

    public void setSeriesColors(Color[] seriesColors) {
        this.seriesColors = seriesColors;
    }
}