weka.gui.visualize.ThresholdVisualizePanel.java Source code

Java tutorial

Introduction

Here is the source code for weka.gui.visualize.ThresholdVisualizePanel.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    ThresholdVisualizePanel.java
 *    Copyright (C) 2003-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.gui.visualize;

import java.awt.BorderLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;

import javax.swing.BorderFactory;
import javax.swing.JFrame;
import javax.swing.border.TitledBorder;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.Prediction;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Instances;
import weka.core.SingleIndex;
import weka.core.Utils;

/**
 * This panel is a VisualizePanel, with the added ablility to display the area
 * under the ROC curve if an ROC curve is chosen.
 * 
 * @author Dale Fletcher (dale@cs.waikato.ac.nz)
 * @author FracPete (fracpete at waikato dot ac dot nz)
 * @version $Revision$
 */
public class ThresholdVisualizePanel extends VisualizePanel {

    /** for serialization */
    private static final long serialVersionUID = 3070002211779443890L;

    /** The string to add to the Plot Border. */
    private String m_ROCString = "";

    /** Original border text */
    private final String m_savePanelBorderText;

    /**
     * default constructor
     */
    public ThresholdVisualizePanel() {
        super();

        // Save the current border text
        TitledBorder tb = (TitledBorder) m_plotSurround.getBorder();
        m_savePanelBorderText = tb.getTitle();
    }

    /**
     * Set the string with ROC area
     * 
     * @param str ROC area string to add to border
     */
    public void setROCString(String str) {
        m_ROCString = str;
    }

    /**
     * This extracts the ROC area string
     * 
     * @return ROC area string
     */
    public String getROCString() {
        return m_ROCString;
    }

    /**
     * This overloads VisualizePanel's setUpComboBoxes to add ActionListeners to
     * watch for when the X/Y Axis comboboxes are changed.
     * 
     * @param inst a set of instances with data for plotting
     */
    @Override
    public void setUpComboBoxes(Instances inst) {
        super.setUpComboBoxes(inst);

        m_XCombo.addActionListener(new ActionListener() {
            @Override
            public void actionPerformed(ActionEvent e) {
                setBorderText();
            }
        });
        m_YCombo.addActionListener(new ActionListener() {
            @Override
            public void actionPerformed(ActionEvent e) {
                setBorderText();
            }
        });

        // Just in case the default is ROC
        setBorderText();
    }

    /**
     * This checks the current selected X/Y Axis comboBoxes to see if an ROC graph
     * is selected. If so, add the ROC area string to the plot border, otherwise
     * display the original border text.
     */
    private void setBorderText() {

        String xs = m_XCombo.getSelectedItem().toString();
        String ys = m_YCombo.getSelectedItem().toString();

        if (xs.equals("X: False Positive Rate (Num)") && ys.equals("Y: True Positive Rate (Num)")) {
            m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText + " " + m_ROCString)));
        } else {
            m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText)));
        }
    }

    /**
     * displays the previously saved instances
     * 
     * @param insts the instances to display
     * @throws Exception if display is not possible
     */
    @Override
    protected void openVisibleInstances(Instances insts) throws Exception {
        super.openVisibleInstances(insts);

        setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(insts), 4) + ")");

        setBorderText();
    }

    /**
     * Starts the ThresholdVisualizationPanel with parameters from the command
     * line.
     * <p/>
     * 
     * Valid options are:
     * <p/>
     * -h <br/>
     * lists all the commandline parameters
     * <p/>
     * 
     * -t file <br/>
     * Dataset to process with given classifier.
     * <p/>
     * 
     * -W classname <br/>
     * Full classname of classifier to run.<br/>
     * Options after '--' are passed to the classifier. <br/>
     * (default weka.classifiers.functions.Logistic)
     * <p/>
     * 
     * -r number <br/>
     * The number of runs to perform (default 2).
     * <p/>
     * 
     * -x number <br/>
     * The number of Cross-validation folds (default 10).
     * <p/>
     * 
     * -l file <br/>
     * Previously saved threshold curve ARFF file.
     * <p/>
     * 
     * @param args optional commandline parameters
     */
    public static void main(String[] args) {
        Instances inst;
        Classifier classifier;
        int runs;
        int folds;
        String tmpStr;
        boolean compute;
        Instances result;
        String[] options;
        SingleIndex classIndex;
        SingleIndex valueIndex;
        int seed;

        inst = null;
        classifier = null;
        runs = 2;
        folds = 10;
        compute = true;
        result = null;
        classIndex = null;
        valueIndex = null;
        seed = 1;

        try {
            // help?
            if (Utils.getFlag('h', args)) {
                System.out.println("\nOptions for " + ThresholdVisualizePanel.class.getName() + ":\n");
                System.out.println("-h\n\tThis help.");
                System.out.println("-t <file>\n\tDataset to process with given classifier.");
                System.out.println("-c <num>\n\tThe class index. first and last are valid, too (default: last).");
                System.out.println(
                        "-C <num>\n\tThe index of the class value to get the the curve for (default: first).");
                System.out.println(
                        "-W <classname>\n\tFull classname of classifier to run.\n\tOptions after '--' are passed to the classifier.\n\t(default: weka.classifiers.functions.Logistic)");
                System.out.println("-r <number>\n\tThe number of runs to perform (default: 1).");
                System.out.println("-x <number>\n\tThe number of Cross-validation folds (default: 10).");
                System.out.println("-S <number>\n\tThe seed value for randomizing the data (default: 1).");
                System.out.println("-l <file>\n\tPreviously saved threshold curve ARFF file.");
                return;
            }

            // regular options
            tmpStr = Utils.getOption('l', args);
            if (tmpStr.length() != 0) {
                result = new Instances(new BufferedReader(new FileReader(tmpStr)));
                compute = false;
            }

            if (compute) {
                tmpStr = Utils.getOption('r', args);
                if (tmpStr.length() != 0) {
                    runs = Integer.parseInt(tmpStr);
                } else {
                    runs = 1;
                }

                tmpStr = Utils.getOption('x', args);
                if (tmpStr.length() != 0) {
                    folds = Integer.parseInt(tmpStr);
                } else {
                    folds = 10;
                }

                tmpStr = Utils.getOption('S', args);
                if (tmpStr.length() != 0) {
                    seed = Integer.parseInt(tmpStr);
                } else {
                    seed = 1;
                }

                tmpStr = Utils.getOption('t', args);
                if (tmpStr.length() != 0) {
                    inst = new Instances(new BufferedReader(new FileReader(tmpStr)));
                    inst.setClassIndex(inst.numAttributes() - 1);
                }

                tmpStr = Utils.getOption('W', args);
                if (tmpStr.length() != 0) {
                    options = Utils.partitionOptions(args);
                } else {
                    tmpStr = weka.classifiers.functions.Logistic.class.getName();
                    options = new String[0];
                }
                classifier = AbstractClassifier.forName(tmpStr, options);

                tmpStr = Utils.getOption('c', args);
                if (tmpStr.length() != 0) {
                    classIndex = new SingleIndex(tmpStr);
                } else {
                    classIndex = new SingleIndex("last");
                }

                tmpStr = Utils.getOption('C', args);
                if (tmpStr.length() != 0) {
                    valueIndex = new SingleIndex(tmpStr);
                } else {
                    valueIndex = new SingleIndex("first");
                }
            }

            // compute if necessary
            if (compute) {
                if (classIndex != null) {
                    classIndex.setUpper(inst.numAttributes() - 1);
                    inst.setClassIndex(classIndex.getIndex());
                } else {
                    inst.setClassIndex(inst.numAttributes() - 1);
                }

                if (valueIndex != null) {
                    valueIndex.setUpper(inst.classAttribute().numValues() - 1);
                }

                ThresholdCurve tc = new ThresholdCurve();
                EvaluationUtils eu = new EvaluationUtils();
                ArrayList<Prediction> predictions = new ArrayList<Prediction>();
                for (int i = 0; i < runs; i++) {
                    eu.setSeed(seed + i);
                    predictions.addAll(eu.getCVPredictions(classifier, inst, folds));
                }

                if (valueIndex != null) {
                    result = tc.getCurve(predictions, valueIndex.getIndex());
                } else {
                    result = tc.getCurve(predictions);
                }
            }

            // setup GUI
            ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
            vmc.setROCString(
                    "(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");
            if (compute) {
                vmc.setName(result.relationName() + ". (Class value "
                        + inst.classAttribute().value(valueIndex.getIndex()) + ")");
            } else {
                vmc.setName(result.relationName() + " (display only)");
            }
            PlotData2D tempd = new PlotData2D(result);
            tempd.setPlotName(result.relationName());
            tempd.addInstanceNumberAttribute();
            vmc.addPlot(tempd);

            String plotName = vmc.getName();
            final JFrame jf = new JFrame("Weka Classifier Visualize: " + plotName);
            jf.setSize(500, 400);
            jf.getContentPane().setLayout(new BorderLayout());

            jf.getContentPane().add(vmc, BorderLayout.CENTER);
            jf.addWindowListener(new WindowAdapter() {
                @Override
                public void windowClosing(WindowEvent e) {
                    jf.dispose();
                }
            });

            jf.setVisible(true);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}