Example usage for weka.classifiers.meta FilteredClassifier setFilter

List of usage examples for weka.classifiers.meta FilteredClassifier setFilter

Introduction

In this page you can find the example usage for weka.classifiers.meta FilteredClassifier setFilter.

Prototype

public void setFilter(Filter filter) 

Source Link

Document

Sets the filter

Usage

From source file:dkpro.similarity.experiments.rte.util.Evaluator.java

License:Open Source License

public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);

    // Set up the random number generator
    long seed = new Date().getTime();
    Random random = new Random(seed);

    // Add IDs to the instances
    AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o",
            MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);

    // Instantiate the Remove filter
    Remove removeIDFilter = new Remove();
    removeIDFilter.setAttributeIndices("first");

    // Randomize the data
    data.randomize(random);//from  w w  w  . ja  v a  2s.  c  om

    // Perform cross-validation
    Instances predictedData = null;
    Evaluation eval = new Evaluation(data);

    for (int n = 0; n < folds; n++) {
        Instances train = data.trainCV(folds, n, random);
        Instances test = data.testCV(folds, n);

        // Apply log filter
        //          Filter logFilter = new LogFilter();
        //           logFilter.setInputFormat(train);
        //           train = Filter.useFilter(train, logFilter);        
        //           logFilter.setInputFormat(test);
        //           test = Filter.useFilter(test, logFilter);

        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);

        // Build the classifier
        filteredClassifier.buildClassifier(train);

        // Evaluate
        eval.evaluateModel(filteredClassifier, test);

        // Add predictions
        AddClassification filter = new AddClassification();
        filter.setClassifier(classifier);
        filter.setOutputClassification(true);
        filter.setOutputDistribution(false);
        filter.setOutputErrorFlag(true);
        filter.setInputFormat(train);
        Filter.useFilter(train, filter); // trains the classifier

        Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
        if (predictedData == null)
            predictedData = new Instances(pred, 0);
        for (int j = 0; j < pred.numInstances(); j++)
            predictedData.add(pred.instance(j));
    }

    System.out.println(eval.toSummaryString());
    System.out.println(eval.toMatrixString());

    // Prepare output scores
    String[] scores = new String[predictedData.numInstances()];

    for (Instance predInst : predictedData) {
        int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

        int valueIdx = predictedData.numAttributes() - 2;

        String value = predInst.stringValue(predInst.attribute(valueIdx));

        scores[id] = value;
    }

    // Output classifications
    StringBuilder sb = new StringBuilder();
    for (String score : scores)
        sb.append(score.toString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
            + "/" + dataset.toString() + ".csv"), sb.toString());

    // Output prediction arff
    DataSink.write(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/"
            + dataset.toString() + ".predicted.arff", predictedData);

    // Output meta information
    sb = new StringBuilder();
    sb.append(baseClassifier.toString() + LF);
    sb.append(eval.toSummaryString() + LF);
    sb.append(eval.toMatrixString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
            + "/" + dataset.toString() + ".meta.txt"), sb.toString());
}

From source file:dkpro.similarity.experiments.sts2013.util.Evaluator.java

License:Open Source License

public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception {
    for (Dataset dataset : datasets) {
        // Set parameters
        int folds = 10;
        Classifier baseClassifier = new LinearRegression();

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the instances
        AddID.main(new String[] { "-i",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                        + "-plusIDs.arff" });
        Instances data = DataSource.read(
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff");
        data.setClassIndex(data.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        data.randomize(random);/*from   ww  w  .j a  v  a  2 s  . com*/

        // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);

        for (int n = 0; n < folds; n++) {
            Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);

            // Apply log filter
            Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);

            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);

            // Build the classifier
            filteredClassifier.buildClassifier(train);

            // Evaluate
            eval.evaluateModel(classifier, test);

            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter); // trains the classifier

            Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
            if (predictedData == null) {
                predictedData = new Instances(pred, 0);
            }
            for (int j = 0; j < pred.numInstances(); j++) {
                predictedData.add(pred.instance(j));
            }
        }

        // Prepare output scores
        double[] scores = new double[predictedData.numInstances()];

        for (Instance predInst : predictedData) {
            int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

            int valueIdx = predictedData.numAttributes() - 2;

            double value = predInst.value(predInst.attribute(valueIdx));

            scores[id] = value;

            // Limit to interval [0;5]
            if (scores[id] > 5.0) {
                scores[id] = 5.0;
            }
            if (scores[id] < 0.0) {
                scores[id] = 0.0;
            }
        }

        // Output
        StringBuilder sb = new StringBuilder();
        for (Double score : scores) {
            sb.append(score.toString() + LF);
        }

        FileUtils.writeStringToFile(
                new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"),
                sb.toString());
    }
}

From source file:dkpro.similarity.experiments.sts2013baseline.util.Evaluator.java

License:Open Source License

public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception {
    for (Dataset dataset : datasets) {
        // Set parameters
        int folds = 10;
        Classifier baseClassifier = new LinearRegression();

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the instances
        AddID.main(new String[] { "-i",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                        + "-plusIDs.arff" });

        String location = MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                + "-plusIDs.arff";

        Instances data = DataSource.read(location);

        if (data == null) {
            throw new IOException("Could not load data from: " + location);
        }/*www .ja  va  2s .c o m*/

        data.setClassIndex(data.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        data.randomize(random);

        // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);

        for (int n = 0; n < folds; n++) {
            Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);

            // Apply log filter
            Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);

            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);

            // Build the classifier
            filteredClassifier.buildClassifier(train);

            // Evaluate
            eval.evaluateModel(classifier, test);

            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter); // trains the classifier

            Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
            if (predictedData == null) {
                predictedData = new Instances(pred, 0);
            }
            for (int j = 0; j < pred.numInstances(); j++) {
                predictedData.add(pred.instance(j));
            }
        }

        // Prepare output scores
        double[] scores = new double[predictedData.numInstances()];

        for (Instance predInst : predictedData) {
            int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

            int valueIdx = predictedData.numAttributes() - 2;

            double value = predInst.value(predInst.attribute(valueIdx));

            scores[id] = value;

            // Limit to interval [0;5]
            if (scores[id] > 5.0) {
                scores[id] = 5.0;
            }
            if (scores[id] < 0.0) {
                scores[id] = 0.0;
            }
        }

        // Output
        StringBuilder sb = new StringBuilder();
        for (Double score : scores) {
            sb.append(score.toString() + LF);
        }

        FileUtils.writeStringToFile(
                new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"),
                sb.toString());
    }
}

From source file:elh.eus.absa.WekaWrapper.java

License:Open Source License

/**
 * @param traindata/*from  www  .jav a2s.co m*/
 * @param testdata
 * @param id : whether the first attribute represents de instance id and should be filtered out for classifying
 * @throws Exception
 */
public WekaWrapper(Instances traindata, Instances testdata, boolean id) throws Exception {

    // classifier
    weka.classifiers.functions.SMO SVM = new weka.classifiers.functions.SMO();
    SVM.setOptions(weka.core.Utils.splitOptions("-C 1.0 -L 0.0010 -P 1.0E-12 -N 0 -V -1 -W 1 "
            + "-K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0\""));
    setTraindata(traindata);
    setTestdata(testdata);

    // first attribute reflects instance id, delete it when building the classifier
    if (id) {
        //filter
        Remove rm = new Remove();
        rm.setAttributeIndices("1"); // remove 1st attribute
        // meta-classifier
        FilteredClassifier fc = new FilteredClassifier();
        fc.setFilter(rm);
        fc.setClassifier(SVM);
        setMLclass(fc);
    } else {
        setMLclass(SVM);
    }
}

From source file:elh.eus.absa.WekaWrapper.java

License:Open Source License

public void filterAttribute(String index) throws Exception {
    //filter/*from w ww.  ja v  a  2  s  . c  o  m*/
    Remove rm = new Remove();
    rm.setAttributeIndices(index); // remove 1st attribute indexes start from 1
    // meta-classifier
    FilteredClassifier fc = new FilteredClassifier();
    fc.setFilter(rm);
    fc.setClassifier(this.MLclass);
    setMLclass(fc);
}

From source file:gov.va.chir.tagline.TagLineTrainer.java

License:Open Source License

public void train(final Collection<Document> documents, final Feature... features) throws Exception {
    if (!DatasetUtil.hasLabels(documents)) {
        throw new IllegalArgumentException("All lines for training must have a label.");
    }//ww  w.j a va 2 s  .  c  om

    // Setup extractor for feature calculation
    extractor = new Extractor();

    if (features != null && features.length > 0) {
        extractor.addFeatures(features);
    } else {
        extractor.addFeatures(Extractor.getDefaultFeatures());
    }

    // Setup any features that require the entire corpus
    extractor.setupCorpusProcessors(documents);

    // Calculate features at both document and line level
    for (Document document : documents) {
        extractor.calculateFeatureValues(document);
    }

    // Create dataset
    instances = DatasetUtil.createDataset(documents);

    // Remove IDs from dataset
    final Remove remove = new Remove();

    remove.setAttributeIndicesArray(new int[] { instances.attribute(DatasetUtil.DOC_ID).index(),
            instances.attribute(DatasetUtil.LINE_ID).index() });

    final FilteredClassifier fc = new FilteredClassifier();
    fc.setFilter(remove);
    fc.setClassifier(tagLineModel.getModel());

    // Train model
    fc.buildClassifier(instances);

    tagLineModel.setModel(fc.getClassifier());
}

From source file:kea.KEAFilter.java

License:Open Source License

/**
 * Builds the classifier./*from   w  w w .j  a  v  a 2  s  .  c  o  m*/
 */
private void buildClassifier() throws Exception {

    // Generate input format for classifier
    FastVector atts = new FastVector();
    for (int i = 0; i < getInputFormat().numAttributes(); i++) {
        if (i == m_DocumentAtt) {
            atts.addElement(new Attribute("TFxIDF"));
            atts.addElement(new Attribute("First_occurrence"));
            if (m_KFused) {
                atts.addElement(new Attribute("Keyphrase_frequency"));
            }
        } else if (i == m_KeyphrasesAtt) {
            FastVector vals = new FastVector(2);
            vals.addElement("False");
            vals.addElement("True");
            atts.addElement(new Attribute("Keyphrase?", vals));
        }
    }
    m_ClassifierData = new Instances("ClassifierData", atts, 0);
    m_ClassifierData.setClassIndex(m_NumFeatures);

    if (m_Debug) {
        System.err.println("--- Converting instances for classifier");
    }

    // Convert pending input instances into data for classifier
    for (int i = 0; i < getInputFormat().numInstances(); i++) {
        Instance current = getInputFormat().instance(i);

        // Get the key phrases for the document
        String keyphrases = current.stringValue(m_KeyphrasesAtt);
        HashMap hashKeyphrases = getGivenKeyphrases(keyphrases, false);
        HashMap hashKeysEval = getGivenKeyphrases(keyphrases, true);

        // Get the phrases for the document
        HashMap hash = new HashMap();
        int length = getPhrases(hash, current.stringValue(m_DocumentAtt));

        // Compute the feature values for each phrase and
        // add the instance to the data for the classifier
        Iterator it = hash.keySet().iterator();
        while (it.hasNext()) {
            String phrase = (String) it.next();
            FastVector phraseInfo = (FastVector) hash.get(phrase);
            double[] vals = featVals(phrase, phraseInfo, true, hashKeysEval, hashKeyphrases, length);
            Instance inst = new Instance(current.weight(), vals);
            m_ClassifierData.add(inst);
        }
    }

    if (m_Debug) {
        System.err.println("--- Building classifier");
    }

    // Build classifier
    FilteredClassifier fclass = new FilteredClassifier();
    fclass.setClassifier(new NaiveBayesSimple());
    fclass.setFilter(new Discretize());
    m_Classifier = fclass;
    m_Classifier.buildClassifier(m_ClassifierData);

    if (m_Debug) {
        System.err.println(m_Classifier);
    }

    // Save space
    m_ClassifierData = new Instances(m_ClassifierData, 0);
}

From source file:miRdup.WekaModule.java

License:Open Source License

public static void trainModel(File arff, String keyword) {
    dec.setMaximumFractionDigits(3);//from   ww w.jav  a  2  s .c  o m
    System.out.println("\nTraining model on file " + arff);
    try {
        // load data
        DataSource source = new DataSource(arff.toString());
        Instances data = source.getDataSet();
        if (data.classIndex() == -1) {
            data.setClassIndex(data.numAttributes() - 1);
        }

        PrintWriter pwout = new PrintWriter(new FileWriter(keyword + Main.modelExtension + "Output"));
        PrintWriter pwroc = new PrintWriter(new FileWriter(keyword + Main.modelExtension + "roc.arff"));

        //remove ID row
        Remove rm = new Remove();
        rm.setAttributeIndices("1");
        FilteredClassifier fc = new FilteredClassifier();
        fc.setFilter(rm);

        //            // train model svm
        //            weka.classifiers.functions.LibSVM model = new weka.classifiers.functions.LibSVM();
        //            model.setOptions(weka.core.Utils.splitOptions("-S 0 -K 2 -D 3 -G 0.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.0010 -P 0.1 -B"));
        // train model MultilayerPerceptron
        //            weka.classifiers.functions.MultilayerPerceptron model = new weka.classifiers.functions.MultilayerPerceptron();
        //            model.setOptions(weka.core.Utils.splitOptions("-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a"));
        // train model Adaboost on RIPPER
        //            weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1();
        //            model.setOptions(weka.core.Utils.splitOptions("weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.rules.JRip -- -F 10 -N 2.0 -O 5 -S 1"));
        // train model Adaboost on FURIA
        //            weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1();
        //            model.setOptions(weka.core.Utils.splitOptions("weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.rules.FURIA -- -F 10 -N 2.0 -O 5 -S 1 -p 0 -s 0"));
        //train model Adaboot on J48 trees
        //             weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1();
        //             model.setOptions(
        //                     weka.core.Utils.splitOptions(
        //                     "-P 100 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -C 0.25 -M 2"));
        //train model Adaboot on Random Forest trees
        weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1();
        model.setOptions(weka.core.Utils
                .splitOptions("-P 100 -S 1 -I 10 -W weka.classifiers.trees.RandomForest -- -I 50 -K 0 -S 1"));

        if (Main.debug) {
            System.out.print("Model options: " + model.getClass().getName().trim() + " ");
        }
        System.out.print(model.getClass() + " ");
        for (String s : model.getOptions()) {
            System.out.print(s + " ");
        }

        pwout.print("Model options: " + model.getClass().getName().trim() + " ");
        for (String s : model.getOptions()) {
            pwout.print(s + " ");
        }

        //build model
        //            model.buildClassifier(data);
        fc.setClassifier(model);
        fc.buildClassifier(data);

        // cross validation 10 times on the model
        Evaluation eval = new Evaluation(data);
        //eval.crossValidateModel(model, data, 10, new Random(1));
        StringBuffer sb = new StringBuffer();
        eval.crossValidateModel(fc, data, 10, new Random(1), sb, new Range("first,last"), false);

        //System.out.println(sb);
        pwout.println(sb);
        pwout.flush();

        // output
        pwout.println("\n" + eval.toSummaryString());
        System.out.println(eval.toSummaryString());

        pwout.println(eval.toClassDetailsString());
        System.out.println(eval.toClassDetailsString());

        //calculate importants values
        String ev[] = eval.toClassDetailsString().split("\n");

        String ptmp[] = ev[3].trim().split(" ");
        String ntmp[] = ev[4].trim().split(" ");
        String avgtmp[] = ev[5].trim().split(" ");

        ArrayList<String> p = new ArrayList<String>();
        ArrayList<String> n = new ArrayList<String>();
        ArrayList<String> avg = new ArrayList<String>();

        for (String s : ptmp) {
            if (!s.trim().isEmpty()) {
                p.add(s);
            }
        }
        for (String s : ntmp) {
            if (!s.trim().isEmpty()) {
                n.add(s);
            }
        }
        for (String s : avgtmp) {
            if (!s.trim().isEmpty()) {
                avg.add(s);
            }
        }

        double tp = Double.parseDouble(p.get(0));
        double fp = Double.parseDouble(p.get(1));
        double tn = Double.parseDouble(n.get(0));
        double fn = Double.parseDouble(n.get(1));
        double auc = Double.parseDouble(avg.get(7));

        pwout.println("\nTP=" + tp + "\nFP=" + fp + "\nTN=" + tn + "\nFN=" + fn);
        System.out.println("\nTP=" + tp + "\nFP=" + fp + "\nTN=" + tn + "\nFN=" + fn);

        //specificity, sensitivity, Mathew's correlation, Prediction accuracy
        double sp = ((tn) / (tn + fp));
        double se = ((tp) / (tp + fn));
        double acc = ((tp + tn) / (tp + tn + fp + fn));
        double mcc = ((tp * tn) - (fp * fn)) / Math.sqrt((tp + fp) * (tn + fn) * (tp + fn) * tn + fp);

        String output = "\nse=" + dec.format(se).replace(",", ".") + "\nsp=" + dec.format(sp).replace(",", ".")
                + "\nACC=" + dec.format(acc).replace(",", ".") + "\nMCC=" + dec.format(mcc).replace(",", ".")
                + "\nAUC=" + dec.format(auc).replace(",", ".");

        pwout.println(output);
        System.out.println(output);

        pwout.println(eval.toMatrixString());
        System.out.println(eval.toMatrixString());

        pwout.flush();
        pwout.close();

        //Saving model
        System.out.println("Model saved: " + keyword + Main.modelExtension);
        weka.core.SerializationHelper.write(keyword + Main.modelExtension, fc.getClassifier() /*model*/);

        // get curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);
        pwroc.print(result.toString());
        pwroc.flush();
        pwroc.close();

        // draw curve
        //rocCurve(eval);
    } catch (Exception e) {
        e.printStackTrace();
    }
}

From source file:mulan.classifier.transformation.MultiLabelStacking.java

License:Open Source License

/**
 * Builds the base-level classifiers./*from www . ja va  2  s.co m*/
 * Their predictions are gathered in the baseLevelPredictions member
 * @param trainingSet 
 * @throws Exception
 */
public void buildBaseLevel(MultiLabelInstances trainingSet) throws Exception {
    train = new Instances(trainingSet.getDataSet());
    baseLevelData = new Instances[numLabels];
    baseLevelEnsemble = AbstractClassifier.makeCopies(baseClassifier, numLabels);
    if (normalize) {
        maxProb = new double[numLabels];
        minProb = new double[numLabels];
        Arrays.fill(minProb, 1);
    }
    // initialize the table holding the predictions of the first level
    // classifiers for each label for every instance of the training set
    baseLevelPredictions = new double[train.numInstances()][numLabels];

    for (int labelIndex = 0; labelIndex < numLabels; labelIndex++) {
        debug("Label: " + labelIndex);
        // transform the dataset according to the BR method
        baseLevelData[labelIndex] = BinaryRelevanceTransformation.transformInstances(train, labelIndices,
                labelIndices[labelIndex]);
        // attach indexes in order to keep track of the original positions
        baseLevelData[labelIndex] = new Instances(attachIndexes(baseLevelData[labelIndex]));
        // prepare the transformed dataset for stratified x-fold cv
        Random random = new Random(1);
        baseLevelData[labelIndex].randomize(random);
        baseLevelData[labelIndex].stratify(numFolds);
        debug("Creating meta-data");
        for (int j = 0; j < numFolds; j++) {
            debug("Label=" + labelIndex + ", Fold=" + j);
            Instances subtrain = baseLevelData[labelIndex].trainCV(numFolds, j, random);
            // create a filtered meta classifier, used to ignore
            // the index attribute in the build process
            // perform stratified x-fold cv and get predictions
            // for each class for every instance
            FilteredClassifier fil = new FilteredClassifier();
            fil.setClassifier(baseLevelEnsemble[labelIndex]);
            Remove remove = new Remove();
            remove.setAttributeIndices("first");
            remove.setInputFormat(subtrain);
            fil.setFilter(remove);
            fil.buildClassifier(subtrain);

            // Classify test instance
            Instances subtest = baseLevelData[labelIndex].testCV(numFolds, j);
            for (int i = 0; i < subtest.numInstances(); i++) {
                double distribution[] = new double[2];
                distribution = fil.distributionForInstance(subtest.instance(i));
                // Ensure correct predictions both for class values {0,1}
                // and {1,0}
                Attribute classAttribute = baseLevelData[labelIndex].classAttribute();
                baseLevelPredictions[(int) subtest.instance(i)
                        .value(0)][labelIndex] = distribution[classAttribute.indexOfValue("1")];
                if (normalize) {
                    if (distribution[classAttribute.indexOfValue("1")] > maxProb[labelIndex]) {
                        maxProb[labelIndex] = distribution[classAttribute.indexOfValue("1")];
                    }
                    if (distribution[classAttribute.indexOfValue("1")] < minProb[labelIndex]) {
                        minProb[labelIndex] = distribution[classAttribute.indexOfValue("1")];
                    }
                }
            }
        }
        // now we can detach the indexes from the first level datasets
        baseLevelData[labelIndex] = detachIndexes(baseLevelData[labelIndex]);

        debug("Building base classifier on full data");
        // build base classifier on the full training data
        baseLevelEnsemble[labelIndex].buildClassifier(baseLevelData[labelIndex]);
        baseLevelData[labelIndex].delete();
    }

    if (normalize) {
        normalizePredictions();
    }

}

From source file:nl.uva.sne.commons.ClusterUtils.java

private static FilteredClassifier buildModel(int[] indicesToRemove, int classIndex, Instances trainDataset,
        Classifier cl) throws Exception {
    FilteredClassifier model = new FilteredClassifier();
    model.setClassifier(AbstractClassifier.makeCopy(cl));
    Remove remove = new Remove();
    remove.setAttributeIndicesArray(indicesToRemove);
    remove.setInputFormat(trainDataset);
    remove.setInvertSelection(false);/* w w  w.j av a  2  s . c om*/
    model.setFilter(remove);
    trainDataset.setClassIndex(classIndex);
    model.buildClassifier(trainDataset);
    //        int foldHash = trainDataset.toString().hashCode();
    //        String modelKey = createKey(indicesToRemove, foldHash);
    //        existingModels.put(modelKey, model);
    return model;
}