mrmc.chart.ROCCurvePlot.java Source code

Java tutorial

Introduction

Here is the source code for mrmc.chart.ROCCurvePlot.java

Source

/**
 * ROCCurvePlot.java
 * 
 * This software and documentation (the "Software") were developed at the Food and Drug Administration (FDA) 
 * by employees of the Federal Government in the course of their official duties. Pursuant to Title 17, Section 
 * 105 of the United States Code, this work is not subject to copyright protection and is in the public domain. 
 * Permission is hereby granted, free of charge, to any person obtaining a copy of the Software, to deal in the 
 * Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, 
 * distribute, sublicense, or sell copies of the Software or derivatives, and to permit persons to whom the 
 * Software is furnished to do so. FDA assumes no responsibility whatsoever for use by other parties of the 
 * Software, its source code, documentation or compiled executables, and makes no guarantees, expressed or 
 * implied, about its quality, reliability, or any other characteristic.   Further, use of this code in no way 
 * implies endorsement by the FDA or confers any advantage in regulatory decisions.  Although this software 
 * can be redistributed and/or modified freely, we ask that any derivative works bear some notice that they 
 * are derived from it, and any modified versions bear some notice that they have been modified.
 */

package mrmc.chart;

import java.awt.BorderLayout;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import java.util.ArrayList;
import java.util.TreeMap;
import java.util.TreeSet;

import javax.swing.JCheckBox;
import javax.swing.JFrame;
import javax.swing.JPanel;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.NumberTickUnit;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

/**
 * Creates a chart displaying ROC curves for each reader of a particular
 * modality. Reader curves are then averaged in direction of sensitivity
 * (vertically), specificity (horizontally), and sensitivity + specificity
 * (diagonally). Additionally all reader scores are placed in one set and a
 * curve is determined, creating a "pooled average"
 * 
 * @author Rohan Pathare
 */
public class ROCCurvePlot extends JFrame {

    private static final long serialVersionUID = 1L;
    private XYLineAndShapeRenderer renderer;
    private XYSeriesCollection seriesCollection;
    private ArrayList<InterpolatedLine> allLines;
    private ArrayList<String> readerSeriesTitles;
    private ArrayList<JCheckBox> readerSeriesBoxes;
    private JCheckBox vert;
    private JCheckBox horiz;
    private JCheckBox diag;
    private JCheckBox pooled;

    /**
     * Sole constructor. Creates a line plot display ROC curves
     * 
     * @param title Title of the chart
     * @param xaxis x-axis label
     * @param yaxis y-axis label
     * @param treeMap Mapping of readers to a set of points defining an ROC curve
     */
    public ROCCurvePlot(final String title, String xaxis, String yaxis, TreeMap<String, TreeSet<XYPair>> treeMap) {
        super(title);

        createDataset(treeMap);
        final JFreeChart chart = ChartFactory.createScatterPlot(title, xaxis, yaxis, seriesCollection,
                PlotOrientation.VERTICAL, true, true, false);
        XYPlot xyplot = (XYPlot) chart.getPlot();
        xyplot.setDomainCrosshairVisible(true);
        xyplot.setRangeCrosshairVisible(true);
        NumberAxis domain = (NumberAxis) xyplot.getDomainAxis();
        domain.setRange(0.00, 1.00);
        domain.setTickUnit(new NumberTickUnit(0.1));
        NumberAxis range = (NumberAxis) xyplot.getRangeAxis();
        range.setRange(0.00, 1.00);
        range.setTickUnit(new NumberTickUnit(0.1));
        renderer = new XYLineAndShapeRenderer();
        chart.getXYPlot().setRenderer(renderer);
        ChartPanel chartPanel = new ChartPanel(chart);

        JPanel readerSelect = new JPanel(new WrapLayout());
        readerSeriesBoxes = new ArrayList<JCheckBox>();

        for (String r : treeMap.keySet()) {
            JCheckBox aBox = new JCheckBox("" + r);
            aBox.setSelected(false);
            aBox.addItemListener(new SeriesSelectListener());
            hideSeries("" + r);
            readerSeriesBoxes.add(aBox);
            readerSelect.add(aBox);
        }

        renderer.setSeriesShapesVisible(seriesCollection.getSeriesIndex("Vertical Average"), false);
        renderer.setSeriesStroke(seriesCollection.getSeriesIndex("Vertical Average"), new java.awt.BasicStroke(3f));
        renderer.setSeriesShapesVisible(seriesCollection.getSeriesIndex("Horizontal Average"), false);
        renderer.setSeriesStroke(seriesCollection.getSeriesIndex("Horizontal Average"),
                new java.awt.BasicStroke(3f));
        renderer.setSeriesShapesVisible(seriesCollection.getSeriesIndex("Diagonal Average"), false);
        renderer.setSeriesStroke(seriesCollection.getSeriesIndex("Diagonal Average"), new java.awt.BasicStroke(3f));
        renderer.setSeriesStroke(seriesCollection.getSeriesIndex("Pooled Average"), new java.awt.BasicStroke(3f));

        vert = new JCheckBox("Vertical Average");
        vert.setSelected(true);
        vert.addItemListener(new SeriesSelectListener());
        readerSelect.add(vert);
        horiz = new JCheckBox("Horizontal Average");
        horiz.setSelected(true);
        horiz.addItemListener(new SeriesSelectListener());
        readerSelect.add(horiz);
        diag = new JCheckBox("Diagonal Average");
        diag.setSelected(true);
        diag.addItemListener(new SeriesSelectListener());
        readerSelect.add(diag);
        pooled = new JCheckBox("Pooled Average");
        pooled.setSelected(true);
        pooled.addItemListener(new SeriesSelectListener());
        readerSelect.add(pooled);

        JCheckBox allReaders = new JCheckBox("Show Readers");
        allReaders.setSelected(false);
        allReaders.addItemListener(new ReadersSelectListener());
        readerSelect.add(allReaders);
        JCheckBox allAverages = new JCheckBox("Show Averages");
        allAverages.setSelected(true);
        allAverages.addItemListener(new AverageSelectListener());
        readerSelect.add(allAverages);

        chartPanel.setPreferredSize(new java.awt.Dimension(700, 700));
        this.add(chartPanel);
        this.add(readerSelect, BorderLayout.PAGE_END);

    }

    /**
     * Converts the mapping of readers to curve points into a collection of
     * separate XY data.
     * 
     * @param treeMap Mapping of readers to points defining a curve
     */
    private void createDataset(TreeMap<String, TreeSet<XYPair>> treeMap) {
        seriesCollection = new XYSeriesCollection();
        readerSeriesTitles = new ArrayList<String>();

        for (String r : treeMap.keySet()) {
            XYSeries series = new XYSeries("" + r, false);
            readerSeriesTitles.add("" + r);
            for (XYPair point : treeMap.get(r)) {
                series.add(point.x, point.y);
            }
            seriesCollection.addSeries(series);
        }

        allLines = new ArrayList<InterpolatedLine>();
        for (String r : treeMap.keySet()) {
            allLines.add(new InterpolatedLine(treeMap.get(r)));
        }

        XYSeries vertAvg = generateVerticalROC();
        seriesCollection.addSeries(vertAvg);
        XYSeries horizAvg = generateHorizontalROC();
        seriesCollection.addSeries(horizAvg);
        XYSeries diagAvg = generateDiagonalROC(treeMap);
        seriesCollection.addSeries(diagAvg);
        XYSeries pooledAvg = new XYSeries("Pooled Average", false);
        seriesCollection.addSeries(pooledAvg);

    }

    /**
     * Adds a set of XY points to the collection of ROC curves
     * 
     * @param newData Set of XY coordinates
     * @param type Name for this set of points
     */
    public void addData(TreeSet<XYPair> newData, String type) {
        for (XYPair point : newData) {
            seriesCollection.getSeries(type).add(point.x, point.y);
        }
    }

    /**
     * Creates an ROC curve that averages together the scores for all readers in
     * the diagonal direction
     * 
     * @param treeMap Mapping of readers to points defining a curve
     * @return Series containing the ROC curve points
     */
    private XYSeries generateDiagonalROC(TreeMap<String, TreeSet<XYPair>> treeMap) {
        XYSeries diagAvg = new XYSeries("Diagonal Average", false);
        TreeMap<String, TreeSet<XYPair>> rotatedData = new TreeMap<String, TreeSet<XYPair>>();

        // rotate all points in data 45 degrees clockwise about origin
        for (String r : treeMap.keySet()) {
            rotatedData.put(r, new TreeSet<XYPair>());
            for (XYPair point : treeMap.get(r)) {
                double x2 = (point.x + point.y) / Math.sqrt(2.0);
                double y2 = (point.y - point.x) / Math.sqrt(2.0);
                rotatedData.get(r).add(new XYPair(x2, y2));
            }
        }

        // generate linear interpolation with new points
        ArrayList<InterpolatedLine> rotatedLines = new ArrayList<InterpolatedLine>();
        for (String r : rotatedData.keySet()) {
            rotatedLines.add(new InterpolatedLine(rotatedData.get(r)));
        }

        // take vertical sample averages from x = 0 to x = 1
        for (double i = 0; i <= Math.sqrt(2); i += 0.01) {
            double avg = 0;
            int counter = 0;
            for (InterpolatedLine line : rotatedLines) {
                avg += line.getYatDiag(i);
                counter++;
            }

            // rotate points back 45 degrees counterclockwise
            double x1 = i;
            double y1 = (avg / (double) counter);
            double x2 = (x1 * Math.cos(Math.toRadians(45))) - (y1 * Math.sin(Math.toRadians(45)));
            double y2 = (x1 * Math.sin(Math.toRadians(45))) + (y1 * Math.cos(Math.toRadians(45)));
            diagAvg.add(x2, y2);
        }

        diagAvg.add(1, 1);
        return diagAvg;
    }

    /**
     * Creates an ROC curve that averages together the scores for all readers in
     * the horizontal direction
     * 
     * @return Series containing the ROC curve points
     */
    private XYSeries generateHorizontalROC() {
        XYSeries horizAvg = new XYSeries("Horizontal Average", false);
        for (double i = 0; i <= 1.01; i += 0.01) {
            double avg = 0;
            int counter = 0;
            for (InterpolatedLine line : allLines) {
                avg += line.getXat(i);
                counter++;
            }
            horizAvg.add(avg / (double) counter, i);
        }
        return horizAvg;
    }

    /**
     * Creates an ROC curve that averages together the scores for all readers in
     * the vertical direction
     * 
     * @return Series containing the ROC curve points
     */
    private XYSeries generateVerticalROC() {
        XYSeries vertAvg = new XYSeries("Vertical Average", false);
        for (double i = 0; i <= 1.01; i += 0.01) {
            double avg = 0;
            int counter = 0;
            for (InterpolatedLine line : allLines) {
                avg += line.getYat(i);
                counter++;
            }
            vertAvg.add(i, avg / (double) counter);
        }
        return vertAvg;
    }

    /**
     * Hides the specified series from the chart
     * 
     * @param series Which series to hide
     */
    private void hideSeries(String series) {
        renderer.setSeriesLinesVisible((seriesCollection.getSeriesIndex(series)), false);
        renderer.setSeriesShapesVisible((seriesCollection.getSeriesIndex(series)), false);
        renderer.setSeriesVisibleInLegend((seriesCollection.getSeriesIndex(series)), false);
    }

    /**
     * Displays the specified series on the chart
     * 
     * @param series Which series to display
     * @param shapes Whether or not to display shapes indicating individual data
     *            points
     */
    private void showSeries(String series, boolean shapes) {
        renderer.setSeriesLinesVisible((seriesCollection.getSeriesIndex(series)), true);
        if (shapes) {
            renderer.setSeriesShapesVisible((seriesCollection.getSeriesIndex(series)), true);
        }
        renderer.setSeriesVisibleInLegend((seriesCollection.getSeriesIndex(series)), true);
    }

    /**
     * Handler for checkbox to show or hide all readers from chart
     */
    class ReadersSelectListener implements ItemListener {
        public void itemStateChanged(ItemEvent e) {
            if (e.getStateChange() == ItemEvent.DESELECTED) {
                for (JCheckBox readerBox : readerSeriesBoxes) {
                    readerBox.setSelected(false);
                }
                for (String title : readerSeriesTitles) {
                    hideSeries(title);
                }
            } else if (e.getStateChange() == ItemEvent.SELECTED) {
                for (JCheckBox readerBox : readerSeriesBoxes) {
                    readerBox.setSelected(true);
                }
                for (String title : readerSeriesTitles) {
                    showSeries(title, true);
                }
            }
        }
    }

    /**
     * Handler for checkbox to show or hide all averages from chart
     */
    class AverageSelectListener implements ItemListener {
        public void itemStateChanged(ItemEvent e) {
            if (e.getStateChange() == ItemEvent.DESELECTED) {
                vert.setSelected(false);
                horiz.setSelected(false);
                diag.setSelected(false);
                pooled.setSelected(false);
                hideSeries("Vertical Average");
                hideSeries("Horizontal Average");
                hideSeries("Diagonal Average");
                hideSeries("Pooled Average");
            } else if (e.getStateChange() == ItemEvent.SELECTED) {
                vert.setSelected(true);
                horiz.setSelected(true);
                diag.setSelected(true);
                pooled.setSelected(true);
                showSeries("Vertical Average", false);
                showSeries("Horizontal Average", false);
                showSeries("Diagonal Average", false);
                showSeries("Pooled Average", true);
            }

        }
    }

    /**
     * Handler for checkboxes to show/hide a specific series
     */
    class SeriesSelectListener implements ItemListener {
        public void itemStateChanged(ItemEvent e) {
            if (e.getStateChange() == ItemEvent.DESELECTED) {
                if (((JCheckBox) e.getItem()).getText().equals("Vertical Average")
                        || ((JCheckBox) e.getItem()).getText().equals("Horizontal Average")
                        || ((JCheckBox) e.getItem()).getText().equals("Diagonal Average")) {
                    hideSeries(((JCheckBox) e.getItem()).getText());
                } else {
                    hideSeries(((JCheckBox) e.getItem()).getText());
                }
            } else if (e.getStateChange() == ItemEvent.SELECTED) {
                if (((JCheckBox) e.getItem()).getText().equals("Vertical Average")
                        || ((JCheckBox) e.getItem()).getText().equals("Horizontal Average")
                        || ((JCheckBox) e.getItem()).getText().equals("Diagonal Average")) {
                    showSeries(((JCheckBox) e.getItem()).getText(), false);
                } else {
                    showSeries(((JCheckBox) e.getItem()).getText(), true);
                }
            }
        }
    }
}