GeMSE.GS.Analysis.Stats.OneSamplePCAPanel.java Source code

Java tutorial

Introduction

Here is the source code for GeMSE.GS.Analysis.Stats.OneSamplePCAPanel.java

Source

/** GenoMetric Space Explorer (GeMSE) Copyright (C) 2017 Vahid Jalili
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 3 of the License, or
 *  (at your option) any later version.
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software Foundation,
 *  Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301  USA
 */
package GeMSE.GS.Analysis.Stats;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Shape;
import javax.swing.ButtonGroup;
import javax.swing.JOptionPane;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.stat.correlation.Covariance;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.data.xy.XYDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.jfree.util.ShapeUtilities;

/**
 *
 * @author Vahid Jalili
 */
public final class OneSamplePCAPanel extends javax.swing.JPanel {
    public OneSamplePCAPanel() {
        initComponents();
        Color.RGBtoHSB(214, 217, 223, bColor);

        ButtonGroup groupA = new ButtonGroup();
        groupA.add(testAreAtColsRB);
        groupA.add(testAreAtRowsRB);
        testAreAtColsRB.setSelected(true);

        _level = DEFAULT_LEVEL;
    }

    float[] bColor = new float[3];

    public static final double DEFAULT_LEVEL = 0.9;

    public enum CovarianceType {
        COVARIANCE, CORRELATION
    }

    private CovarianceType _covarianceType;

    private int _pcaIndices = -1;
    private double _level;

    private double[][] _data;
    private RealMatrix _matrix;

    private double[][] _sampleAData;
    private String _sampleALabel;
    private String[] _sampleARowLabels;
    private String[] _sampleAColLabels;

    private double[][] _sampleBData;
    private String _sampleBLabel;
    private String[] _sampleBRowLabels;
    private String[] _sampleBColLabels;

    private String[] _seriesLabels;

    private RealMatrix _covariance_old;
    private RealMatrix _covariance;

    private RealMatrix _principalComponents;
    private RealVector _variance;

    private Boolean _treatNodesSeparately = false;

    /**
     * This method is called from within the constructor to initialize the form.
     * WARNING: Do NOT modify this code. The content of this method is always
     * regenerated by the Form Editor.
     */
    @SuppressWarnings("unchecked")
    // <editor-fold defaultstate="collapsed" desc="Generated Code">//GEN-BEGIN:initComponents
    private void initComponents() {

        jPanel2 = new javax.swing.JPanel();
        testAreAtPanel = new javax.swing.JPanel();
        testAreAtL = new javax.swing.JLabel();
        testAreAtColsRB = new javax.swing.JRadioButton();
        testAreAtRowsRB = new javax.swing.JRadioButton();
        plotPanel = new javax.swing.JScrollPane();

        testAreAtPanel.setBorder(javax.swing.BorderFactory.createEtchedBorder());

        testAreAtL.setText("Tests are at: ");

        testAreAtColsRB.setText("columns");
        testAreAtColsRB.addActionListener(new java.awt.event.ActionListener() {
            public void actionPerformed(java.awt.event.ActionEvent evt) {
                testAreAtColsRBActionPerformed(evt);
            }
        });

        testAreAtRowsRB.setText("rows");
        testAreAtRowsRB.addActionListener(new java.awt.event.ActionListener() {
            public void actionPerformed(java.awt.event.ActionEvent evt) {
                testAreAtRowsRBActionPerformed(evt);
            }
        });

        javax.swing.GroupLayout testAreAtPanelLayout = new javax.swing.GroupLayout(testAreAtPanel);
        testAreAtPanel.setLayout(testAreAtPanelLayout);
        testAreAtPanelLayout.setHorizontalGroup(testAreAtPanelLayout
                .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
                .addGroup(testAreAtPanelLayout.createSequentialGroup().addGap(11, 11, 11).addComponent(testAreAtL)
                        .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED)
                        .addComponent(testAreAtColsRB)
                        .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED)
                        .addComponent(testAreAtRowsRB).addContainerGap(75, Short.MAX_VALUE)));
        testAreAtPanelLayout.setVerticalGroup(
                testAreAtPanelLayout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
                        .addGroup(testAreAtPanelLayout.createSequentialGroup().addContainerGap()
                                .addGroup(testAreAtPanelLayout
                                        .createParallelGroup(javax.swing.GroupLayout.Alignment.BASELINE)
                                        .addComponent(testAreAtL).addComponent(testAreAtColsRB)
                                        .addComponent(testAreAtRowsRB))
                                .addContainerGap(javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)));

        javax.swing.GroupLayout jPanel2Layout = new javax.swing.GroupLayout(jPanel2);
        jPanel2.setLayout(jPanel2Layout);
        jPanel2Layout.setHorizontalGroup(jPanel2Layout
                .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
                .addGroup(jPanel2Layout.createSequentialGroup().addContainerGap()
                        .addComponent(testAreAtPanel, javax.swing.GroupLayout.PREFERRED_SIZE,
                                javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.PREFERRED_SIZE)
                        .addContainerGap(351, Short.MAX_VALUE)));
        jPanel2Layout.setVerticalGroup(jPanel2Layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
                .addGroup(jPanel2Layout.createSequentialGroup().addContainerGap()
                        .addComponent(testAreAtPanel, javax.swing.GroupLayout.PREFERRED_SIZE,
                                javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.PREFERRED_SIZE)
                        .addContainerGap(javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)));

        javax.swing.GroupLayout layout = new javax.swing.GroupLayout(this);
        this.setLayout(layout);
        layout.setHorizontalGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
                .addGroup(layout.createSequentialGroup().addContainerGap()
                        .addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
                                .addComponent(jPanel2, javax.swing.GroupLayout.DEFAULT_SIZE,
                                        javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)
                                .addComponent(plotPanel))
                        .addContainerGap()));
        layout.setVerticalGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING)
                .addGroup(layout.createSequentialGroup()
                        .addComponent(jPanel2, javax.swing.GroupLayout.PREFERRED_SIZE,
                                javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.PREFERRED_SIZE)
                        .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED)
                        .addComponent(plotPanel, javax.swing.GroupLayout.DEFAULT_SIZE, 530, Short.MAX_VALUE)
                        .addContainerGap()));
    }// </editor-fold>//GEN-END:initComponents

    private void testAreAtColsRBActionPerformed(java.awt.event.ActionEvent evt)//GEN-FIRST:event_testAreAtColsRBActionPerformed
    {//GEN-HEADEREND:event_testAreAtColsRBActionPerformed
        RunAnalysis();
    }//GEN-LAST:event_testAreAtColsRBActionPerformed

    private void testAreAtRowsRBActionPerformed(java.awt.event.ActionEvent evt)//GEN-FIRST:event_testAreAtRowsRBActionPerformed
    {//GEN-HEADEREND:event_testAreAtRowsRBActionPerformed
        RunAnalysis();
    }//GEN-LAST:event_testAreAtRowsRBActionPerformed

    // Variables declaration - do not modify//GEN-BEGIN:variables
    private javax.swing.JPanel jPanel2;
    private javax.swing.JScrollPane plotPanel;
    private javax.swing.JRadioButton testAreAtColsRB;
    private javax.swing.JLabel testAreAtL;
    private javax.swing.JPanel testAreAtPanel;
    private javax.swing.JRadioButton testAreAtRowsRB;
    // End of variables declaration//GEN-END:variables

    public void RunAnalysis(double[][] data, String label, String[] rowLabels, String[] colLabels) {
        if (data == null)
            return;
        if (data.length == 0 || data[0].length == 0)
            return;
        _sampleAData = data;
        _sampleALabel = label;
        _sampleARowLabels = rowLabels;
        _sampleAColLabels = colLabels;

        RunAnalysis();
    }

    public void RunAnalysis(double[][] sampleAData, String sampleALabel, String[] sampleARowLabels,
            String[] sampleAColLabels, double[][] sampleBData, String sampleBLabel, String[] sampleBRowLabels,
            String[] sampleBColLabels) {
        if (sampleAData == null)
            return;
        if (sampleAData.length == 0 || sampleAData[0].length == 0)
            return;

        if (sampleBData == null)
            return;
        if (sampleBData.length == 0 || sampleBData[0].length == 0)
            return;

        _sampleAData = sampleAData;
        _sampleALabel = sampleALabel;
        _sampleARowLabels = sampleARowLabels;
        _sampleAColLabels = sampleAColLabels;
        _sampleBData = sampleAData;
        _sampleBLabel = sampleBLabel;
        _sampleBRowLabels = sampleARowLabels;
        _sampleBColLabels = sampleAColLabels;

        RunAnalysis();
    }

    private void RunAnalysis() {
        double[][] data;
        if (_sampleBData == null) { // One sample; no need to combine.
            if (testAreAtColsRB.isSelected()) {
                data = _sampleAData;
                _seriesLabels = _sampleAColLabels;
            } else {
                data = new double[_sampleAData[0].length][_sampleAData.length];
                for (int r = 0; r < _sampleAData.length; r++)
                    for (int c = 0; c < _sampleAData[0].length; c++)
                        data[c][r] = _sampleAData[r][c];
                _seriesLabels = _sampleARowLabels;
            }
        } else { // Two samples, combine them in a single array
            if (_treatNodesSeparately) {
                data = new double[2][Math.max(_sampleAData.length * _sampleAData[0].length,
                        _sampleBData.length * _sampleBData[0].length)];

                for (int r = 0; r < _sampleAData.length; r++)
                    System.arraycopy(_sampleAData[r], 0, data[0], r, _sampleAData[0].length);
                for (int r = 0; r < _sampleBData.length; r++)
                    System.arraycopy(_sampleBData[r], 0, data[1], r, _sampleBData[0].length);
                _seriesLabels = new String[] { _sampleALabel, _sampleBLabel };
            } else {
                if (testAreAtColsRB.isSelected()) {
                    data = new double[_sampleAData.length + _sampleBData.length][Math.max(_sampleAData[0].length,
                            _sampleBData[0].length)];

                    for (int r = 0; r < _sampleAData.length; r++)
                        System.arraycopy(_sampleAData[r], 0, data[r], 0, _sampleAData[0].length);
                    for (int r = 0; r < _sampleBData.length; r++)
                        System.arraycopy(_sampleBData[r], 0, data[_sampleAData.length + r], 0,
                                _sampleBData[0].length);

                    _seriesLabels = new String[_sampleAColLabels.length + _sampleBColLabels.length];
                    System.arraycopy(_sampleAColLabels, 0, _seriesLabels, 0, _sampleAColLabels.length);
                    System.arraycopy(_sampleBColLabels, 0, _seriesLabels, _sampleAColLabels.length,
                            _sampleBColLabels.length);
                } else {
                    data = new double[_sampleAData[0].length + _sampleBData[0].length][Math.max(_sampleAData.length,
                            _sampleBData.length)];

                    for (int r = 0; r < _sampleAData.length; r++)
                        for (int c = 0; c < _sampleAData[0].length; c++)
                            data[c][r] = _sampleAData[r][c];
                    for (int r = 0; r < _sampleBData.length; r++)
                        for (int c = 0; c < _sampleBData[0].length; c++)
                            data[_sampleAData[0].length + c][r] = _sampleBData[r][c];

                    _seriesLabels = new String[_sampleARowLabels.length + _sampleBRowLabels.length];
                    System.arraycopy(_sampleARowLabels, 0, _seriesLabels, 0, _sampleARowLabels.length);
                    System.arraycopy(_sampleBRowLabels, 0, _seriesLabels, _sampleARowLabels.length,
                            _sampleBRowLabels.length);
                }
            }
        }

        if (data.length < 1) {
            JOptionPane.showMessageDialog(this,
                    "An error occured when computing principal components.     "
                            + "\nExpect at least one data entry, but found none. \n",
                    "Not enough data", JOptionPane.ERROR_MESSAGE);
            return;
        }

        _data = data;
        _matrix = new Array2DRowRealMatrix(data);
        computePrincipalComponents();
        Plot();
    }

    private void computePrincipalComponents() {
        RealMatrix realMatrix = MatrixUtils.createRealMatrix(_data);
        Covariance covariance = new Covariance(realMatrix);
        _covariance = covariance.getCovarianceMatrix();
        EigenDecomposition ed = new EigenDecomposition(_covariance);
        double[] realEigenvalues = ed.getRealEigenvalues();

        int pcaCols = numPCAIndices(realEigenvalues, _level);
        int eigenCount = realEigenvalues.length;
        _principalComponents = new Array2DRowRealMatrix(eigenCount, pcaCols);
        _variance = new ArrayRealVector(pcaCols);

        for (int i = 0; i < pcaCols; i++) {
            RealVector eigenVec = ed.getEigenvector(i);
            for (int j = 0; j < eigenCount; j++)
                _principalComponents.setEntry(j, i, eigenVec.getEntry(j));
            _variance.setEntry(i, realEigenvalues[i]);
        }
    }

    private void Plot() {
        double[][] data = _principalComponents.getData();
        if (data[0].length < 2) {
            JOptionPane.showMessageDialog(this,
                    "An error occured when computing principal components.     "
                            + "\nRequire at least two principal components, but calculated "
                            + String.valueOf(data[0].length) + "\n",
                    "Not enough data", JOptionPane.ERROR_MESSAGE);
            return;
        }

        float[] yAxisColor = new float[3];
        Color.RGBtoHSB(255, 255, 255, yAxisColor);

        float[] hsbValues = new float[3];
        Color.RGBtoHSB(16, 23, 67, hsbValues);

        float[] pcColor = new float[3];
        Color.RGBtoHSB(255, 255, 0, pcColor);

        XYSeriesCollection dataset = new XYSeriesCollection();

        XYSeries series = new XYSeries("PC");
        for (double[] d : data)
            series.add(d[0], d[1]);
        dataset.addSeries(series);

        JFreeChart chart = ChartFactory.createScatterPlot(null, "Principal component 1", "Principal component 2",
                (XYDataset) dataset);
        chart.setBackgroundPaint(Color.getHSBColor(hsbValues[0], hsbValues[1], hsbValues[2]));
        chart.removeLegend();

        XYPlot plot = (XYPlot) chart.getPlot();
        plot.setBackgroundPaint(Color.getHSBColor(hsbValues[0], hsbValues[1], hsbValues[2]));

        Font axisLabelFont = new Font("Dialog", Font.PLAIN, 14);
        Font axisTickLabelFont = new Font("Dialog", Font.PLAIN, 12);

        plot.setDomainGridlinePaint(Color.gray);
        plot.setRangeGridlinePaint(Color.gray);

        plot.getDomainAxis().setTickLabelPaint(Color.white);
        plot.getDomainAxis().setLabelPaint(Color.white);
        plot.getDomainAxis().setLabelFont(axisLabelFont);
        plot.getDomainAxis().setTickLabelFont(axisTickLabelFont);

        plot.getRangeAxis().setTickLabelPaint(Color.getHSBColor(yAxisColor[0], yAxisColor[1], yAxisColor[2]));
        plot.getRangeAxis().setLabelPaint(Color.getHSBColor(yAxisColor[0], yAxisColor[1], yAxisColor[2]));
        plot.getRangeAxis().setLabelFont(axisLabelFont);
        plot.getRangeAxis().setTickLabelFont(axisTickLabelFont);

        Shape shape = ShapeUtilities.createDiagonalCross(4, 0.5f);
        XYItemRenderer renderer = chart.getXYPlot().getRenderer();
        renderer.setSeriesShape(0, shape);
        renderer.setSeriesPaint(0, Color.getHSBColor(pcColor[0], pcColor[1], pcColor[2]));

        ChartPanel panel = new ChartPanel(chart);
        Dimension plotDim = plotPanel.getSize();
        plotDim.height -= (plotDim.height * 10) / 100;
        plotDim.width -= (plotDim.width * 10) / 100;
        panel.setPreferredSize(plotDim);
        plotPanel.setViewportView(panel);

        revalidate();
        repaint();
    }

    private int numPCAIndices(double[] sortedEigenvalues, double level) {
        int index = Math.max(0, sortedEigenvalues.length - 1);
        if (index > 0) {
            double sum = 0.0;
            for (int i = 0; i < sortedEigenvalues.length; i++)
                sum += sortedEigenvalues[i];

            double testValue = sortedEigenvalues[sortedEigenvalues.length - 1 - index] / sum;
            double threshold = -Math.pow(10, -6.0);
            while ((testValue - level) < threshold && index > 0) {
                index--;
                testValue += sortedEigenvalues[sortedEigenvalues.length - 1 - index] / sum;
            }
        }
        return Math.max(index, 1);
    }
}