List of usage examples for weka.core Instances stratify
public void stratify(int numFolds)
From source file:mao.datamining.ModelProcess.java
private void testCV(Classifier classifier, Instances finalTrainDataSet, FileOutputStream testCaseSummaryOut, TestResult result) {/*w ww. ja v a 2 s . c o m*/ long start, end, trainTime = 0, testTime = 0; Evaluation evalAll = null; double confusionMatrix[][] = null; // randomize data, and then stratify it into 10 groups Random rand = new Random(1); Instances randData = new Instances(finalTrainDataSet); randData.randomize(rand); if (randData.classAttribute().isNominal()) { //always run with 10 cross validation randData.stratify(folds); } try { evalAll = new Evaluation(randData); for (int i = 0; i < folds; i++) { Evaluation eval = new Evaluation(randData); Instances train = randData.trainCV(folds, i); Instances test = randData.testCV(folds, i); //counting traininig time start = System.currentTimeMillis(); Classifier j48ClassifierCopy = Classifier.makeCopy(classifier); j48ClassifierCopy.buildClassifier(train); end = System.currentTimeMillis(); trainTime += end - start; //counting test time start = System.currentTimeMillis(); eval.evaluateModel(j48ClassifierCopy, test); evalAll.evaluateModel(j48ClassifierCopy, test); end = System.currentTimeMillis(); testTime += end - start; } } catch (Exception e) { ModelProcess.logging(null, e); } //end test by cross validation // output evaluation try { ModelProcess.logging(""); //write into summary file testCaseSummaryOut .write((evalAll.toSummaryString("=== Cross Validation Summary ===", true)).getBytes()); testCaseSummaryOut.write("\n".getBytes()); testCaseSummaryOut.write( (evalAll.toClassDetailsString("=== " + folds + "-fold Cross-validation Class Detail ===\n")) .getBytes()); testCaseSummaryOut.write("\n".getBytes()); testCaseSummaryOut .write((evalAll.toMatrixString("=== Confusion matrix for all folds ===\n")).getBytes()); testCaseSummaryOut.flush(); confusionMatrix = evalAll.confusionMatrix(); result.setConfusionMatrix10Folds(confusionMatrix); } catch (Exception e) { ModelProcess.logging(null, e); } }
From source file:mlpoc.MLPOC.java
public static Evaluation crossValidate(String filename) { Evaluation eval = null;/*from w w w.j av a 2s. com*/ try { BufferedReader br = new BufferedReader(new FileReader(filename)); // loads data and set class index Instances data = new Instances(br); br.close(); /*File csv=new File(filename); CSVLoader loader = new CSVLoader(); loader.setSource(csv); Instances data = loader.getDataSet();*/ data.setClassIndex(data.numAttributes() - 1); // classifier String[] tmpOptions; String classname = "weka.classifiers.trees.J48 -C 0.25"; tmpOptions = classname.split(" "); classname = "weka.classifiers.trees.J48"; tmpOptions[0] = ""; Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions); // other options int seed = 2; int folds = 10; // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) randData.stratify(folds); // perform cross-validation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(cls); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } // output evaluation System.out.println(); System.out.println("=== Setup ==="); System.out .println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions())); System.out.println("Dataset: " + data.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + seed); System.out.println(); System.out.println(eval.toSummaryString("Summary for testing", true)); System.out.println("Correctly Classified Instances: " + eval.correct()); System.out.println("Percentage of Correctly Classified Instances: " + eval.pctCorrect()); System.out.println("InCorrectly Classified Instances: " + eval.incorrect()); System.out.println("Percentage of InCorrectly Classified Instances: " + eval.pctIncorrect()); } catch (Exception ex) { System.err.println(ex.getMessage()); } return eval; }
From source file:moa.classifiers.AccuracyWeightedEnsemble.java
License:Open Source License
/** * Computes the weight of a candidate classifier. * @param candidate Candidate classifier. * @param chunk Data chunk of examples./*from w w w . ja v a 2 s . com*/ * @param numFolds Number of folds in candidate classifier cross-validation. * @param useMseR Determines whether to use the MSEr threshold. * @return Candidate classifier weight. */ protected double computeCandidateWeight(Classifier candidate, Instances chunk, int numFolds) { double candidateWeight = 0.0; Random random = new Random(1); Instances randData = new Instances(chunk); randData.randomize(random); if (randData.classAttribute().isNominal()) { randData.stratify(numFolds); } for (int n = 0; n < numFolds; n++) { Instances train = randData.trainCV(numFolds, n, random); Instances test = randData.testCV(numFolds, n); Classifier learner = candidate.copy(); for (int num = 0; num < train.numInstances(); num++) { learner.trainOnInstance(train.instance(num)); } candidateWeight += computeWeight(learner, test); } double resultWeight = candidateWeight / numFolds; if (Double.isInfinite(resultWeight)) { return Double.MAX_VALUE; } else { return resultWeight; } }
From source file:mulan.data.LabelPowersetStratification.java
License:Open Source License
public MultiLabelInstances[] stratify(MultiLabelInstances data, int folds) { try {/*from w w w . ja va2 s . co m*/ MultiLabelInstances[] segments = new MultiLabelInstances[folds]; LabelPowersetTransformation transformation = new LabelPowersetTransformation(); Instances transformed; // transform to single-label transformed = transformation.transformInstances(data); // add id Add add = new Add(); add.setAttributeIndex("first"); add.setAttributeName("instanceID"); add.setInputFormat(transformed); transformed = Filter.useFilter(transformed, add); for (int i = 0; i < transformed.numInstances(); i++) { transformed.instance(i).setValue(0, i); } transformed.setClassIndex(transformed.numAttributes() - 1); // stratify transformed.randomize(new Random(seed)); transformed.stratify(folds); for (int i = 0; i < folds; i++) { //System.out.println("Fold " + (i + 1) + "/" + folds); Instances temp = transformed.testCV(folds, i); Instances test = new Instances(data.getDataSet(), 0); for (int j = 0; j < temp.numInstances(); j++) { test.add(data.getDataSet().instance((int) temp.instance(j).value(0))); } segments[i] = new MultiLabelInstances(test, data.getLabelsMetaData()); } return segments; } catch (Exception ex) { Logger.getLogger(LabelPowersetStratification.class.getName()).log(Level.SEVERE, null, ex); return null; } }
From source file:net.sf.bddbddb.order.WekaInterface.java
License:LGPL
public static double cvError(int numFolds, Instances data0, String cClassName) { if (data0.numInstances() < numFolds) return Double.NaN; //more folds than elements if (numFolds == 0) return Double.NaN; // no folds if (data0.numInstances() == 0) return 0; //no instances Instances data = new Instances(data0); //data.randomize(new Random(System.currentTimeMillis())); data.stratify(numFolds); Assert._assert(data.classAttribute() != null); double[] estimates = new double[numFolds]; for (int i = 0; i < numFolds; ++i) { Instances trainData = data.trainCV(numFolds, i); Assert._assert(trainData.classAttribute() != null); Assert._assert(trainData.numInstances() != 0, "Cannot train classifier on 0 instances."); Instances testData = data.testCV(numFolds, i); Assert._assert(testData.classAttribute() != null); Assert._assert(testData.numInstances() != 0, "Cannot test classifier on 0 instances."); int temp = FindBestDomainOrder.TRACE; FindBestDomainOrder.TRACE = 0;//w w w.j a v a2 s .c o m Classifier classifier = buildClassifier(cClassName, trainData); FindBestDomainOrder.TRACE = temp; int count = testData.numInstances(); double loss = 0; double sum = 0; for (Enumeration e = testData.enumerateInstances(); e.hasMoreElements();) { Instance instance = (Instance) e.nextElement(); Assert._assert(instance != null); Assert._assert(instance.classAttribute() != null && instance.classAttribute() == trainData.classAttribute()); try { double testClass = classifier.classifyInstance(instance); double weight = instance.weight(); if (testClass != instance.classValue()) loss += weight; sum += weight; } catch (Exception ex) { FindBestDomainOrder.out.println("Exception while classifying: " + instance + "\n" + ex); } } estimates[i] = 1 - loss / sum; } double average = 0; for (int i = 0; i < numFolds; ++i) average += estimates[i]; return average / numFolds; }
From source file:org.scripps.branch.classifier.ManualTree.java
License:Open Source License
/** * Builds classifier.// ww w .jav a 2 s. c om * * @param data * the data to train with * @throws Exception * if something goes wrong or the data doesn't fit */ @Override public void buildClassifier(Instances data) throws Exception { // Make sure K value is in range if (m_KValue > data.numAttributes() - 1) m_KValue = data.numAttributes() - 1; if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes()) + 1; // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); // only class? -> build ZeroR model if (data.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(data); return; } else { m_ZeroR = null; } // Figure out appropriate datasets Instances train = null; Instances backfit = null; Random rand = data.getRandomNumberGenerator(m_randomSeed); if (m_NumFolds <= 0) { train = data; } else { data.randomize(rand); data.stratify(m_NumFolds); train = data.trainCV(m_NumFolds, 1, rand); backfit = data.testCV(m_NumFolds, 1); } //Set Default Instances for selection. setRequiredInst(data); // Create the attribute indices window int[] attIndicesWindow = new int[data.numAttributes() - 1]; int j = 0; for (int i = 0; i < attIndicesWindow.length; i++) { if (j == data.classIndex()) j++; // do not include the class attIndicesWindow[i] = j++; } // Compute initial class counts double[] classProbs = new double[train.numClasses()]; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); classProbs[(int) inst.classValue()] += inst.weight(); } Instances requiredInstances = getRequiredInst(); // Build tree if (jsontree != null) { buildTree(train, classProbs, new Instances(data, 0), m_Debug, 0, jsontree, 0, m_distributionData, requiredInstances, listOfFc, cSetList, ccSer, d); } else { System.out.println("No json tree specified, failing to process tree"); } setRequiredInst(requiredInstances); // Backfit if required if (backfit != null) { backfitData(backfit); } }
From source file:sentinets.Prediction.java
License:Open Source License
public String updateModel(String inputFile, ArrayList<Double[]> metrics) { String output = ""; this.setInstances(inputFile); FilteredClassifier fcls = (FilteredClassifier) this.cls; SGD cls = (SGD) fcls.getClassifier(); Filter filter = fcls.getFilter(); Instances insAll;/*from w w w . java2s .c om*/ try { insAll = Filter.useFilter(this.unlabled, filter); if (insAll.size() > 0) { Random rand = new Random(10); int folds = 10 > insAll.size() ? 2 : 10; Instances randData = new Instances(insAll); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } Evaluation eval = new Evaluation(randData); eval.evaluateModel(cls, insAll); System.out.println("Initial Evaluation"); System.out.println(eval.toSummaryString()); System.out.println(eval.toClassDetailsString()); metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() }); output += "\n====" + "Initial Evaluation" + "====\n"; output += "\n" + eval.toSummaryString(); output += "\n" + eval.toClassDetailsString(); System.out.println("Cross Validated Evaluation"); output += "\n====" + "Cross Validated Evaluation" + "====\n"; for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); for (int i = 0; i < train.numInstances(); i++) { cls.updateClassifier(train.instance(i)); } eval.evaluateModel(cls, test); System.out.println("Cross Validated Evaluation fold: " + n); output += "\n====" + "Cross Validated Evaluation fold (" + n + ")====\n"; System.out.println(eval.toSummaryString()); System.out.println(eval.toClassDetailsString()); output += "\n" + eval.toSummaryString(); output += "\n" + eval.toClassDetailsString(); metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() }); } for (int i = 0; i < insAll.numInstances(); i++) { cls.updateClassifier(insAll.instance(i)); } eval.evaluateModel(cls, insAll); System.out.println("Final Evaluation"); System.out.println(eval.toSummaryString()); System.out.println(eval.toClassDetailsString()); output += "\n====" + "Final Evaluation" + "====\n"; output += "\n" + eval.toSummaryString(); output += "\n" + eval.toClassDetailsString(); metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() }); fcls.setClassifier(cls); String modelFilePath = outputDir + "/" + Utils.getOutDir(Utils.OutDirIndex.MODELS) + "/updatedClassifier.model"; weka.core.SerializationHelper.write(modelFilePath, fcls); output += "\n" + "Updated Model saved at: " + modelFilePath; } else { output += "No new instances for training the model."; } } catch (Exception e) { e.printStackTrace(); } return output; }
From source file:sirius.clustering.main.ClustererClassificationPane.java
License:Open Source License
private void start() { //Run Classifier if (this.inputDirectoryTextField.getText().length() == 0) { JOptionPane.showMessageDialog(parent, "Please set Input Directory to where the clusterer output are!", "Evaluate Classifier", JOptionPane.ERROR_MESSAGE); return;/*w w w . java 2 s . c o m*/ } if (m_ClassifierEditor.getValue() == null) { JOptionPane.showMessageDialog(parent, "Please choose Classifier!", "Evaluate Classifier", JOptionPane.ERROR_MESSAGE); return; } if (validateStatsSettings(1) == false) { return; } if (this.clusteringClassificationThread == null) { startButton.setEnabled(false); stopButton.setEnabled(true); tabbedClassifierPane.setSelectedIndex(0); this.clusteringClassificationThread = (new Thread() { public void run() { //Clear the output text area levelOneClassifierOutputTextArea.setText(""); resultsTableModel.reset(); //double threshold = Double.parseDouble(classifierOneThresholdTextField.getText()); //cross-validation int numFolds; if (jackKnifeRadioButton.isSelected()) numFolds = -1; else numFolds = Integer.parseInt(foldsField.getText()); StringTokenizer st = new StringTokenizer(inputDirectoryTextField.getText(), File.separator); String filename = ""; while (st.hasMoreTokens()) { filename = st.nextToken(); } StringTokenizer st2 = new StringTokenizer(filename, "_."); numOfCluster = 0; if (st2.countTokens() >= 2) { st2.nextToken(); String numOfClusterString = st2.nextToken().replaceAll("cluster", ""); try { numOfCluster = Integer.parseInt(numOfClusterString); } catch (NumberFormatException e) { JOptionPane.showMessageDialog(parent, "Please choose the correct file! (Output from Utilize Clusterer)", "ERROR", JOptionPane.ERROR_MESSAGE); } } Classifier template = (Classifier) m_ClassifierEditor.getValue(); for (int x = 0; x <= numOfCluster && clusteringClassificationThread != null; x++) {//Test each cluster try { long totalTimeStart = 0, totalTimeElapsed = 0; totalTimeStart = System.currentTimeMillis(); statusLabel.setText("Reading in cluster" + x + " file.."); String inputFilename = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".arff"); String outputScoreFilename = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score"); BufferedWriter output = new BufferedWriter(new FileWriter(outputScoreFilename)); Instances inst = new Instances(new FileReader(inputFilename)); //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files inst.setClassIndex(inst.numAttributes() - 1); Random random = new Random(1);//Simply set to 1, shall implement the random seed option later inst.randomize(random); if (inst.attribute(inst.classIndex()).isNominal()) inst.stratify(numFolds); // for timing ClassifierResults classifierResults = new ClassifierResults(false, 0); String classifierName = m_ClassifierEditor.getValue().getClass().getName(); classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", inputFilename); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", "NA"); //ArrayList<Double> resultList = new ArrayList<Double>(); if (jackKnifeRadioButton.isSelected() || numFolds > inst.numInstances() - 1) numFolds = inst.numInstances() - 1; for (int fold = 0; fold < numFolds && clusteringClassificationThread != null; fold++) {//Doing cross-validation statusLabel.setText("Cluster: " + x + " - Training Fold " + (fold + 1) + ".."); Instances train = inst.trainCV(numFolds, fold, random); Classifier current = null; try { current = Classifier.makeCopy(template); current.buildClassifier(train); Instances test = inst.testCV(numFolds, fold); statusLabel.setText("Cluster: " + x + " - Testing Fold " + (fold + 1) + ".."); for (int jj = 0; jj < test.numInstances(); jj++) { double[] result = current.distributionForInstance(test.instance(jj)); output.write("Cluster: " + x); output.newLine(); output.newLine(); output.write(test.instance(jj).stringValue(test.classAttribute()) + ",0=" + result[0]); output.newLine(); } } catch (Exception ex) { ex.printStackTrace(); statusLabel.setText("Error in cross-validation!"); startButton.setEnabled(true); stopButton.setEnabled(false); } } output.close(); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); double threshold = validateFieldAsThreshold(classifierOneThresholdTextField.getText(), "Threshold Field", classifierOneThresholdTextField); String filename2 = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score"); PredictionStats classifierStats = new PredictionStats(filename2, 0, threshold); resultsTableModel.add("Cluster " + x, classifierResults, classifierStats); resultsTable.setRowSelectionInterval(x, x); computeStats(numFolds);//compute and display the results } catch (Exception e) { e.printStackTrace(); statusLabel.setText("Error in reading file!"); startButton.setEnabled(true); stopButton.setEnabled(false); } } //end of cluster for loop resultsTableModel.add("Summary - Equal Weightage", null, null); resultsTable.setRowSelectionInterval(numOfCluster + 1, numOfCluster + 1); computeStats(numFolds); resultsTableModel.add("Summary - Weighted Average", null, null); resultsTable.setRowSelectionInterval(numOfCluster + 2, numOfCluster + 2); computeStats(numFolds); if (clusteringClassificationThread != null) statusLabel.setText("Done!"); else statusLabel.setText("Interrupted.."); startButton.setEnabled(true); stopButton.setEnabled(false); if (classifierOne != null) { levelOneClassifierOutputScrollPane.getVerticalScrollBar() .setValue(levelOneClassifierOutputScrollPane.getVerticalScrollBar().getMaximum()); } clusteringClassificationThread = null; } }); this.clusteringClassificationThread.setPriority(Thread.MIN_PRIORITY); this.clusteringClassificationThread.start(); } else { JOptionPane.showMessageDialog(parent, "Cannot start new job as previous job still running. Click stop to terminate previous job", "ERROR", JOptionPane.ERROR_MESSAGE); } }