jjj.asap.sas.ensemble.impl.StackedClassifier.java Source code

Java tutorial


Here is the source code for jjj.asap.sas.ensemble.impl.StackedClassifier.java


 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * GNU General Public License for more details.
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.

 * Copyright (C) 2012 James Jesensky

package jjj.asap.sas.ensemble.impl;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import jjj.asap.sas.ensemble.Ensemble;
import jjj.asap.sas.ensemble.StrongLearner;
import jjj.asap.sas.ensemble.WeakLearner;
import jjj.asap.sas.util.Calc;
import jjj.asap.sas.util.Contest;
import jjj.asap.sas.weka.DatasetBuilder;
import jjj.asap.sas.weka.Model;
import jjj.asap.sas.weka.Weka;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.Utils;

 * Implements stacking using labels only (that is, it doesn't consider
 * class support)
public class StackedClassifier implements Ensemble {

    private boolean useNumericVariables;
    private Classifier prototype;

    public StackedClassifier(boolean useNumericVariables, Classifier prototype) {
        this.useNumericVariables = useNumericVariables;
        this.prototype = prototype;

    public StrongLearner build(int essaySet, String ensembleName, List<WeakLearner> learners) {

        if (learners.isEmpty()) {
            return StrongLearner.NO_MODEL[essaySet - 1];

        StrongLearner strong = new StrongLearner();

        // training
        try {

            Instances metaData = getMetaDataset(essaySet, learners);

            // hack
            //Instances hack = getMetaDataset(essaySet, learners);
            //Dataset.save("etc/stacking" + essaySet + ".arff", hack);
            // end hack

            Classifier metaClassifier = AbstractClassifier.makeCopy(prototype);

            Weka.trainClassifier(metaData, metaClassifier);
            Map<Double, double[]> probs = Weka.classifyInstances(metaData, metaClassifier);
            Map<Double, Double> preds = Model.getPredictions(essaySet, probs);
            double kappa = Calc.kappa(essaySet, preds, Contest.getGoldStandard(essaySet));

            strong.setLearners(new ArrayList<WeakLearner>(learners));

        } catch (Exception e) {
            throw new RuntimeException(e);

        return strong;

    public Map<Double, Double> classify(int essaySet, String ensembleName, List<WeakLearner> learners,
            Object context) {

        if (learners.isEmpty()) {
            return StrongLearner.NO_MODEL[essaySet - 1].getPreds();

        try {
            Map<Double, double[]> probs = Weka.classifyInstances(getMetaDataset(essaySet, learners),
                    (Classifier) context);

            return Model.getPredictions(essaySet, probs);

        } catch (Exception e) {
            throw new RuntimeException(e);

     * Returns a dataset representing the learners.
    private Instances getMetaDataset(int essaySet, List<WeakLearner> learners) {

        // create dataset headers

        DatasetBuilder builder = new DatasetBuilder();
        for (int i = 0; i < learners.size(); i++) {
            if (useNumericVariables) {
                builder.addVariable("x" + i);
            } else {
                builder.addNominalVariable("x" + i, Contest.getRubrics(essaySet));
        builder.addNominalVariable("score", Contest.getRubrics(essaySet));

        Instances dataset = builder.getDataset(this.getClass().getCanonicalName());
        Map<Double, Double> labels = Contest.getGoldStandard(essaySet);

        // now add the data
        for (double id : learners.get(0).getPreds().keySet()) {

            double[] data = new double[dataset.numAttributes()];
            data[0] = id;

            for (int i = 0; i < learners.size(); i++) {
                data[i + 1] = learners.get(i).getPreds().get(id);

            data[dataset.numAttributes() - 1] = labels.containsKey(id) ? labels.get(id) : Utils.missingValue();

            dataset.add(new DenseInstance(1.0, data));

        return dataset;
