Example usage for weka.gui.visualize ThresholdVisualizePanel ThresholdVisualizePanel

List of usage examples for weka.gui.visualize ThresholdVisualizePanel ThresholdVisualizePanel

Introduction

In this page you can find the example usage for weka.gui.visualize ThresholdVisualizePanel ThresholdVisualizePanel.

Prototype

public ThresholdVisualizePanel() 

Source Link

Document

default constructor

Usage

From source file:TextClassifierUI.java

private void runButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_runButtonActionPerformed
    // TODO add your handling code here:
    try {//from w ww .  j  a  v a 2  s  .  c o m
        DocClassifier dr = new DocClassifier(trainFiles, testFiles);
        Classifier cl;
        if (naiveBayes.isSelected()) {
            cl = new NaiveBayes();
        } else {
            cl = new IBk(Integer.parseInt(kNearest.getText()));
        }
        Evaluation ev;
        if (useCV.isSelected()) {
            ev = dr.cvClassify(cl, Integer.parseInt(kFold.getText()));
            result.setText(dr.performanceEval(ev));
        } else {
            ev = dr.classify(cl);
            result.setText(dr.performanceEval(ev));
            result.append("\nDOCUMENT\t=>\tPREDICT\n");
            for (String p : dr.getDocPredList()) {
                result.append(p + "\n");
            }
        }
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        setVMC(ev.predictions(), vmc, true);
        showVMC(vmc);
    } catch (NumberFormatException e) {
        JOptionPane.showMessageDialog(this, "K Nearest and K-Fold must be positive numbers.",
                "Number Format Error", JOptionPane.ERROR_MESSAGE);
    } catch (Exception e) {
        e.printStackTrace();
        JOptionPane.showMessageDialog(this, "Failed to classify : " + e.getLocalizedMessage(),
                "Unexpected Error", JOptionPane.ERROR_MESSAGE);
    }
}

From source file:adams.flow.sink.WekaCostCurve.java

License:Open Source License

/**
 * Creates the panel to display in the dialog.
 *
 * @return      the panel/*ww w  .j a v  a  2s. c om*/
 */
@Override
protected BasePanel newPanel() {
    BasePanel result;

    result = new BasePanel(new BorderLayout());
    m_VisualizePanel = new ThresholdVisualizePanel();
    result.add(m_VisualizePanel, BorderLayout.CENTER);

    return result;
}

From source file:adams.flow.sink.WekaThresholdCurve.java

License:Open Source License

/**
 * Creates a new panel for the token.//from  ww  w .  j  av a 2 s .  c  om
 *
 * @param token   the token to display in a new panel, can be null
 * @return      the generated panel
 */
public AbstractDisplayPanel createDisplayPanel(Token token) {
    AbstractDisplayPanel result;
    String name;

    if (token != null)
        name = "Threshold curve (" + getEvaluation(token).getHeader().relationName() + ")";
    else
        name = "Threshold curve";

    result = new AbstractComponentDisplayPanel(name) {
        private static final long serialVersionUID = -7362768698548152899L;
        protected ThresholdVisualizePanel m_VisualizePanel;

        @Override
        protected void initGUI() {
            super.initGUI();
            setLayout(new BorderLayout());
            m_VisualizePanel = new ThresholdVisualizePanel();
            add(m_VisualizePanel, BorderLayout.CENTER);
        }

        @Override
        public void display(Token token) {
            try {
                Evaluation eval = getEvaluation(token);
                m_ClassLabelRange.setMax(eval.getHeader().classAttribute().numValues());
                int[] indices = m_ClassLabelRange.getIntIndices();
                for (int index : indices) {
                    ThresholdCurve curve = new ThresholdCurve();
                    Instances data = curve.getCurve(eval.predictions(), index);
                    PlotData2D plot = new PlotData2D(data);
                    plot.setPlotName(eval.getHeader().classAttribute().value(index));
                    plot.m_displayAllPoints = true;
                    boolean[] connectPoints = new boolean[data.numInstances()];
                    for (int cp = 1; cp < connectPoints.length; cp++)
                        connectPoints[cp] = true;
                    plot.setConnectPoints(connectPoints);
                    m_VisualizePanel.addPlot(plot);
                    if (data.attribute(m_AttributeX.toDisplay()) != null)
                        m_VisualizePanel.setXIndex(data.attribute(m_AttributeX.toDisplay()).index());
                    if (data.attribute(m_AttributeY.toDisplay()) != null)
                        m_VisualizePanel.setYIndex(data.attribute(m_AttributeY.toDisplay()).index());
                }
            } catch (Exception e) {
                getLogger().log(Level.SEVERE, "Failed to display token: " + token, e);
            }
        }

        @Override
        public JComponent supplyComponent() {
            return m_VisualizePanel;
        }

        @Override
        public void clearPanel() {
            m_VisualizePanel.removeAllPlots();
        }

        public void cleanUp() {
            m_VisualizePanel.removeAllPlots();
        }
    };

    if (token != null)
        result.display(token);

    return result;
}

From source file:adams.gui.menu.CostCurve.java

License:Open Source License

/**
 * Launches the functionality of the menu item.
 *///w  w w .  j av  a2 s. c  om
@Override
public void launch() {
    File file;
    if (m_Parameters.length == 0) {
        // choose file
        int retVal = m_FileChooser.showOpenDialog(null);
        if (retVal != JFileChooser.APPROVE_OPTION)
            return;
        file = m_FileChooser.getSelectedFile();
    } else {
        file = new PlaceholderFile(m_Parameters[0]).getAbsoluteFile();
        m_FileChooser.setSelectedFile(file);
    }

    // create plot
    Instances result;
    try {
        result = m_FileChooser.getLoader().getDataSet();
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(),
                "Error loading file '" + file + "':\n" + adams.core.Utils.throwableToString(e));
        return;
    }
    result.setClassIndex(result.numAttributes() - 1);
    ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
    PlotData2D plot = new PlotData2D(result);
    plot.setPlotName(result.relationName());
    plot.m_displayAllPoints = true;
    boolean[] connectPoints = new boolean[result.numInstances()];
    for (int cp = 1; cp < connectPoints.length; cp++)
        connectPoints[cp] = true;
    try {
        plot.setConnectPoints(connectPoints);
        vmc.addPlot(plot);
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(), "Error adding plot:\n" + adams.core.Utils.throwableToString(e));
        return;
    }

    ChildFrame frame = createChildFrame(vmc, GUIHelper.getDefaultDialogDimension());
    frame.setTitle(frame.getTitle() + " - " + file);
}

From source file:adams.gui.menu.ROC.java

License:Open Source License

/**
 * Launches the functionality of the menu item.
 *///from  w w w. j a v  a2  s.  c  o m
@Override
public void launch() {
    File file;
    if (m_Parameters.length == 0) {
        // choose file
        int retVal = m_FileChooser.showOpenDialog(null);
        if (retVal != JFileChooser.APPROVE_OPTION)
            return;
        file = m_FileChooser.getSelectedFile();
    } else {
        file = new PlaceholderFile(m_Parameters[0]).getAbsoluteFile();
        m_FileChooser.setSelectedFile(file);
    }

    // create plot
    Instances result;
    try {
        result = m_FileChooser.getLoader().getDataSet();
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(),
                "Error loading file '" + file + "':\n" + adams.core.Utils.throwableToString(e));
        return;
    }
    result.setClassIndex(result.numAttributes() - 1);
    ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
    vmc.setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");
    vmc.setName(result.relationName());
    PlotData2D tempd = new PlotData2D(result);
    tempd.setPlotName(result.relationName());
    tempd.addInstanceNumberAttribute();
    // specify which points are connected
    boolean[] cp = new boolean[result.numInstances()];
    for (int n = 1; n < cp.length; n++)
        cp[n] = true;
    try {
        tempd.setConnectPoints(cp);
        vmc.addPlot(tempd);
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(), "Error adding plot:\n" + adams.core.Utils.throwableToString(e));
        return;
    }

    ChildFrame frame = createChildFrame(vmc, GUIHelper.getDefaultDialogDimension());
    frame.setTitle(frame.getTitle() + " - " + file);
}

From source file:com.sliit.views.DataVisualizerPanel.java

void getRocCurve() {
    try {/* w w  w .j  av a  2  s.co m*/
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        // display curve
        String plotName = vmc.getName();
        final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
        jf.setSize(500, 400);
        jf.getContentPane().setLayout(new BorderLayout());
        jf.getContentPane().add(vmc, BorderLayout.CENTER);
        jf.addWindowListener(new java.awt.event.WindowAdapter() {
            public void windowClosing(java.awt.event.WindowEvent e) {
                jf.dispose();
            }
        });
        jf.setVisible(true);
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.sliit.views.KNNView.java

void getRocCurve() {
    try {// w  w  w.j  ava2s . co  m
        Instances data;
        data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        rocPanel.removeAll();
        rocPanel.add(vmc, "vmc", 0);
        rocPanel.revalidate();

    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.sliit.views.SVMView.java

/**
 * draw ROC curve/*from   w  w w .ja va  2 s. c  o m*/
 */
void getRocCurve() {
    try {
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        //train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        //            rocPanel.removeAll();
        //            rocPanel.add(vmc, "vmc", 0);
        //            rocPanel.revalidate();
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:cs.man.ac.uk.classifiers.GetAUC.java

License:Open Source License

/**
 * Computes the AUC for the supplied learner.
 * @return the AUC as a double value./*from   ww  w .  ja va  2  s.  c om*/
 */
@SuppressWarnings("unused")
private static double validate5x2CV() {
    try {
        // other options
        int runs = 5;
        int folds = 2;
        double AUC_SUM = 0;

        // perform cross-validation
        for (int i = 0; i < runs; i++) {
            // randomize data
            int seed = i + 1;
            Random rand = new Random(seed);
            Instances randData = new Instances(data);
            randData.randomize(rand);

            if (randData.classAttribute().isNominal()) {
                System.out.println("Stratifying...");
                randData.stratify(folds);
            }

            Evaluation eval = new Evaluation(randData);

            for (int n = 0; n < folds; n++) {
                Instances train = randData.trainCV(folds, n);
                Instances test = randData.testCV(folds, n);

                // the above code is used by the StratifiedRemoveFolds filter, the
                // code below by the Explorer/Experimenter:
                // Instances train = randData.trainCV(folds, n, rand);

                // build and evaluate classifier
                String[] options = { "-U", "-A" };
                J48 classifier = new J48();
                //HTree classifier = new HTree();

                classifier.setOptions(options);
                classifier.buildClassifier(train);
                eval.evaluateModel(classifier, test);

                // generate curve
                ThresholdCurve tc = new ThresholdCurve();
                int classIndex = 0;
                Instances result = tc.getCurve(eval.predictions(), classIndex);

                // plot curve
                vmc = new ThresholdVisualizePanel();
                AUC_SUM += ThresholdCurve.getROCArea(result);
                System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM);
            }
        }

        return AUC_SUM / ((double) runs * (double) folds);
    } catch (Exception e) {
        System.out.println("Exception validating data!");
        return 0;
    }
}

From source file:cs.man.ac.uk.classifiers.GetAUC.java

License:Open Source License

/**
 * Computes the AUC for the supplied learner.
 * @param learner the learning algorithm to use.
 * @return the AUC as a double value./*from   ww  w .j  a v  a 2s . c  om*/
 */
@SuppressWarnings("unused")
private static double validate(Classifier learner) {
    try {

        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(learner, data, 2, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        vmc = new ThresholdVisualizePanel();
        double AUC = ThresholdCurve.getROCArea(result);
        vmc.setROCString(
                "(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 9) + ")");
        vmc.setName(result.relationName());

        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();

        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++)
            cp[n] = true;

        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        return AUC;
    } catch (Exception e) {
        System.out.println("Exception validating data!");
        return 0;
    }
}