Example usage for weka.core Instances stratify

List of usage examples for weka.core Instances stratify

Introduction

In this page you can find the example usage for weka.core Instances stratify.

Prototype

public void stratify(int numFolds) 

Source Link

Document

Stratifies a set of instances according to its class values if the class attribute is nominal (so that afterwards a stratified cross-validation can be performed).

Usage

From source file:cezeri.evaluater.FactoryEvaluation.java

public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text,
        boolean show_plot, TFigureAttribute attr) {
    Random rand = new Random(1);
    Instances randData = new Instances(datax);
    randData.randomize(rand);/*w ww.j  a v  a 2 s  .com*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }
    Evaluation eval = null;
    try {
        // perform cross-validation
        eval = new Evaluation(randData);
        //            double[] simulated = new double[0];
        //            double[] observed = new double[0];
        //            double[] sim = new double[0];
        //            double[] obs = new double[0];
        for (int n = 0; n < folds; n++) {
            Instances train = randData.trainCV(folds, n, rand);
            Instances validation = randData.testCV(folds, n);
            // build and evaluate classifier
            Classifier clsCopy = Classifier.makeCopy(model);
            clsCopy.buildClassifier(train);

            //                sim = eval.evaluateModel(clsCopy, validation);
            //                obs = validation.attributeToDoubleArray(validation.classIndex());
            //                if (show_plot) {
            //                    double[][] d = new double[2][sim.length];
            //                    d[0] = obs;
            //                    d[1] = sim;
            //                    CMatrix f1 = CMatrix.getInstance(d);
            //                    f1.transpose().plot(attr);
            //                }
            //                if (show_text) {
            //                    // output evaluation
            //                    System.out.println();
            //                    System.out.println("=== Setup for each Cross Validation fold===");
            //                    System.out.println("Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            //                    System.out.println("Dataset: " + randData.relationName());
            //                    System.out.println("Folds: " + folds);
            //                    System.out.println("Seed: " + 1);
            //                    System.out.println();
            //                    System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
            //                }
            simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation));
            observed = FactoryUtils.concatenate(observed,
                    validation.attributeToDoubleArray(validation.classIndex()));
            //                simulated = FactoryUtils.mean(simulated,eval.evaluateModel(clsCopy, validation));
            //                observed = FactoryUtils.mean(observed,validation.attributeToDoubleArray(validation.classIndex()));
        }

        if (show_plot) {
            double[][] d = new double[2][simulated.length];
            d[0] = observed;
            d[1] = simulated;
            CMatrix f1 = CMatrix.getInstance(d);
            attr.figureCaption = "overall performance";
            f1.transpose().plot(attr);
        }
        if (show_text) {
            // output evaluation
            System.out.println();
            System.out.println("=== Setup for Overall Cross Validation===");
            System.out.println(
                    "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            System.out.println("Dataset: " + randData.relationName());
            System.out.println("Folds: " + folds);
            System.out.println("Seed: " + 1);
            System.out.println();
            System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
        }
    } catch (Exception ex) {
        Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
    }
    return eval;
}

From source file:com.reactivetechnologies.analytics.core.eval.StackingWithBuiltClassifiers.java

License:Open Source License

/**
 * Buildclassifier selects a classifier from the set of classifiers
 * by minimising error on the training data.
 *
 * @param data the training data to be used for generating the
 * boosted classifier./*from   w  w w.  j  av a  2s  .  com*/
 * @throws Exception if the classifier could not be built successfully
 */
@Override
public void buildClassifier(Instances data) throws Exception {

    if (m_MetaClassifier == null) {
        throw new IllegalArgumentException("No meta classifier has been set");
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    Instances newData = new Instances(data);
    m_BaseFormat = new Instances(data, 0);
    newData.deleteWithMissingClass();

    Random random = new Random(m_Seed);
    newData.randomize(random);
    if (newData.classAttribute().isNominal()) {
        newData.stratify(m_NumFolds);
    }

    // Create meta level
    generateMetaLevel(newData, random);

    /** Changed here */
    // DO NOT Rebuilt all the base classifiers on the full training data
    /*for (int i = 0; i < m_Classifiers.length; i++) {
      getClassifier(i).buildClassifier(newData);
    }*/
    /** End change */
}

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Performs a (stratified if class is nominal) cross-validation 
 * for a classifier on a set of instances. Now performs
 * a deep copy of the classifier before each call to 
 * buildClassifier() (just in case the classifier is not
 * initialized properly).//w w w . j  a  v a  2 s. com
 *
 * @param classifier the classifier with any options set.
 * @param data the data on which the cross-validation is to be 
 * performed 
 * @param numFolds the number of folds for the cross-validation
 * @param random random number generator for randomization 
 * @param forPredictionsString varargs parameter that, if supplied, is
 * expected to hold a StringBuffer to print predictions to, 
 * a Range of attributes to output and a Boolean (true if the distribution
 * is to be printed)
 * @throws Exception if a classifier could not be generated 
 * successfully or the class is not defined
 */
public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random,
        Object... forPredictionsPrinting) throws Exception {

    // Make a copy of the data we can reorder
    data = new Instances(data);
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(numFolds);
    }

    // We assume that the first element is a StringBuffer, the second a Range (attributes
    // to output) and the third a Boolean (whether or not to output a distribution instead
    // of just a classification)
    if (forPredictionsPrinting.length > 0) {
        // print the header first
        StringBuffer buff = (StringBuffer) forPredictionsPrinting[0];
        Range attsToOutput = (Range) forPredictionsPrinting[1];
        boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue();
        printClassificationsHeader(data, attsToOutput, printDist, buff);
    }

    // Do the folds
    for (int i = 0; i < numFolds; i++) {
        Instances train = data.trainCV(numFolds, i, random);
        setPriors(train);
        Classifier copiedClassifier = Classifier.makeCopy(classifier);
        copiedClassifier.buildClassifier(train);
        Instances test = data.testCV(numFolds, i);
        evaluateModel(copiedClassifier, test, forPredictionsPrinting);
    }
    m_NumFolds = numFolds;
}

From source file:cs.man.ac.uk.classifiers.GetAUC.java

License:Open Source License

/**
 * Computes the AUC for the supplied stream learner.
 * @return the AUC as a double value./*from  w  w  w .  ja v a  2s .c  o m*/
 */
private static double validate5x2CVStream() {
    try {
        // Other options
        int runs = 5;
        int folds = 2;
        double AUC_SUM = 0;

        // perform cross-validation
        for (int i = 0; i < runs; i++) {
            // randomize data
            int seed = i + 1;
            Random rand = new Random(seed);
            Instances randData = new Instances(data);
            randData.randomize(rand);

            if (randData.classAttribute().isNominal()) {
                System.out.println("Stratifying...");
                randData.stratify(folds);
            }

            for (int n = 0; n < folds; n++) {
                Instances train = randData.trainCV(folds, n);
                Instances test = randData.testCV(folds, n);

                Distribution testDistribution = new Distribution(test);

                ArffSaver trainSaver = new ArffSaver();
                trainSaver.setInstances(train);
                trainSaver.setFile(new File(trainPath));
                trainSaver.writeBatch();

                ArffSaver testSaver = new ArffSaver();
                testSaver.setInstances(test);

                double[][] dist = testDistribution.matrix();
                int negativeClassSize = (int) dist[0][0];
                int positiveClassSize = (int) dist[0][1];
                double balance = (double) positiveClassSize / (double) negativeClassSize;

                String tempTestPath = testPath.replace(".arff",
                        "_" + positiveClassSize + "_" + negativeClassSize + "_" + balance + "_1.0.arff");// [Test-n-Set-n]_[+]_[-]_[K]_[L];
                testSaver.setFile(new File(tempTestPath));
                testSaver.writeBatch();

                ARFFFile file = new ARFFFile(tempTestPath, CLASS_INDEX, new DebugLogger(false));
                file.createMetaData();

                HoeffdingTreeTester streamClassifier = new HoeffdingTreeTester(trainPath, tempTestPath,
                        CLASS_INDEX, new String[] { "0", "1" }, new DebugLogger(true));

                streamClassifier.train();

                System.in.read();

                //AUC_SUM += streamClassifier.getROCExternalData("",(int)testDistribution.perClass(1),(int)testDistribution.perClass(0));
                streamClassifier.testStatic(homeDirectory + "/FuckSakeTest.txt");

                String[] files = Common.getFilePaths(scratch);
                for (int j = 0; j < files.length; j++)
                    Common.fileDelete(files[j]);
            }
        }

        return AUC_SUM / ((double) runs * (double) folds);
    } catch (Exception e) {
        System.out.println("Exception validating data!");
        e.printStackTrace();
        return 0;
    }
}

From source file:cs.man.ac.uk.classifiers.GetAUC.java

License:Open Source License

/**
 * Computes the AUC for the supplied learner.
 * @return the AUC as a double value./*from   ww  w.  j a v  a 2 s .c  o m*/
 */
@SuppressWarnings("unused")
private static double validate5x2CV() {
    try {
        // other options
        int runs = 5;
        int folds = 2;
        double AUC_SUM = 0;

        // perform cross-validation
        for (int i = 0; i < runs; i++) {
            // randomize data
            int seed = i + 1;
            Random rand = new Random(seed);
            Instances randData = new Instances(data);
            randData.randomize(rand);

            if (randData.classAttribute().isNominal()) {
                System.out.println("Stratifying...");
                randData.stratify(folds);
            }

            Evaluation eval = new Evaluation(randData);

            for (int n = 0; n < folds; n++) {
                Instances train = randData.trainCV(folds, n);
                Instances test = randData.testCV(folds, n);

                // the above code is used by the StratifiedRemoveFolds filter, the
                // code below by the Explorer/Experimenter:
                // Instances train = randData.trainCV(folds, n, rand);

                // build and evaluate classifier
                String[] options = { "-U", "-A" };
                J48 classifier = new J48();
                //HTree classifier = new HTree();

                classifier.setOptions(options);
                classifier.buildClassifier(train);
                eval.evaluateModel(classifier, test);

                // generate curve
                ThresholdCurve tc = new ThresholdCurve();
                int classIndex = 0;
                Instances result = tc.getCurve(eval.predictions(), classIndex);

                // plot curve
                vmc = new ThresholdVisualizePanel();
                AUC_SUM += ThresholdCurve.getROCArea(result);
                System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM);
            }
        }

        return AUC_SUM / ((double) runs * (double) folds);
    } catch (Exception e) {
        System.out.println("Exception validating data!");
        return 0;
    }
}

From source file:es.upm.dit.gsi.barmas.dataset.utils.DatasetSplitter.java

License:Open Source License

/**
 * @param folds//w w w.j av  a2  s  .  c  o m
 * @param minAgents
 * @param maxAgents
 * @param originalDatasetPath
 * @param outputDir
 * @param scenario
 * @param logger
 */
public void splitDataset(int folds, int minAgents, int maxAgents, String originalDatasetPath, String outputDir,
        String scenario, Logger logger) {

    int ratioint = (int) ((1 / (double) folds) * 100);
    double roundedratio = ((double) ratioint) / 100;

    // Look for essentials
    List<String[]> essentials = this.getEssentials(originalDatasetPath, logger);

    for (int fold = 0; fold < folds; fold++) {
        String outputDirWithRatio = outputDir + "/" + roundedratio + "testRatio/iteration-" + fold;
        File dir = new File(outputDirWithRatio);
        if (!dir.exists() || !dir.isDirectory()) {
            dir.mkdirs();
        }

        logger.finer("--> splitDataset()");
        logger.fine("Creating experiment.info...");

        try {

            Instances originalData = this.getDataFromCSV(originalDatasetPath);

            originalData.randomize(new Random());
            originalData.stratify(folds);

            // TestDataSet
            Instances testData = originalData.testCV(folds, fold);
            CSVSaver saver = new CSVSaver();
            ArffSaver arffsaver = new ArffSaver();
            File file = new File(outputDirWithRatio + File.separator + "test-dataset.csv");
            if (!file.exists()) {
                saver.resetOptions();
                saver.setInstances(testData);
                saver.setFile(file);
                saver.writeBatch();
            }

            file = new File(outputDirWithRatio + File.separator + "test-dataset.arff");
            if (!file.exists()) {
                arffsaver.resetOptions();
                arffsaver.setInstances(testData);
                arffsaver.setFile(file);
                arffsaver.writeBatch();
            }

            // BayesCentralDataset
            Instances trainData = originalData.trainCV(folds, fold);
            file = new File(outputDirWithRatio + File.separator + "bayes-central-dataset.csv");
            if (!file.exists()) {
                saver.resetOptions();
                saver.setInstances(trainData);
                saver.setFile(file);
                saver.writeBatch();
                this.copyFileUsingApacheCommonsIO(file,
                        new File(
                                outputDirWithRatio + File.separator + "bayes-central-dataset-noEssentials.csv"),
                        logger);
                CsvWriter w = new CsvWriter(new FileWriter(file, true), ',');
                for (String[] essential : essentials) {
                    w.writeRecord(essential);
                }
                w.close();
            }
            file = new File(outputDirWithRatio + File.separator + "bayes-central-dataset.arff");
            if (!file.exists()) {
                arffsaver.resetOptions();
                arffsaver.setInstances(trainData);
                arffsaver.setFile(file);
                arffsaver.writeBatch();
                this.copyFileUsingApacheCommonsIO(file, new File(
                        outputDirWithRatio + File.separator + "bayes-central-dataset-noEssentials.arff"),
                        logger);
                CsvWriter w = new CsvWriter(new FileWriter(file, true), ',');
                for (String[] essential : essentials) {
                    w.writeRecord(essential);
                }
                w.close();
            }

            // Agent datasets
            CsvReader csvreader = new CsvReader(new FileReader(new File(originalDatasetPath)));
            csvreader.readHeaders();
            String[] headers = csvreader.getHeaders();
            csvreader.close();

            for (int agents = minAgents; agents <= maxAgents; agents++) {
                this.createExperimentInfoFile(folds, agents, originalDatasetPath, outputDirWithRatio, scenario,
                        logger);
                HashMap<String, CsvWriter> writers = new HashMap<String, CsvWriter>();
                String agentsDatasetsDir = outputDirWithRatio + File.separator + agents + "agents";
                HashMap<String, CsvWriter> arffWriters = new HashMap<String, CsvWriter>();
                File f = new File(agentsDatasetsDir);
                if (!f.isDirectory()) {
                    f.mkdirs();
                }
                Instances copy = new Instances(trainData);
                copy.delete();
                for (int i = 0; i < agents; i++) {
                    String fileName = agentsDatasetsDir + File.separator + "agent-" + i + "-dataset.csv";
                    file = new File(fileName);
                    if (!file.exists()) {
                        CsvWriter writer = new CsvWriter(new FileWriter(fileName), ',');
                        writer.writeRecord(headers);
                        writers.put("AGENT" + i, writer);
                    }
                    fileName = agentsDatasetsDir + File.separator + "agent-" + i + "-dataset.arff";
                    file = new File(fileName);
                    if (!file.exists()) {
                        arffsaver.resetOptions();
                        arffsaver.setInstances(copy);
                        arffsaver.setFile(new File(fileName));
                        arffsaver.writeBatch();
                        CsvWriter arffwriter = new CsvWriter(new FileWriter(fileName, true), ',');
                        arffWriters.put("AGENT" + i, arffwriter);
                    }

                    logger.fine("AGENT" + i + " dataset created in csv and arff formats.");
                }
                // Append essentials to all
                for (String[] essential : essentials) {
                    for (CsvWriter wr : writers.values()) {
                        wr.writeRecord(essential);
                    }
                    for (CsvWriter arffwr : arffWriters.values()) {
                        arffwr.writeRecord(essential);
                    }
                }

                int agentCounter = 0;
                for (int j = 0; j < trainData.numInstances(); j++) {
                    Instance instance = trainData.instance(j);
                    CsvWriter writer = writers.get("AGENT" + agentCounter);
                    CsvWriter arffwriter = arffWriters.get("AGENT" + agentCounter);
                    String[] row = new String[instance.numAttributes()];
                    for (int a = 0; a < instance.numAttributes(); a++) {
                        row[a] = instance.stringValue(a);
                    }
                    if (writer != null) {
                        writer.writeRecord(row);
                    }
                    if (arffwriter != null) {
                        arffwriter.writeRecord(row);
                    }
                    agentCounter++;
                    if (agentCounter == agents) {
                        agentCounter = 0;
                    }
                }

                for (CsvWriter wr : writers.values()) {
                    wr.close();
                }
                for (CsvWriter arffwr : arffWriters.values()) {
                    arffwr.close();
                }
            }

        } catch (Exception e) {
            logger.severe("Exception while splitting dataset. ->");
            logger.severe(e.getMessage());
            System.exit(1);
        }

        logger.finest("Dataset for fold " + fold + " created.");
    }

    logger.finer("<-- splitDataset()");

}

From source file:GClass.EvaluationInternal.java

License:Open Source License

/**
 * Performs a (stratified if class is nominal) cross-validation
 * for a classifier on a set of instances.
 *
 * @param classifier the classifier with any options set.
 * @param data the data on which the cross-validation is to be
 * performed//from   www  .  j  av  a2  s .c o m
 * @param numFolds the number of folds for the cross-validation
 * @param random random number generator for randomization
 * @exception Exception if a classifier could not be generated
 * successfully or the class is not defined
 */
public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random)
        throws Exception {

    // Make a copy of the data we can reorder
    data = new Instances(data);
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(numFolds);
    }
    // Do the folds
    for (int i = 0; i < numFolds; i++) {
        Instances train = data.trainCV(numFolds, i, random);
        setPriors(train);
        Classifier copiedClassifier = Classifier.makeCopy(classifier);
        copiedClassifier.buildClassifier(train);
        Instances test = data.testCV(numFolds, i);
        evaluateModel(copiedClassifier, test);
    }
    m_NumFolds = numFolds;
}

From source file:it.unisa.gitdm.evaluation.WekaEvaluator.java

private static void evaluateModel(String baseFolderPath, String projectName, Classifier pClassifier,
        Instances pInstances, String pModelName, String pClassifierName) throws Exception {

    // other options
    int folds = 10;

    // randomize data
    Random rand = new Random(42);
    Instances randData = new Instances(pInstances);
    randData.randomize(rand);//from w w w. j a  va 2 s .co m
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Instances predictedData = null;
    Evaluation eval = new Evaluation(randData);

    int positiveValueIndexOfClassFeature = 0;
    for (int n = 0; n < folds; n++) {
        Instances train = randData.trainCV(folds, n);
        Instances test = randData.testCV(folds, n);
        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);

        int classFeatureIndex = 0;
        for (int i = 0; i < train.numAttributes(); i++) {
            if (train.attribute(i).name().equals("isBuggy")) {
                classFeatureIndex = i;
                break;
            }
        }

        Attribute classFeature = train.attribute(classFeatureIndex);
        for (int i = 0; i < classFeature.numValues(); i++) {
            if (classFeature.value(i).equals("TRUE")) {
                positiveValueIndexOfClassFeature = i;
            }
        }

        train.setClassIndex(classFeatureIndex);
        test.setClassIndex(classFeatureIndex);

        // build and evaluate classifier
        pClassifier.buildClassifier(train);
        eval.evaluateModel(pClassifier, test);

        // add predictions
        //           AddClassification filter = new AddClassification();
        //           filter.setClassifier(pClassifier);
        //           filter.setOutputClassification(true);
        //           filter.setOutputDistribution(true);
        //           filter.setOutputErrorFlag(true);
        //           filter.setInputFormat(train);
        //           Filter.useFilter(train, filter); 
        //           Instances pred = Filter.useFilter(test, filter); 
        //           if (predictedData == null)
        //             predictedData = new Instances(pred, 0);
        //           
        //           for (int j = 0; j < pred.numInstances(); j++)
        //             predictedData.add(pred.instance(j));
    }
    double accuracy = (eval.numTruePositives(positiveValueIndexOfClassFeature)
            + eval.numTrueNegatives(positiveValueIndexOfClassFeature))
            / (eval.numTruePositives(positiveValueIndexOfClassFeature)
                    + eval.numFalsePositives(positiveValueIndexOfClassFeature)
                    + eval.numFalseNegatives(positiveValueIndexOfClassFeature)
                    + eval.numTrueNegatives(positiveValueIndexOfClassFeature));

    double fmeasure = 2 * ((eval.precision(positiveValueIndexOfClassFeature)
            * eval.recall(positiveValueIndexOfClassFeature))
            / (eval.precision(positiveValueIndexOfClassFeature)
                    + eval.recall(positiveValueIndexOfClassFeature)));
    File wekaOutput = new File(baseFolderPath + projectName + "/predictors.csv");
    PrintWriter pw1 = new PrintWriter(wekaOutput);

    pw1.write(accuracy + ";" + eval.precision(positiveValueIndexOfClassFeature) + ";"
            + eval.recall(positiveValueIndexOfClassFeature) + ";" + fmeasure + ";"
            + eval.areaUnderROC(positiveValueIndexOfClassFeature));

    System.out.println(projectName + ";" + pClassifierName + ";" + pModelName + ";"
            + eval.numTruePositives(positiveValueIndexOfClassFeature) + ";"
            + eval.numFalsePositives(positiveValueIndexOfClassFeature) + ";"
            + eval.numFalseNegatives(positiveValueIndexOfClassFeature) + ";"
            + eval.numTrueNegatives(positiveValueIndexOfClassFeature) + ";" + accuracy + ";"
            + eval.precision(positiveValueIndexOfClassFeature) + ";"
            + eval.recall(positiveValueIndexOfClassFeature) + ";" + fmeasure + ";"
            + eval.areaUnderROC(positiveValueIndexOfClassFeature) + "\n");
}

From source file:j48.PruneableClassifierTree.java

License:Open Source License

/**
 * Method for building a pruneable classifier tree.
 *
 * @param data the data to build the tree from 
 * @throws Exception if tree can't be built successfully
 *//*from w w w  .  j  av a2  s.c  o  m*/
public void buildClassifier(Instances data) throws Exception {

    // can classifier tree handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    Random random = new Random(m_seed);
    data.stratify(numSets);
    buildTree(data.trainCV(numSets, numSets - 1, random), data.testCV(numSets, numSets - 1), !m_cleanup);
    if (pruneTheTree) {
        prune();
    }
    if (m_cleanup) {
        cleanup(new Instances(data, 0));
    }
}

From source file:jjj.asap.sas.ensemble.impl.CrossValidatedEnsemble.java

License:Open Source License

@Override
public StrongLearner build(int essaySet, String ensembleName, List<WeakLearner> learners) {

    // can't handle empty case
    if (learners.isEmpty()) {
        return this.ensemble.build(essaySet, ensembleName, learners);
    }//w  ww.ja  v a  2 s. c  o  m

    // create a dummy dataset.
    DatasetBuilder builder = new DatasetBuilder();
    builder.addVariable("id");
    builder.addNominalVariable("class", Contest.getRubrics(essaySet));
    Instances dummy = builder.getDataset("dummy");

    // add data
    Map<Double, Double> groundTruth = Contest.getGoldStandard(essaySet);
    for (double id : learners.get(0).getPreds().keySet()) {
        dummy.add(new DenseInstance(1.0, new double[] { id, groundTruth.get(id) }));
    }

    // stratify
    dummy.sort(0);
    dummy.randomize(new Random(1));
    dummy.setClassIndex(1);
    dummy.stratify(nFolds);

    // now evaluate each fold
    Map<Double, Double> preds = new HashMap<Double, Double>();
    for (int k = 0; k < nFolds; k++) {
        Instances train = dummy.trainCV(nFolds, k);
        Instances test = dummy.testCV(nFolds, k);

        List<WeakLearner> cvLeaners = new ArrayList<WeakLearner>();
        for (WeakLearner learner : learners) {
            WeakLearner copy = learner.copyOf();
            for (int i = 0; i < test.numInstances(); i++) {
                copy.getPreds().remove(test.instance(i).value(0));
                copy.getProbs().remove(test.instance(i).value(0));
            }
            cvLeaners.add(copy);
        }

        // train on fold
        StrongLearner cv = this.ensemble.build(essaySet, ensembleName, cvLeaners);

        List<WeakLearner> testLeaners = new ArrayList<WeakLearner>();
        for (WeakLearner learner : cv.getLearners()) {
            WeakLearner copy = learner.copyOf();
            copy.getPreds().clear();
            copy.getProbs().clear();
            WeakLearner source = find(copy.getName(), learners);
            for (int i = 0; i < test.numInstances(); i++) {
                double id = test.instance(i).value(0);
                copy.getPreds().put(id, source.getPreds().get(id));
                copy.getProbs().put(id, source.getProbs().get(id));
            }
            testLeaners.add(copy);
        }

        preds.putAll(this.ensemble.classify(essaySet, ensembleName, testLeaners, cv.getContext()));
    }

    // now prepare final result

    StrongLearner strong = this.ensemble.build(essaySet, ensembleName, learners);

    double trainingError = strong.getKappa();
    double cvError = Calc.kappa(essaySet, preds, groundTruth);
    //   Job.log(essaySet+"-"+ensembleName, "XVAL: training error = " + trainingError + " cv error = " + cvError);      

    strong.setKappa(cvError);
    return strong;
}