List of usage examples for weka.core Instances stratify
public void stratify(int numFolds)
From source file:cezeri.evaluater.FactoryEvaluation.java
public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text, boolean show_plot, TFigureAttribute attr) { Random rand = new Random(1); Instances randData = new Instances(datax); randData.randomize(rand);/*w ww.j a v a 2 s .com*/ if (randData.classAttribute().isNominal()) { randData.stratify(folds); } Evaluation eval = null; try { // perform cross-validation eval = new Evaluation(randData); // double[] simulated = new double[0]; // double[] observed = new double[0]; // double[] sim = new double[0]; // double[] obs = new double[0]; for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n, rand); Instances validation = randData.testCV(folds, n); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); // sim = eval.evaluateModel(clsCopy, validation); // obs = validation.attributeToDoubleArray(validation.classIndex()); // if (show_plot) { // double[][] d = new double[2][sim.length]; // d[0] = obs; // d[1] = sim; // CMatrix f1 = CMatrix.getInstance(d); // f1.transpose().plot(attr); // } // if (show_text) { // // output evaluation // System.out.println(); // System.out.println("=== Setup for each Cross Validation fold==="); // System.out.println("Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); // System.out.println("Dataset: " + randData.relationName()); // System.out.println("Folds: " + folds); // System.out.println("Seed: " + 1); // System.out.println(); // System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); // } simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation)); observed = FactoryUtils.concatenate(observed, validation.attributeToDoubleArray(validation.classIndex())); // simulated = FactoryUtils.mean(simulated,eval.evaluateModel(clsCopy, validation)); // observed = FactoryUtils.mean(observed,validation.attributeToDoubleArray(validation.classIndex())); } if (show_plot) { double[][] d = new double[2][simulated.length]; d[0] = observed; d[1] = simulated; CMatrix f1 = CMatrix.getInstance(d); attr.figureCaption = "overall performance"; f1.transpose().plot(attr); } if (show_text) { // output evaluation System.out.println(); System.out.println("=== Setup for Overall Cross Validation==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + 1); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); } } catch (Exception ex) { Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:com.reactivetechnologies.analytics.core.eval.StackingWithBuiltClassifiers.java
License:Open Source License
/** * Buildclassifier selects a classifier from the set of classifiers * by minimising error on the training data. * * @param data the training data to be used for generating the * boosted classifier./*from w w w. j av a 2s . com*/ * @throws Exception if the classifier could not be built successfully */ @Override public void buildClassifier(Instances data) throws Exception { if (m_MetaClassifier == null) { throw new IllegalArgumentException("No meta classifier has been set"); } // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class Instances newData = new Instances(data); m_BaseFormat = new Instances(data, 0); newData.deleteWithMissingClass(); Random random = new Random(m_Seed); newData.randomize(random); if (newData.classAttribute().isNominal()) { newData.stratify(m_NumFolds); } // Create meta level generateMetaLevel(newData, random); /** Changed here */ // DO NOT Rebuilt all the base classifiers on the full training data /*for (int i = 0; i < m_Classifiers.length; i++) { getClassifier(i).buildClassifier(newData); }*/ /** End change */ }
From source file:cotraining.copy.Evaluation_D.java
License:Open Source License
/** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. Now performs * a deep copy of the classifier before each call to * buildClassifier() (just in case the classifier is not * initialized properly).//w w w . j a v a 2 s. com * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @param random random number generator for randomization * @param forPredictionsString varargs parameter that, if supplied, is * expected to hold a StringBuffer to print predictions to, * a Range of attributes to output and a Boolean (true if the distribution * is to be printed) * @throws Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random, Object... forPredictionsPrinting) throws Exception { // Make a copy of the data we can reorder data = new Instances(data); data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(numFolds); } // We assume that the first element is a StringBuffer, the second a Range (attributes // to output) and the third a Boolean (whether or not to output a distribution instead // of just a classification) if (forPredictionsPrinting.length > 0) { // print the header first StringBuffer buff = (StringBuffer) forPredictionsPrinting[0]; Range attsToOutput = (Range) forPredictionsPrinting[1]; boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue(); printClassificationsHeader(data, attsToOutput, printDist, buff); } // Do the folds for (int i = 0; i < numFolds; i++) { Instances train = data.trainCV(numFolds, i, random); setPriors(train); Classifier copiedClassifier = Classifier.makeCopy(classifier); copiedClassifier.buildClassifier(train); Instances test = data.testCV(numFolds, i); evaluateModel(copiedClassifier, test, forPredictionsPrinting); } m_NumFolds = numFolds; }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied stream learner. * @return the AUC as a double value./*from w w w . ja v a 2s .c o m*/ */ private static double validate5x2CVStream() { try { // Other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); Distribution testDistribution = new Distribution(test); ArffSaver trainSaver = new ArffSaver(); trainSaver.setInstances(train); trainSaver.setFile(new File(trainPath)); trainSaver.writeBatch(); ArffSaver testSaver = new ArffSaver(); testSaver.setInstances(test); double[][] dist = testDistribution.matrix(); int negativeClassSize = (int) dist[0][0]; int positiveClassSize = (int) dist[0][1]; double balance = (double) positiveClassSize / (double) negativeClassSize; String tempTestPath = testPath.replace(".arff", "_" + positiveClassSize + "_" + negativeClassSize + "_" + balance + "_1.0.arff");// [Test-n-Set-n]_[+]_[-]_[K]_[L]; testSaver.setFile(new File(tempTestPath)); testSaver.writeBatch(); ARFFFile file = new ARFFFile(tempTestPath, CLASS_INDEX, new DebugLogger(false)); file.createMetaData(); HoeffdingTreeTester streamClassifier = new HoeffdingTreeTester(trainPath, tempTestPath, CLASS_INDEX, new String[] { "0", "1" }, new DebugLogger(true)); streamClassifier.train(); System.in.read(); //AUC_SUM += streamClassifier.getROCExternalData("",(int)testDistribution.perClass(1),(int)testDistribution.perClass(0)); streamClassifier.testStatic(homeDirectory + "/FuckSakeTest.txt"); String[] files = Common.getFilePaths(scratch); for (int j = 0; j < files.length; j++) Common.fileDelete(files[j]); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); e.printStackTrace(); return 0; } }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied learner. * @return the AUC as a double value./*from ww w. j a v a 2 s .c o m*/ */ @SuppressWarnings("unused") private static double validate5x2CV() { try { // other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier String[] options = { "-U", "-A" }; J48 classifier = new J48(); //HTree classifier = new HTree(); classifier.setOptions(options); classifier.buildClassifier(train); eval.evaluateModel(classifier, test); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); AUC_SUM += ThresholdCurve.getROCArea(result); System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } }
From source file:es.upm.dit.gsi.barmas.dataset.utils.DatasetSplitter.java
License:Open Source License
/** * @param folds//w w w.j av a2 s . c o m * @param minAgents * @param maxAgents * @param originalDatasetPath * @param outputDir * @param scenario * @param logger */ public void splitDataset(int folds, int minAgents, int maxAgents, String originalDatasetPath, String outputDir, String scenario, Logger logger) { int ratioint = (int) ((1 / (double) folds) * 100); double roundedratio = ((double) ratioint) / 100; // Look for essentials List<String[]> essentials = this.getEssentials(originalDatasetPath, logger); for (int fold = 0; fold < folds; fold++) { String outputDirWithRatio = outputDir + "/" + roundedratio + "testRatio/iteration-" + fold; File dir = new File(outputDirWithRatio); if (!dir.exists() || !dir.isDirectory()) { dir.mkdirs(); } logger.finer("--> splitDataset()"); logger.fine("Creating experiment.info..."); try { Instances originalData = this.getDataFromCSV(originalDatasetPath); originalData.randomize(new Random()); originalData.stratify(folds); // TestDataSet Instances testData = originalData.testCV(folds, fold); CSVSaver saver = new CSVSaver(); ArffSaver arffsaver = new ArffSaver(); File file = new File(outputDirWithRatio + File.separator + "test-dataset.csv"); if (!file.exists()) { saver.resetOptions(); saver.setInstances(testData); saver.setFile(file); saver.writeBatch(); } file = new File(outputDirWithRatio + File.separator + "test-dataset.arff"); if (!file.exists()) { arffsaver.resetOptions(); arffsaver.setInstances(testData); arffsaver.setFile(file); arffsaver.writeBatch(); } // BayesCentralDataset Instances trainData = originalData.trainCV(folds, fold); file = new File(outputDirWithRatio + File.separator + "bayes-central-dataset.csv"); if (!file.exists()) { saver.resetOptions(); saver.setInstances(trainData); saver.setFile(file); saver.writeBatch(); this.copyFileUsingApacheCommonsIO(file, new File( outputDirWithRatio + File.separator + "bayes-central-dataset-noEssentials.csv"), logger); CsvWriter w = new CsvWriter(new FileWriter(file, true), ','); for (String[] essential : essentials) { w.writeRecord(essential); } w.close(); } file = new File(outputDirWithRatio + File.separator + "bayes-central-dataset.arff"); if (!file.exists()) { arffsaver.resetOptions(); arffsaver.setInstances(trainData); arffsaver.setFile(file); arffsaver.writeBatch(); this.copyFileUsingApacheCommonsIO(file, new File( outputDirWithRatio + File.separator + "bayes-central-dataset-noEssentials.arff"), logger); CsvWriter w = new CsvWriter(new FileWriter(file, true), ','); for (String[] essential : essentials) { w.writeRecord(essential); } w.close(); } // Agent datasets CsvReader csvreader = new CsvReader(new FileReader(new File(originalDatasetPath))); csvreader.readHeaders(); String[] headers = csvreader.getHeaders(); csvreader.close(); for (int agents = minAgents; agents <= maxAgents; agents++) { this.createExperimentInfoFile(folds, agents, originalDatasetPath, outputDirWithRatio, scenario, logger); HashMap<String, CsvWriter> writers = new HashMap<String, CsvWriter>(); String agentsDatasetsDir = outputDirWithRatio + File.separator + agents + "agents"; HashMap<String, CsvWriter> arffWriters = new HashMap<String, CsvWriter>(); File f = new File(agentsDatasetsDir); if (!f.isDirectory()) { f.mkdirs(); } Instances copy = new Instances(trainData); copy.delete(); for (int i = 0; i < agents; i++) { String fileName = agentsDatasetsDir + File.separator + "agent-" + i + "-dataset.csv"; file = new File(fileName); if (!file.exists()) { CsvWriter writer = new CsvWriter(new FileWriter(fileName), ','); writer.writeRecord(headers); writers.put("AGENT" + i, writer); } fileName = agentsDatasetsDir + File.separator + "agent-" + i + "-dataset.arff"; file = new File(fileName); if (!file.exists()) { arffsaver.resetOptions(); arffsaver.setInstances(copy); arffsaver.setFile(new File(fileName)); arffsaver.writeBatch(); CsvWriter arffwriter = new CsvWriter(new FileWriter(fileName, true), ','); arffWriters.put("AGENT" + i, arffwriter); } logger.fine("AGENT" + i + " dataset created in csv and arff formats."); } // Append essentials to all for (String[] essential : essentials) { for (CsvWriter wr : writers.values()) { wr.writeRecord(essential); } for (CsvWriter arffwr : arffWriters.values()) { arffwr.writeRecord(essential); } } int agentCounter = 0; for (int j = 0; j < trainData.numInstances(); j++) { Instance instance = trainData.instance(j); CsvWriter writer = writers.get("AGENT" + agentCounter); CsvWriter arffwriter = arffWriters.get("AGENT" + agentCounter); String[] row = new String[instance.numAttributes()]; for (int a = 0; a < instance.numAttributes(); a++) { row[a] = instance.stringValue(a); } if (writer != null) { writer.writeRecord(row); } if (arffwriter != null) { arffwriter.writeRecord(row); } agentCounter++; if (agentCounter == agents) { agentCounter = 0; } } for (CsvWriter wr : writers.values()) { wr.close(); } for (CsvWriter arffwr : arffWriters.values()) { arffwr.close(); } } } catch (Exception e) { logger.severe("Exception while splitting dataset. ->"); logger.severe(e.getMessage()); System.exit(1); } logger.finest("Dataset for fold " + fold + " created."); } logger.finer("<-- splitDataset()"); }
From source file:GClass.EvaluationInternal.java
License:Open Source License
/** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be * performed//from www . j av a2 s .c o m * @param numFolds the number of folds for the cross-validation * @param random random number generator for randomization * @exception Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random) throws Exception { // Make a copy of the data we can reorder data = new Instances(data); data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(numFolds); } // Do the folds for (int i = 0; i < numFolds; i++) { Instances train = data.trainCV(numFolds, i, random); setPriors(train); Classifier copiedClassifier = Classifier.makeCopy(classifier); copiedClassifier.buildClassifier(train); Instances test = data.testCV(numFolds, i); evaluateModel(copiedClassifier, test); } m_NumFolds = numFolds; }
From source file:it.unisa.gitdm.evaluation.WekaEvaluator.java
private static void evaluateModel(String baseFolderPath, String projectName, Classifier pClassifier, Instances pInstances, String pModelName, String pClassifierName) throws Exception { // other options int folds = 10; // randomize data Random rand = new Random(42); Instances randData = new Instances(pInstances); randData.randomize(rand);//from w w w. j a va 2 s .co m if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Instances predictedData = null; Evaluation eval = new Evaluation(randData); int positiveValueIndexOfClassFeature = 0; for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); int classFeatureIndex = 0; for (int i = 0; i < train.numAttributes(); i++) { if (train.attribute(i).name().equals("isBuggy")) { classFeatureIndex = i; break; } } Attribute classFeature = train.attribute(classFeatureIndex); for (int i = 0; i < classFeature.numValues(); i++) { if (classFeature.value(i).equals("TRUE")) { positiveValueIndexOfClassFeature = i; } } train.setClassIndex(classFeatureIndex); test.setClassIndex(classFeatureIndex); // build and evaluate classifier pClassifier.buildClassifier(train); eval.evaluateModel(pClassifier, test); // add predictions // AddClassification filter = new AddClassification(); // filter.setClassifier(pClassifier); // filter.setOutputClassification(true); // filter.setOutputDistribution(true); // filter.setOutputErrorFlag(true); // filter.setInputFormat(train); // Filter.useFilter(train, filter); // Instances pred = Filter.useFilter(test, filter); // if (predictedData == null) // predictedData = new Instances(pred, 0); // // for (int j = 0; j < pred.numInstances(); j++) // predictedData.add(pred.instance(j)); } double accuracy = (eval.numTruePositives(positiveValueIndexOfClassFeature) + eval.numTrueNegatives(positiveValueIndexOfClassFeature)) / (eval.numTruePositives(positiveValueIndexOfClassFeature) + eval.numFalsePositives(positiveValueIndexOfClassFeature) + eval.numFalseNegatives(positiveValueIndexOfClassFeature) + eval.numTrueNegatives(positiveValueIndexOfClassFeature)); double fmeasure = 2 * ((eval.precision(positiveValueIndexOfClassFeature) * eval.recall(positiveValueIndexOfClassFeature)) / (eval.precision(positiveValueIndexOfClassFeature) + eval.recall(positiveValueIndexOfClassFeature))); File wekaOutput = new File(baseFolderPath + projectName + "/predictors.csv"); PrintWriter pw1 = new PrintWriter(wekaOutput); pw1.write(accuracy + ";" + eval.precision(positiveValueIndexOfClassFeature) + ";" + eval.recall(positiveValueIndexOfClassFeature) + ";" + fmeasure + ";" + eval.areaUnderROC(positiveValueIndexOfClassFeature)); System.out.println(projectName + ";" + pClassifierName + ";" + pModelName + ";" + eval.numTruePositives(positiveValueIndexOfClassFeature) + ";" + eval.numFalsePositives(positiveValueIndexOfClassFeature) + ";" + eval.numFalseNegatives(positiveValueIndexOfClassFeature) + ";" + eval.numTrueNegatives(positiveValueIndexOfClassFeature) + ";" + accuracy + ";" + eval.precision(positiveValueIndexOfClassFeature) + ";" + eval.recall(positiveValueIndexOfClassFeature) + ";" + fmeasure + ";" + eval.areaUnderROC(positiveValueIndexOfClassFeature) + "\n"); }
From source file:j48.PruneableClassifierTree.java
License:Open Source License
/** * Method for building a pruneable classifier tree. * * @param data the data to build the tree from * @throws Exception if tree can't be built successfully *//*from w w w . j av a2 s.c o m*/ public void buildClassifier(Instances data) throws Exception { // can classifier tree handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); Random random = new Random(m_seed); data.stratify(numSets); buildTree(data.trainCV(numSets, numSets - 1, random), data.testCV(numSets, numSets - 1), !m_cleanup); if (pruneTheTree) { prune(); } if (m_cleanup) { cleanup(new Instances(data, 0)); } }
From source file:jjj.asap.sas.ensemble.impl.CrossValidatedEnsemble.java
License:Open Source License
@Override public StrongLearner build(int essaySet, String ensembleName, List<WeakLearner> learners) { // can't handle empty case if (learners.isEmpty()) { return this.ensemble.build(essaySet, ensembleName, learners); }//w ww.ja v a 2 s. c o m // create a dummy dataset. DatasetBuilder builder = new DatasetBuilder(); builder.addVariable("id"); builder.addNominalVariable("class", Contest.getRubrics(essaySet)); Instances dummy = builder.getDataset("dummy"); // add data Map<Double, Double> groundTruth = Contest.getGoldStandard(essaySet); for (double id : learners.get(0).getPreds().keySet()) { dummy.add(new DenseInstance(1.0, new double[] { id, groundTruth.get(id) })); } // stratify dummy.sort(0); dummy.randomize(new Random(1)); dummy.setClassIndex(1); dummy.stratify(nFolds); // now evaluate each fold Map<Double, Double> preds = new HashMap<Double, Double>(); for (int k = 0; k < nFolds; k++) { Instances train = dummy.trainCV(nFolds, k); Instances test = dummy.testCV(nFolds, k); List<WeakLearner> cvLeaners = new ArrayList<WeakLearner>(); for (WeakLearner learner : learners) { WeakLearner copy = learner.copyOf(); for (int i = 0; i < test.numInstances(); i++) { copy.getPreds().remove(test.instance(i).value(0)); copy.getProbs().remove(test.instance(i).value(0)); } cvLeaners.add(copy); } // train on fold StrongLearner cv = this.ensemble.build(essaySet, ensembleName, cvLeaners); List<WeakLearner> testLeaners = new ArrayList<WeakLearner>(); for (WeakLearner learner : cv.getLearners()) { WeakLearner copy = learner.copyOf(); copy.getPreds().clear(); copy.getProbs().clear(); WeakLearner source = find(copy.getName(), learners); for (int i = 0; i < test.numInstances(); i++) { double id = test.instance(i).value(0); copy.getPreds().put(id, source.getPreds().get(id)); copy.getProbs().put(id, source.getProbs().get(id)); } testLeaners.add(copy); } preds.putAll(this.ensemble.classify(essaySet, ensembleName, testLeaners, cv.getContext())); } // now prepare final result StrongLearner strong = this.ensemble.build(essaySet, ensembleName, learners); double trainingError = strong.getKappa(); double cvError = Calc.kappa(essaySet, preds, groundTruth); // Job.log(essaySet+"-"+ensembleName, "XVAL: training error = " + trainingError + " cv error = " + cvError); strong.setKappa(cvError); return strong; }