Example usage for weka.core Instances trainCV

List of usage examples for weka.core Instances trainCV

Introduction

In this page you can find the example usage for weka.core Instances trainCV.

Prototype



public Instances trainCV(int numFolds, int numFold, Random random) 

Source Link

Document

Creates the training set for one fold of a cross-validation on the dataset.

Usage

From source file:sirius.clustering.main.ClustererClassificationPane.java

License:Open Source License

private void start() {
    //Run Classifier
    if (this.inputDirectoryTextField.getText().length() == 0) {
        JOptionPane.showMessageDialog(parent, "Please set Input Directory to where the clusterer output are!",
                "Evaluate Classifier", JOptionPane.ERROR_MESSAGE);
        return;/*from   ww  w  .j  a va2s  .c o  m*/
    }
    if (m_ClassifierEditor.getValue() == null) {
        JOptionPane.showMessageDialog(parent, "Please choose Classifier!", "Evaluate Classifier",
                JOptionPane.ERROR_MESSAGE);
        return;
    }
    if (validateStatsSettings(1) == false) {
        return;
    }
    if (this.clusteringClassificationThread == null) {
        startButton.setEnabled(false);
        stopButton.setEnabled(true);
        tabbedClassifierPane.setSelectedIndex(0);
        this.clusteringClassificationThread = (new Thread() {
            public void run() {
                //Clear the output text area
                levelOneClassifierOutputTextArea.setText("");
                resultsTableModel.reset();
                //double threshold = Double.parseDouble(classifierOneThresholdTextField.getText());                
                //cross-validation
                int numFolds;
                if (jackKnifeRadioButton.isSelected())
                    numFolds = -1;
                else
                    numFolds = Integer.parseInt(foldsField.getText());
                StringTokenizer st = new StringTokenizer(inputDirectoryTextField.getText(), File.separator);
                String filename = "";
                while (st.hasMoreTokens()) {
                    filename = st.nextToken();
                }
                StringTokenizer st2 = new StringTokenizer(filename, "_.");
                numOfCluster = 0;
                if (st2.countTokens() >= 2) {
                    st2.nextToken();
                    String numOfClusterString = st2.nextToken().replaceAll("cluster", "");
                    try {
                        numOfCluster = Integer.parseInt(numOfClusterString);
                    } catch (NumberFormatException e) {
                        JOptionPane.showMessageDialog(parent,
                                "Please choose the correct file! (Output from Utilize Clusterer)", "ERROR",
                                JOptionPane.ERROR_MESSAGE);
                    }
                }
                Classifier template = (Classifier) m_ClassifierEditor.getValue();

                for (int x = 0; x <= numOfCluster && clusteringClassificationThread != null; x++) {//Test each cluster
                    try {
                        long totalTimeStart = 0, totalTimeElapsed = 0;
                        totalTimeStart = System.currentTimeMillis();
                        statusLabel.setText("Reading in cluster" + x + " file..");
                        String inputFilename = inputDirectoryTextField.getText()
                                .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".arff");
                        String outputScoreFilename = inputDirectoryTextField.getText()
                                .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score");
                        BufferedWriter output = new BufferedWriter(new FileWriter(outputScoreFilename));
                        Instances inst = new Instances(new FileReader(inputFilename));
                        //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files
                        inst.setClassIndex(inst.numAttributes() - 1);
                        Random random = new Random(1);//Simply set to 1, shall implement the random seed option later
                        inst.randomize(random);
                        if (inst.attribute(inst.classIndex()).isNominal())
                            inst.stratify(numFolds);
                        // for timing
                        ClassifierResults classifierResults = new ClassifierResults(false, 0);
                        String classifierName = m_ClassifierEditor.getValue().getClass().getName();
                        classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ",
                                classifierName);
                        classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                                inputFilename);
                        classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                                "NA");
                        //ArrayList<Double> resultList = new ArrayList<Double>();     
                        if (jackKnifeRadioButton.isSelected() || numFolds > inst.numInstances() - 1)
                            numFolds = inst.numInstances() - 1;
                        for (int fold = 0; fold < numFolds && clusteringClassificationThread != null; fold++) {//Doing cross-validation          
                            statusLabel.setText("Cluster: " + x + " - Training Fold " + (fold + 1) + "..");
                            Instances train = inst.trainCV(numFolds, fold, random);
                            Classifier current = null;
                            try {
                                current = Classifier.makeCopy(template);
                                current.buildClassifier(train);
                                Instances test = inst.testCV(numFolds, fold);
                                statusLabel.setText("Cluster: " + x + " - Testing Fold " + (fold + 1) + "..");
                                for (int jj = 0; jj < test.numInstances(); jj++) {
                                    double[] result = current.distributionForInstance(test.instance(jj));
                                    output.write("Cluster: " + x);
                                    output.newLine();
                                    output.newLine();
                                    output.write(test.instance(jj).stringValue(test.classAttribute()) + ",0="
                                            + result[0]);
                                    output.newLine();
                                }
                            } catch (Exception ex) {
                                ex.printStackTrace();
                                statusLabel.setText("Error in cross-validation!");
                                startButton.setEnabled(true);
                                stopButton.setEnabled(false);
                            }
                        }
                        output.close();
                        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
                        classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                                Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                                        + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2)
                                        + " seconds");
                        double threshold = validateFieldAsThreshold(classifierOneThresholdTextField.getText(),
                                "Threshold Field", classifierOneThresholdTextField);
                        String filename2 = inputDirectoryTextField.getText()
                                .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score");
                        PredictionStats classifierStats = new PredictionStats(filename2, 0, threshold);

                        resultsTableModel.add("Cluster " + x, classifierResults, classifierStats);
                        resultsTable.setRowSelectionInterval(x, x);
                        computeStats(numFolds);//compute and display the results                    
                    } catch (Exception e) {
                        e.printStackTrace();
                        statusLabel.setText("Error in reading file!");
                        startButton.setEnabled(true);
                        stopButton.setEnabled(false);
                    }
                } //end of cluster for loop

                resultsTableModel.add("Summary - Equal Weightage", null, null);
                resultsTable.setRowSelectionInterval(numOfCluster + 1, numOfCluster + 1);
                computeStats(numFolds);
                resultsTableModel.add("Summary - Weighted Average", null, null);
                resultsTable.setRowSelectionInterval(numOfCluster + 2, numOfCluster + 2);
                computeStats(numFolds);

                if (clusteringClassificationThread != null)
                    statusLabel.setText("Done!");
                else
                    statusLabel.setText("Interrupted..");
                startButton.setEnabled(true);
                stopButton.setEnabled(false);
                if (classifierOne != null) {
                    levelOneClassifierOutputScrollPane.getVerticalScrollBar()
                            .setValue(levelOneClassifierOutputScrollPane.getVerticalScrollBar().getMaximum());
                }
                clusteringClassificationThread = null;

            }
        });
        this.clusteringClassificationThread.setPriority(Thread.MIN_PRIORITY);
        this.clusteringClassificationThread.start();
    } else {
        JOptionPane.showMessageDialog(parent,
                "Cannot start new job as previous job still running. Click stop to terminate previous job",
                "ERROR", JOptionPane.ERROR_MESSAGE);
    }
}