tr.gov.ulakbim.jDenetX.classifiers.OzaBagASHOT.java Source code

Java tutorial

Introduction

Here is the source code for tr.gov.ulakbim.jDenetX.classifiers.OzaBagASHOT.java

Source

/*
 *    OzaBagASHT.java
 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
 *    @author Albert Bifet
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
package tr.gov.ulakbim.jDenetX.classifiers;

import tr.gov.ulakbim.jDenetX.core.DoubleVector;
import tr.gov.ulakbim.jDenetX.core.MiscUtils;
import tr.gov.ulakbim.jDenetX.options.FlagOption;
import tr.gov.ulakbim.jDenetX.options.IntOption;
import weka.core.Instance;
import weka.core.Utils;

public class OzaBagASHOT extends OzaBag {

    private static final long serialVersionUID = 1L;

    public IntOption firstClassifierSizeOption = new IntOption("firstClassifierSize", 'f',
            "The size of first classifier in the bag.", 1, 1, Integer.MAX_VALUE);

    public FlagOption useWeightOption = new FlagOption("useWeight", 'u', "Enable weight classifiers.");

    public FlagOption resetTreesOption = new FlagOption("resetTrees", 'r',
            "Reset trees when size is higher than the max.");

    protected double[] error;

    protected double alpha = 0.01;

    /**
     * Initializes the variables and components of the algorithm
     */
    @Override
    public void resetLearningImpl() {
        System.out.println("Learning is resetted");
        StackTraceElement[] stackTraceElements = Thread.currentThread().getStackTrace();
        System.out.println("Caller method name " + stackTraceElements[2].getMethodName());

        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        this.error = new double[this.ensembleSizeOption.getValue()];
        Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        baseLearner.resetLearning();
        int pow = this.firstClassifierSizeOption.getValue(); //EXTENSION TO ASHT
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i] = baseLearner.copy();
            this.error[i] = 0.0;
            ((ASHoeffdingOptionTree) this.ensemble[i]).setMaxSize(pow); //EXTENSION TO ASHT
            if ((this.resetTreesOption != null) && this.resetTreesOption.isSet()) {
                ((ASHoeffdingOptionTree) this.ensemble[i]).setResetTree();
            }
            pow *= 2; //EXTENSION TO ASHT
        }
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        final int trueClass = (int) inst.classValue();
        //System.out.println("Ensemble Length " + this.ensemble.length);
        for (int i = 0; i < this.ensemble.length; i++) {
            final int k = MiscUtils.poisson(1.0, this.classifierRandom);
            if (k > 0) {
                final Instance weightedInst = (Instance) inst.copy();
                weightedInst.setWeight(inst.weight() * k);
                if (Utils.maxIndex(this.ensemble[i].getVotesForInstance(inst)) == trueClass) {
                    this.error[i] += alpha * (0.0 - this.error[i]); //EWMA
                } else {
                    this.error[i] += alpha * (1.0 - this.error[i]); //EWMA
                }
                this.ensemble[i].trainOnInstance(weightedInst);
            }
            //   System.out.println("ClassifierRandom: " + k);
            //System.out.println("EWMA Error Ensemble "+i+" "+ this.error[i]);
            if (this.error[i] > 0.6) {
                System.out.println("Error is " + this.error[i]);
                System.out.println("Ensemble " + i);
                System.err.println("Warning!!!!!");
            }
        }
    }

    public double[] getVotesForInstance(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();
        for (int i = 0; i < this.ensemble.length; i++) {
            DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
            if (vote.sumOfValues() > 0.0) {
                vote.normalize();
                if ((this.useWeightOption != null) && this.useWeightOption.isSet()) {
                    vote.scaleValues(1.0 / (this.error[i] * this.error[i]));
                }
                combinedVote.addValues(vote);
            }
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
        // TODO Auto-generated method stub
    }
}