Example usage for weka.core Instances trainCV

List of usage examples for weka.core Instances trainCV

Introduction

In this page you can find the example usage for weka.core Instances trainCV.

Prototype



public Instances trainCV(int numFolds, int numFold, Random random) 

Source Link

Document

Creates the training set for one fold of a cross-validation on the dataset.

Usage

From source file:REPTree.java

License:Open Source License

/**
 * Builds classifier./*from  w  w  w . j  av  a2s  . c o m*/
 * 
 * @param data the data to train with
 * @throws Exception if building fails
 */
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    Random random = new Random(m_Seed);

    m_zeroR = null;
    if (data.numAttributes() == 1) {
        m_zeroR = new ZeroR();
        m_zeroR.buildClassifier(data);
        return;
    }

    // Randomize and stratify
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(m_NumFolds);
    }

    // Split data into training and pruning set
    Instances train = null;
    Instances prune = null;
    if (!m_NoPruning) {
        train = data.trainCV(m_NumFolds, 0, random);
        prune = data.testCV(m_NumFolds, 0);
    } else {
        train = data;
    }

    // Create array of sorted indices and weights
    int[][][] sortedIndices = new int[1][train.numAttributes()][0];
    double[][][] weights = new double[1][train.numAttributes()][0];
    double[] vals = new double[train.numInstances()];
    for (int j = 0; j < train.numAttributes(); j++) {
        if (j != train.classIndex()) {
            weights[0][j] = new double[train.numInstances()];
            if (train.attribute(j).isNominal()) {

                // Handling nominal attributes. Putting indices of
                // instances with missing values at the end.
                sortedIndices[0][j] = new int[train.numInstances()];
                int count = 0;
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (!inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
            } else {

                // Sorted indices are computed for numeric attributes
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    vals[i] = inst.value(j);
                }
                sortedIndices[0][j] = Utils.sort(vals);
                for (int i = 0; i < train.numInstances(); i++) {
                    weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight();
                }
            }
        }
    }

    // Compute initial class counts
    double[] classProbs = new double[train.numClasses()];
    double totalWeight = 0, totalSumSquared = 0;
    for (int i = 0; i < train.numInstances(); i++) {
        Instance inst = train.instance(i);
        if (data.classAttribute().isNominal()) {
            classProbs[(int) inst.classValue()] += inst.weight();
            totalWeight += inst.weight();
        } else {
            classProbs[0] += inst.classValue() * inst.weight();
            totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
            totalWeight += inst.weight();
        }
    }
    m_Tree = new Tree();
    double trainVariance = 0;
    if (data.classAttribute().isNumeric()) {
        trainVariance = m_Tree.singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
        classProbs[0] /= totalWeight;
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum,
            m_MinVarianceProp * trainVariance, 0, m_MaxDepth);

    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
        m_Tree.insertHoldOutSet(prune);
        m_Tree.reducedErrorPrune();
        m_Tree.backfitHoldOutSet();
    }
}

From source file:REPRandomTree.java

License:Open Source License

/**
 * Builds classifier.//  w w  w.  j a v  a  2s.c  o m
 * 
 * @param data the data to train with
 * @throws Exception if building fails
 */
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    Random random = new Random(m_Seed);

    m_zeroR = null;
    if (data.numAttributes() == 1) {
        m_zeroR = new ZeroR();
        m_zeroR.buildClassifier(data);
        return;
    }

    // Randomize and stratify
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(m_NumFolds);
    }

    // Split data into training and pruning set
    Instances train = null;
    Instances prune = null;
    if (!m_NoPruning) {
        train = data.trainCV(m_NumFolds, 0, random);
        prune = data.testCV(m_NumFolds, 0);
    } else {
        train = data;
    }

    // Create array of sorted indices and weights
    int[][][] sortedIndices = new int[1][train.numAttributes()][0];
    double[][][] weights = new double[1][train.numAttributes()][0];
    double[] vals = new double[train.numInstances()];
    for (int j = 0; j < train.numAttributes(); j++) {
        if (j != train.classIndex()) {
            weights[0][j] = new double[train.numInstances()];
            if (train.attribute(j).isNominal()) {

                // Handling nominal attributes. Putting indices of
                // instances with missing values at the end.
                sortedIndices[0][j] = new int[train.numInstances()];
                int count = 0;
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (!inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
            } else {

                // Sorted indices are computed for numeric attributes
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    vals[i] = inst.value(j);
                }
                sortedIndices[0][j] = Utils.sort(vals);
                for (int i = 0; i < train.numInstances(); i++) {
                    weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight();
                }
            }
        }
    }

    // Compute initial class counts
    double[] classProbs = new double[train.numClasses()];
    double totalWeight = 0, totalSumSquared = 0;
    for (int i = 0; i < train.numInstances(); i++) {
        Instance inst = train.instance(i);
        if (data.classAttribute().isNominal()) {
            classProbs[(int) inst.classValue()] += inst.weight();
            totalWeight += inst.weight();
        } else {
            classProbs[0] += inst.classValue() * inst.weight();
            totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
            totalWeight += inst.weight();
        }
    }
    m_Tree = new Tree();
    double trainVariance = 0;
    if (data.classAttribute().isNumeric()) {
        trainVariance = m_Tree.singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
        classProbs[0] /= totalWeight;
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum,
            m_MinVarianceProp * trainVariance, 0, m_MaxDepth, m_FeatureFrac, random);

    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
        m_Tree.insertHoldOutSet(prune);
        m_Tree.reducedErrorPrune();
        m_Tree.backfitHoldOutSet();
    }
}

From source file:adams.flow.transformer.WekaAttributeSelection.java

License:Open Source License

/**
 * Executes the flow item./*from   ww  w. j ava2 s  .co  m*/
 *
 * @return      null if everything is fine, otherwise error message
 */
@Override
protected String doExecute() {
    String result;
    Instances data;
    Instances reduced;
    Instances transformed;
    AttributeSelection eval;
    boolean crossValidate;
    int fold;
    Instances train;
    WekaAttributeSelectionContainer cont;
    SpreadSheet stats;
    int i;
    Row row;
    int[] selected;
    double[][] ranked;
    Range range;
    String rangeStr;
    boolean useReduced;

    result = null;

    try {
        if (m_InputToken.getPayload() instanceof Instances)
            data = (Instances) m_InputToken.getPayload();
        else
            data = (Instances) ((WekaTrainTestSetContainer) m_InputToken.getPayload())
                    .getValue(WekaTrainTestSetContainer.VALUE_TRAIN);

        if (result == null) {
            crossValidate = (m_Folds >= 2);

            // setup evaluation
            eval = new AttributeSelection();
            eval.setEvaluator(m_Evaluator);
            eval.setSearch(m_Search);
            eval.setFolds(m_Folds);
            eval.setSeed((int) m_Seed);
            eval.setXval(crossValidate);

            // select attributes
            if (crossValidate) {
                Random random = new Random(m_Seed);
                data = new Instances(data);
                data.randomize(random);
                if ((data.classIndex() > -1) && data.classAttribute().isNominal()) {
                    if (isLoggingEnabled())
                        getLogger().info("Stratifying instances...");
                    data.stratify(m_Folds);
                }
                for (fold = 0; fold < m_Folds; fold++) {
                    if (isLoggingEnabled())
                        getLogger().info("Creating splits for fold " + (fold + 1) + "...");
                    train = data.trainCV(m_Folds, fold, random);
                    if (isLoggingEnabled())
                        getLogger().info("Selecting attributes using all but fold " + (fold + 1) + "...");
                    eval.selectAttributesCVSplit(train);
                }
            } else {
                eval.SelectAttributes(data);
            }

            // generate reduced/transformed dataset
            reduced = null;
            transformed = null;
            if (!crossValidate) {
                reduced = eval.reduceDimensionality(data);
                if (m_Evaluator instanceof AttributeTransformer)
                    transformed = ((AttributeTransformer) m_Evaluator).transformedData(data);
            }

            // generated stats
            stats = null;
            if (!crossValidate) {
                stats = new DefaultSpreadSheet();
                row = stats.getHeaderRow();

                useReduced = false;
                if (m_Search instanceof RankedOutputSearch) {
                    i = reduced.numAttributes();
                    if (reduced.classIndex() > -1)
                        i--;
                    ranked = eval.rankedAttributes();
                    useReduced = (ranked.length == i);
                }

                if (useReduced) {
                    for (i = 0; i < reduced.numAttributes(); i++)
                        row.addCell("" + i).setContent(reduced.attribute(i).name());
                    row = stats.addRow();
                    for (i = 0; i < reduced.numAttributes(); i++)
                        row.addCell(i).setContent(0.0);
                } else {
                    for (i = 0; i < data.numAttributes(); i++)
                        row.addCell("" + i).setContent(data.attribute(i).name());
                    row = stats.addRow();
                    for (i = 0; i < data.numAttributes(); i++)
                        row.addCell(i).setContent(0.0);
                }

                if (m_Search instanceof RankedOutputSearch) {
                    ranked = eval.rankedAttributes();
                    for (i = 0; i < ranked.length; i++)
                        row.getCell((int) ranked[i][0]).setContent(ranked[i][1]);
                } else {
                    selected = eval.selectedAttributes();
                    for (i = 0; i < selected.length; i++)
                        row.getCell(selected[i]).setContent(1.0);
                }
            }

            // selected attributes
            rangeStr = null;
            if (!crossValidate) {
                range = new Range();
                range.setIndices(eval.selectedAttributes());
                rangeStr = range.getRange();
            }

            // setup container
            if (crossValidate)
                cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, m_Seed, m_Folds);
            else
                cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, stats, rangeStr);
            m_OutputToken = new Token(cont);
        }
    } catch (Exception e) {
        m_OutputToken = null;
        result = handleException("Failed to process data:", e);
    }

    return result;
}

From source file:ann.MyANN.java

/**
 * Mengevaluasi model dengan membagi instances menjadi trainSet dan testSet sebanyak numFold
 * @param instances data yang akan diuji
 * @param numFold//  ww  w  .  ja  v  a 2 s  .  com
 * @param rand 
 * @return confusion matrix
 */
public int[][] crossValidation(Instances instances, int numFold, Random rand) {
    int[][] totalResult = null;
    instances = new Instances(instances);
    instances.randomize(rand);
    if (instances.classAttribute().isNominal()) {
        instances.stratify(numFold);
    }
    for (int i = 0; i < numFold; i++) {
        try {
            // membagi instance berdasarkan jumlah fold
            Instances train = instances.trainCV(numFold, i, rand);
            Instances test = instances.testCV(numFold, i);
            MyANN cc = new MyANN(this);
            cc.buildClassifier(train);
            int[][] result = cc.evaluate(test);
            if (i == 0) {
                totalResult = cc.evaluate(test);
            } else {
                result = cc.evaluate(test);
                for (int j = 0; j < totalResult.length; j++) {
                    for (int k = 0; k < totalResult[0].length; k++) {
                        totalResult[j][k] += result[j][k];
                    }
                }
            }
        } catch (Exception ex) {
            Logger.getLogger(MyANN.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    return totalResult;
}

From source file:br.ufrn.ia.core.clustering.EMIaProject.java

License:Open Source License

private void CVClusters() throws Exception {
    double CVLogLikely = -Double.MAX_VALUE;
    double templl, tll;
    boolean CVincreased = true;
    m_num_clusters = 1;/*from   w w  w . ja va  2  s  .  co  m*/
    int num_clusters = m_num_clusters;
    int i;
    Random cvr;
    Instances trainCopy;
    int numFolds = (m_theInstances.numInstances() < 10) ? m_theInstances.numInstances() : 10;

    boolean ok = true;
    int seed = getSeed();
    int restartCount = 0;
    CLUSTER_SEARCH: while (CVincreased) {
        // theInstances.stratify(10);

        CVincreased = false;
        cvr = new Random(getSeed());
        trainCopy = new Instances(m_theInstances);
        trainCopy.randomize(cvr);
        templl = 0.0;
        for (i = 0; i < numFolds; i++) {
            Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
            if (num_clusters > cvTrain.numInstances()) {
                break CLUSTER_SEARCH;
            }
            Instances cvTest = trainCopy.testCV(numFolds, i);
            m_rr = new Random(seed);
            for (int z = 0; z < 10; z++)
                m_rr.nextDouble();
            m_num_clusters = num_clusters;
            EM_Init(cvTrain);
            try {
                iterate(cvTrain, false);
            } catch (Exception ex) {
                // catch any problems - i.e. empty clusters occuring
                ex.printStackTrace();
                // System.err.println("Restarting after CV training failure
                // ("+num_clusters+" clusters");
                seed++;
                restartCount++;
                ok = false;
                if (restartCount > 5) {
                    break CLUSTER_SEARCH;
                }
                break;
            }
            try {
                tll = E(cvTest, false);
            } catch (Exception ex) {
                // catch any problems - i.e. empty clusters occuring
                // ex.printStackTrace();
                ex.printStackTrace();
                // System.err.println("Restarting after CV testing failure
                // ("+num_clusters+" clusters");
                // throw new Exception(ex);
                seed++;
                restartCount++;
                ok = false;
                if (restartCount > 5) {
                    break CLUSTER_SEARCH;
                }
                break;
            }

            if (m_verbose) {
                System.out.println("# clust: " + num_clusters + " Fold: " + i + " Loglikely: " + tll);
            }
            templl += tll;
        }

        if (ok) {
            restartCount = 0;
            seed = getSeed();
            templl /= (double) numFolds;

            if (m_verbose) {
                System.out.println("===================================" + "==============\n# clust: "
                        + num_clusters + " Mean Loglikely: " + templl + "\n================================"
                        + "=================");
            }

            if (templl > CVLogLikely) {
                CVLogLikely = templl;
                CVincreased = true;
                num_clusters++;
            }
        }
    }

    if (m_verbose) {
        System.out.println("Number of clusters: " + (num_clusters - 1));
    }

    m_num_clusters = num_clusters - 1;
}

From source file:cezeri.evaluater.FactoryEvaluation.java

public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text,
        boolean show_plot, TFigureAttribute attr) {
    Random rand = new Random(1);
    Instances randData = new Instances(datax);
    randData.randomize(rand);/*from   w  w  w  . j ava2s.  c  om*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }
    Evaluation eval = null;
    try {
        // perform cross-validation
        eval = new Evaluation(randData);
        //            double[] simulated = new double[0];
        //            double[] observed = new double[0];
        //            double[] sim = new double[0];
        //            double[] obs = new double[0];
        for (int n = 0; n < folds; n++) {
            Instances train = randData.trainCV(folds, n, rand);
            Instances validation = randData.testCV(folds, n);
            // build and evaluate classifier
            Classifier clsCopy = Classifier.makeCopy(model);
            clsCopy.buildClassifier(train);

            //                sim = eval.evaluateModel(clsCopy, validation);
            //                obs = validation.attributeToDoubleArray(validation.classIndex());
            //                if (show_plot) {
            //                    double[][] d = new double[2][sim.length];
            //                    d[0] = obs;
            //                    d[1] = sim;
            //                    CMatrix f1 = CMatrix.getInstance(d);
            //                    f1.transpose().plot(attr);
            //                }
            //                if (show_text) {
            //                    // output evaluation
            //                    System.out.println();
            //                    System.out.println("=== Setup for each Cross Validation fold===");
            //                    System.out.println("Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            //                    System.out.println("Dataset: " + randData.relationName());
            //                    System.out.println("Folds: " + folds);
            //                    System.out.println("Seed: " + 1);
            //                    System.out.println();
            //                    System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
            //                }
            simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation));
            observed = FactoryUtils.concatenate(observed,
                    validation.attributeToDoubleArray(validation.classIndex()));
            //                simulated = FactoryUtils.mean(simulated,eval.evaluateModel(clsCopy, validation));
            //                observed = FactoryUtils.mean(observed,validation.attributeToDoubleArray(validation.classIndex()));
        }

        if (show_plot) {
            double[][] d = new double[2][simulated.length];
            d[0] = observed;
            d[1] = simulated;
            CMatrix f1 = CMatrix.getInstance(d);
            attr.figureCaption = "overall performance";
            f1.transpose().plot(attr);
        }
        if (show_text) {
            // output evaluation
            System.out.println();
            System.out.println("=== Setup for Overall Cross Validation===");
            System.out.println(
                    "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            System.out.println("Dataset: " + randData.relationName());
            System.out.println("Folds: " + folds);
            System.out.println("Seed: " + 1);
            System.out.println();
            System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
        }
    } catch (Exception ex) {
        Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
    }
    return eval;
}

From source file:core.ClusterEvaluationEX.java

License:Open Source License

/**
 * Perform a cross-validation for DensityBasedClusterer on a set of instances.
 *
 * @param clusterer the clusterer to use
 * @param data the training data/* ww  w.  jav a 2 s .c o m*/
 * @param numFolds number of folds of cross validation to perform
 * @param random random number seed for cross-validation
 * @return the cross-validated log-likelihood
 * @throws Exception if an error occurs
 */
public static double crossValidateModel(DensityBasedClusterer clusterer, Instances data, int numFolds,
        Random random) throws Exception {
    Instances train, test;
    double foldAv = 0;
    ;
    data = new Instances(data);
    data.randomize(random);
    //    double sumOW = 0;
    for (int i = 0; i < numFolds; i++) {
        // Build and test clusterer
        train = data.trainCV(numFolds, i, random);

        clusterer.buildClusterer(train);

        test = data.testCV(numFolds, i);

        for (int j = 0; j < test.numInstances(); j++) {
            try {
                foldAv += ((DensityBasedClusterer) clusterer).logDensityForInstance(test.instance(j));
                //     sumOW += test.instance(j).weight();
                //   double temp = Utils.sum(tempDist);
            } catch (Exception ex) {
                // unclustered instances
            }
        }
    }

    //    return foldAv / sumOW;
    return foldAv / data.numInstances();
}

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Performs a (stratified if class is nominal) cross-validation 
 * for a classifier on a set of instances. Now performs
 * a deep copy of the classifier before each call to 
 * buildClassifier() (just in case the classifier is not
 * initialized properly)./*from   w w w. j a va  2 s.c om*/
 *
 * @param classifier the classifier with any options set.
 * @param data the data on which the cross-validation is to be 
 * performed 
 * @param numFolds the number of folds for the cross-validation
 * @param random random number generator for randomization 
 * @param forPredictionsString varargs parameter that, if supplied, is
 * expected to hold a StringBuffer to print predictions to, 
 * a Range of attributes to output and a Boolean (true if the distribution
 * is to be printed)
 * @throws Exception if a classifier could not be generated 
 * successfully or the class is not defined
 */
public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random,
        Object... forPredictionsPrinting) throws Exception {

    // Make a copy of the data we can reorder
    data = new Instances(data);
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(numFolds);
    }

    // We assume that the first element is a StringBuffer, the second a Range (attributes
    // to output) and the third a Boolean (whether or not to output a distribution instead
    // of just a classification)
    if (forPredictionsPrinting.length > 0) {
        // print the header first
        StringBuffer buff = (StringBuffer) forPredictionsPrinting[0];
        Range attsToOutput = (Range) forPredictionsPrinting[1];
        boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue();
        printClassificationsHeader(data, attsToOutput, printDist, buff);
    }

    // Do the folds
    for (int i = 0; i < numFolds; i++) {
        Instances train = data.trainCV(numFolds, i, random);
        setPriors(train);
        Classifier copiedClassifier = Classifier.makeCopy(classifier);
        copiedClassifier.buildClassifier(train);
        Instances test = data.testCV(numFolds, i);
        evaluateModel(copiedClassifier, test, forPredictionsPrinting);
    }
    m_NumFolds = numFolds;
}

From source file:de.tudarmstadt.ukp.similarity.experiments.coling2012.util.Evaluator.java

License:Open Source License

public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = getClassifier(wekaClassifier);

    // Set up the random number generator
    long seed = new Date().getTime();
    Random random = new Random(seed);

    // Add IDs to the instances
    AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o",
            MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);

    // Instantiate the Remove filter
    Remove removeIDFilter = new Remove();
    removeIDFilter.setAttributeIndices("first");

    // Randomize the data
    data.randomize(random);//  www.  j  av  a 2 s. c om

    // Perform cross-validation
    Instances predictedData = null;
    Evaluation eval = new Evaluation(data);

    for (int n = 0; n < folds; n++) {
        Instances train = data.trainCV(folds, n, random);
        Instances test = data.testCV(folds, n);

        // Apply log filter
        //          Filter logFilter = new LogFilter();
        //           logFilter.setInputFormat(train);
        //           train = Filter.useFilter(train, logFilter);        
        //           logFilter.setInputFormat(test);
        //           test = Filter.useFilter(test, logFilter);

        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);

        // Build the classifier
        filteredClassifier.buildClassifier(train);

        // Evaluate
        eval.evaluateModel(filteredClassifier, test);

        // Add predictions
        AddClassification filter = new AddClassification();
        filter.setClassifier(filteredClassifier);
        filter.setOutputClassification(true);
        filter.setOutputDistribution(false);
        filter.setOutputErrorFlag(true);
        filter.setInputFormat(train);
        Filter.useFilter(train, filter); // trains the classifier

        Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
        if (predictedData == null)
            predictedData = new Instances(pred, 0);
        for (int j = 0; j < pred.numInstances(); j++)
            predictedData.add(pred.instance(j));
    }

    // Prepare output classification
    String[] scores = new String[predictedData.numInstances()];

    for (Instance predInst : predictedData) {
        int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

        int valueIdx = predictedData.numAttributes() - 2;

        String value = predInst.stringValue(predInst.attribute(valueIdx));

        scores[id] = value;
    }

    // Output
    StringBuilder sb = new StringBuilder();
    for (String score : scores)
        sb.append(score.toString() + LF);

    FileUtils.writeStringToFile(
            new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"),
            sb.toString());
}

From source file:dkpro.similarity.experiments.rte.util.Evaluator.java

License:Open Source License

public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);

    // Set up the random number generator
    long seed = new Date().getTime();
    Random random = new Random(seed);

    // Add IDs to the instances
    AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o",
            MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);

    // Instantiate the Remove filter
    Remove removeIDFilter = new Remove();
    removeIDFilter.setAttributeIndices("first");

    // Randomize the data
    data.randomize(random);/* ww  w .  j a v  a2s .co m*/

    // Perform cross-validation
    Instances predictedData = null;
    Evaluation eval = new Evaluation(data);

    for (int n = 0; n < folds; n++) {
        Instances train = data.trainCV(folds, n, random);
        Instances test = data.testCV(folds, n);

        // Apply log filter
        //          Filter logFilter = new LogFilter();
        //           logFilter.setInputFormat(train);
        //           train = Filter.useFilter(train, logFilter);        
        //           logFilter.setInputFormat(test);
        //           test = Filter.useFilter(test, logFilter);

        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);

        // Build the classifier
        filteredClassifier.buildClassifier(train);

        // Evaluate
        eval.evaluateModel(filteredClassifier, test);

        // Add predictions
        AddClassification filter = new AddClassification();
        filter.setClassifier(classifier);
        filter.setOutputClassification(true);
        filter.setOutputDistribution(false);
        filter.setOutputErrorFlag(true);
        filter.setInputFormat(train);
        Filter.useFilter(train, filter); // trains the classifier

        Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
        if (predictedData == null)
            predictedData = new Instances(pred, 0);
        for (int j = 0; j < pred.numInstances(); j++)
            predictedData.add(pred.instance(j));
    }

    System.out.println(eval.toSummaryString());
    System.out.println(eval.toMatrixString());

    // Prepare output scores
    String[] scores = new String[predictedData.numInstances()];

    for (Instance predInst : predictedData) {
        int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

        int valueIdx = predictedData.numAttributes() - 2;

        String value = predInst.stringValue(predInst.attribute(valueIdx));

        scores[id] = value;
    }

    // Output classifications
    StringBuilder sb = new StringBuilder();
    for (String score : scores)
        sb.append(score.toString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
            + "/" + dataset.toString() + ".csv"), sb.toString());

    // Output prediction arff
    DataSink.write(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/"
            + dataset.toString() + ".predicted.arff", predictedData);

    // Output meta information
    sb = new StringBuilder();
    sb.append(baseClassifier.toString() + LF);
    sb.append(eval.toSummaryString() + LF);
    sb.append(eval.toMatrixString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
            + "/" + dataset.toString() + ".meta.txt"), sb.toString());
}