List of usage examples for weka.core Instances trainCV
public Instances trainCV(int numFolds, int numFold, Random random)
From source file:sirius.clustering.main.ClustererClassificationPane.java
License:Open Source License
private void start() { //Run Classifier if (this.inputDirectoryTextField.getText().length() == 0) { JOptionPane.showMessageDialog(parent, "Please set Input Directory to where the clusterer output are!", "Evaluate Classifier", JOptionPane.ERROR_MESSAGE); return;/*from ww w .j a va2s .c o m*/ } if (m_ClassifierEditor.getValue() == null) { JOptionPane.showMessageDialog(parent, "Please choose Classifier!", "Evaluate Classifier", JOptionPane.ERROR_MESSAGE); return; } if (validateStatsSettings(1) == false) { return; } if (this.clusteringClassificationThread == null) { startButton.setEnabled(false); stopButton.setEnabled(true); tabbedClassifierPane.setSelectedIndex(0); this.clusteringClassificationThread = (new Thread() { public void run() { //Clear the output text area levelOneClassifierOutputTextArea.setText(""); resultsTableModel.reset(); //double threshold = Double.parseDouble(classifierOneThresholdTextField.getText()); //cross-validation int numFolds; if (jackKnifeRadioButton.isSelected()) numFolds = -1; else numFolds = Integer.parseInt(foldsField.getText()); StringTokenizer st = new StringTokenizer(inputDirectoryTextField.getText(), File.separator); String filename = ""; while (st.hasMoreTokens()) { filename = st.nextToken(); } StringTokenizer st2 = new StringTokenizer(filename, "_."); numOfCluster = 0; if (st2.countTokens() >= 2) { st2.nextToken(); String numOfClusterString = st2.nextToken().replaceAll("cluster", ""); try { numOfCluster = Integer.parseInt(numOfClusterString); } catch (NumberFormatException e) { JOptionPane.showMessageDialog(parent, "Please choose the correct file! (Output from Utilize Clusterer)", "ERROR", JOptionPane.ERROR_MESSAGE); } } Classifier template = (Classifier) m_ClassifierEditor.getValue(); for (int x = 0; x <= numOfCluster && clusteringClassificationThread != null; x++) {//Test each cluster try { long totalTimeStart = 0, totalTimeElapsed = 0; totalTimeStart = System.currentTimeMillis(); statusLabel.setText("Reading in cluster" + x + " file.."); String inputFilename = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".arff"); String outputScoreFilename = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score"); BufferedWriter output = new BufferedWriter(new FileWriter(outputScoreFilename)); Instances inst = new Instances(new FileReader(inputFilename)); //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files inst.setClassIndex(inst.numAttributes() - 1); Random random = new Random(1);//Simply set to 1, shall implement the random seed option later inst.randomize(random); if (inst.attribute(inst.classIndex()).isNominal()) inst.stratify(numFolds); // for timing ClassifierResults classifierResults = new ClassifierResults(false, 0); String classifierName = m_ClassifierEditor.getValue().getClass().getName(); classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", inputFilename); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", "NA"); //ArrayList<Double> resultList = new ArrayList<Double>(); if (jackKnifeRadioButton.isSelected() || numFolds > inst.numInstances() - 1) numFolds = inst.numInstances() - 1; for (int fold = 0; fold < numFolds && clusteringClassificationThread != null; fold++) {//Doing cross-validation statusLabel.setText("Cluster: " + x + " - Training Fold " + (fold + 1) + ".."); Instances train = inst.trainCV(numFolds, fold, random); Classifier current = null; try { current = Classifier.makeCopy(template); current.buildClassifier(train); Instances test = inst.testCV(numFolds, fold); statusLabel.setText("Cluster: " + x + " - Testing Fold " + (fold + 1) + ".."); for (int jj = 0; jj < test.numInstances(); jj++) { double[] result = current.distributionForInstance(test.instance(jj)); output.write("Cluster: " + x); output.newLine(); output.newLine(); output.write(test.instance(jj).stringValue(test.classAttribute()) + ",0=" + result[0]); output.newLine(); } } catch (Exception ex) { ex.printStackTrace(); statusLabel.setText("Error in cross-validation!"); startButton.setEnabled(true); stopButton.setEnabled(false); } } output.close(); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); double threshold = validateFieldAsThreshold(classifierOneThresholdTextField.getText(), "Threshold Field", classifierOneThresholdTextField); String filename2 = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score"); PredictionStats classifierStats = new PredictionStats(filename2, 0, threshold); resultsTableModel.add("Cluster " + x, classifierResults, classifierStats); resultsTable.setRowSelectionInterval(x, x); computeStats(numFolds);//compute and display the results } catch (Exception e) { e.printStackTrace(); statusLabel.setText("Error in reading file!"); startButton.setEnabled(true); stopButton.setEnabled(false); } } //end of cluster for loop resultsTableModel.add("Summary - Equal Weightage", null, null); resultsTable.setRowSelectionInterval(numOfCluster + 1, numOfCluster + 1); computeStats(numFolds); resultsTableModel.add("Summary - Weighted Average", null, null); resultsTable.setRowSelectionInterval(numOfCluster + 2, numOfCluster + 2); computeStats(numFolds); if (clusteringClassificationThread != null) statusLabel.setText("Done!"); else statusLabel.setText("Interrupted.."); startButton.setEnabled(true); stopButton.setEnabled(false); if (classifierOne != null) { levelOneClassifierOutputScrollPane.getVerticalScrollBar() .setValue(levelOneClassifierOutputScrollPane.getVerticalScrollBar().getMaximum()); } clusteringClassificationThread = null; } }); this.clusteringClassificationThread.setPriority(Thread.MIN_PRIORITY); this.clusteringClassificationThread.start(); } else { JOptionPane.showMessageDialog(parent, "Cannot start new job as previous job still running. Click stop to terminate previous job", "ERROR", JOptionPane.ERROR_MESSAGE); } }