tr.gov.ulakbim.jDenetX.classifiers.CoOzaBagASHT.java Source code

Java tutorial

Introduction

Here is the source code for tr.gov.ulakbim.jDenetX.classifiers.CoOzaBagASHT.java

Source

/*
 *    OzaBagASHT.java
 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
 *    @author Caglar
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

package tr.gov.ulakbim.jDenetX.classifiers;

import tr.gov.ulakbim.jDenetX.core.DoubleVector;
import tr.gov.ulakbim.jDenetX.core.MiscUtils;
import tr.gov.ulakbim.jDenetX.core.VotedInstancePool;
import tr.gov.ulakbim.jDenetX.options.FlagOption;
import tr.gov.ulakbim.jDenetX.options.IntOption;
import weka.core.Instance;
import weka.core.Utils;

import java.util.ArrayList;

public class CoOzaBagASHT extends OzaBag {

    private static final long serialVersionUID = 1L;

    public IntOption firstClassifierSizeOption = new IntOption("firstClassifierSize", 'f',
            "The size of first classifier in the bag.", 1, 1, Integer.MAX_VALUE);

    public FlagOption useWeightOption = new FlagOption("useWeight", 'u', "Enable weight classifiers.");

    public FlagOption resetTreesOption = new FlagOption("resetTrees", 'r',
            "Reset trees when size is higher than the max.");

    protected double[] error;

    protected ArrayList<Instance> centroids;

    protected double alpha = 0.01;

    private static VotedInstancePool instConfPool = new VotedInstancePool();

    public static int instConfCount = 0;

    private final static double confidenceThreshold = 9.7;

    @Override
    public void resetLearningImpl() {
        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        this.error = new double[this.ensembleSizeOption.getValue()];
        instConfPool = new VotedInstancePool();
        instConfCount = 0;
        Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        baseLearner.resetLearning();
        int pow = this.firstClassifierSizeOption.getValue(); // EXTENSION TO ASHT
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i] = baseLearner.copy();
            this.error[i] = 0.0;
            ((ASHoeffdingOptionTree) this.ensemble[i]).setMaxSize(pow); // EXTENSION TO ASHT
            if ((this.resetTreesOption != null) && this.resetTreesOption.isSet()) {
                ((ASHoeffdingOptionTree) this.ensemble[i]).setResetTree();
            }
            pow *= 2; // EXTENSION TO ASHT
        }
    }

    public double getEntropyForArray(double votes[]) {
        double entropy = 0.0;
        for (int i = 0; i < votes.length; i++) {
            votes[i] -= votes[i] * (Math.log(votes[i]) / Math.log(2));
            // By Default Java computes Math.log for base e, to compute base 2 we should divide by log(2)
        }
        return entropy;
    }

    public double getQBCEntropy(double vote, int success) {
        double entropy = 0.0;
        entropy -= (vote / success) * Utils.log2(vote / success);
        //(Math.log(vote) / Math.log(2));
        // Default Java log function computes
        // Math.log for base e, to compute base 2 we
        // should divide by log(2)
        return entropy;
    }

    /**
     * Query By Comittee algorithm
     * This measures the vote entropy.
     * xve = argmax-Sigma (V(yi)/C)*log(V(yi)/C)
     * xve is the vote entropy.
     * C is the comittee size
     * V(yi) is the number of the votes that a label recieves among the comittee members' votes.
     */
    public double queryByCommitee(double[] ensembleVotes, int noOfClasses, int success) {
        double entropyQBC = 0.0;
        double qbc = 0.0;
        if (noOfClasses != 0) {
            for (int j = 0; j < noOfClasses; j++) {
                if (ensembleVotes[j] != 0) {
                    qbc = (double) ensembleVotes[j] / ((double) ensemble.length);
                    //System.out.println("qbc is : " + qbc);
                    entropyQBC -= getQBCEntropy(qbc, success);
                }
            }
        }
        return entropyQBC;
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        int trueClass = (int) inst.classValue();
        for (int i = 0; i < this.ensemble.length; i++) {
            int k = MiscUtils.poisson(1.0, this.classifierRandom);
            if (k > 0) {
                Instance weightedInst = (Instance) inst.copy();
                weightedInst.setWeight(inst.weight() * k);
                if (Utils.maxIndex(this.ensemble[i].getVotesForInstance(inst)) == trueClass) { // Here we used the getVotesForInstanceFunction of HoeffdingTree
                    this.error[i] += alpha * (0.0 - this.error[i]); // EWMA
                } else {
                    this.error[i] += alpha * (1.0 - this.error[i]); // EWMA
                }
                this.ensemble[i].trainOnInstance(weightedInst);
            }
        }
    }

    /**
     * This is the main classification function that is used by the GUI
     */
    public double[] getVotesForInstance(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();
        DoubleVector confidenceVec = new DoubleVector();
        double[] ensembleVotes = new double[inst.numClasses()];
        double qbcEntropy = 0.0;
        int success = 0;
        int alpha1 = 1;
        int alpha2 = 1;
        for (int i = 0; i < this.ensemble.length; i++) {
            DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
            if (vote.sumOfValues() > 0.0) {
                vote.normalize();
                confidenceVec.addValues(vote);
                if ((this.useWeightOption != null) && this.useWeightOption.isSet()) {
                    vote.scaleValues(1.0 / (this.error[i] * this.error[i]));
                    //System.out.println("Ensemble : " + i + " Error: " + this.error[i]);
                }
                combinedVote.addValues(vote);
            }
            //
            //Ignore the classifiers which have high error ratio
            //
            if (this.error[i] < 0.23) {
                //
                // this is the votes of the ensembles for the classes
                //
                success++;
                ensembleVotes[combinedVote.maxIndex()] += combinedVote.getValue(combinedVote.maxIndex());
            }
        }
        //For confidence measure add to the pool  and in order to fit the confidence value between 0 and 1 divide by success val
        //System.out.println("Confidence " + combinedVote.getValue(combinedVote.maxIndex()));
        if ((confidenceVec.getValue(combinedVote.maxIndex())) >= confidenceThreshold) {
            qbcEntropy = queryByCommitee(ensembleVotes, inst.numClasses(), success);
            double activeLearningRatio = (qbcEntropy)
                    * (combinedVote.getValue(combinedVote.maxIndex()) / this.ensemble.length);
            inst.setClassValue(combinedVote.maxIndex()); //Set the class value of the instance
            instConfPool.addVotedInstance(inst, combinedVote.getValue(combinedVote.maxIndex()),
                    activeLearningRatio);
            instConfCount++;
        }
        return combinedVote.getArrayRef();
    }

    /**
     * This is the main classification function that is used by the GUI
     */
    public double[] getVotesForInstanceOrig(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();
        double[] ensembleVotes = new double[inst.numClasses()];
        double qbcEntropy = 0.0;
        int success = 0;

        for (int i = 0; i < this.ensemble.length; i++) {
            DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
            // This will call the HoeffdingTree's getVotesForInstance Function
            if (vote.sumOfValues() > 0.0) {
                vote.normalize();
                if ((this.useWeightOption != null) && this.useWeightOption.isSet()) {
                    vote.scaleValues(1.0 / (this.error[i] * this.error[i]));
                    System.out.println("Ensemble : " + i + " Error: " + this.error[i]);
                }
                //
                //Ignore the ensembles which have high error ratio
                //
                if (this.error[i] < 0.3) {
                    combinedVote.addValues(vote);
                }
            }
            //
            // this is the votes of the ensembles for the classes
            //
            if (this.error[i] < 0.3) {
                success++;
                ensembleVotes[combinedVote.maxIndex()] += combinedVote.getValue(combinedVote.maxIndex());
            }
        }
        // For confidence measure add to the pool  and in order to fit the confidence value between 0 and 1 divide by success val

        if ((combinedVote.getValue(combinedVote.maxIndex()) / success) >= confidenceThreshold) {
            qbcEntropy = queryByCommitee(ensembleVotes, inst.numClasses(), 0);
            System.out.println("QBC Entropy: " + qbcEntropy);
            double activeLearningRatio = (qbcEntropy)
                    + (combinedVote.getValue(combinedVote.maxIndex()) / this.ensemble.length);
            inst.setClassValue(combinedVote.maxIndex());
            instConfPool.addVotedInstance(inst, combinedVote.getValue(combinedVote.maxIndex()),
                    activeLearningRatio);
        }
        return combinedVote.getArrayRef();
    }

    public static VotedInstancePool getVotedInstancePool() {
        return instConfPool;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
        // TODO Auto-generated method stub
        super.getModelDescription(out, indent);
    }
}