j48.GraftSplit.java Source code

Java tutorial

Introduction

Here is the source code for j48.GraftSplit.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *  GraftSplit.java
 *  Copyright (C) 2007 Geoff Webb & Janice Boughton
 *  a split object for nodes added to a tree during grafting.
 *  (used in classifier J48g).
 */

package j48;

import weka.core.*;

/**
 * Class implementing a split for nodes added to a tree during grafting.
 *
 * @author Janice Boughton (jrbought@infotech.monash.edu.au)
 * @version $Revision 1.0 $
 */
public class GraftSplit extends ClassifierSplitModel implements Comparable {

    /** for serialzation. */
    private static final long serialVersionUID = 722773260393182051L;

    /** the distribution for graft values, from cases in atbop */
    private Distribution m_graftdistro;

    /** the attribute we are splitting on */
    private int m_attIndex;

    /** value of split point (if numeric attribute) */
    private double m_splitPoint;

    /** dominant class of the subset specified by m_testType */
    private int m_maxClass;

    /** dominant class of the subset not specified by m_testType */
    private int m_otherLeafMaxClass;

    /** laplace value of the subset specified by m_testType for m_maxClass */
    private double m_laplace;

    /** leaf for the subset specified by m_testType */
    private Distribution m_leafdistro;

    /** 
     * type of test:
     * 0: <= test
     * 1: > test
     * 2: = test
     * 3: != test
     */
    private int m_testType;

    /**
     * constructor
     *
     * @param a the attribute to split on
     * @param v the value of a where split occurs
     * @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !)
     * @param c the class to label the leaf node pointed to by test as.
     * @param l the laplace value (needed when sorting GraftSplits)
     */
    public GraftSplit(int a, double v, int t, double c, double l) {

        m_attIndex = a;
        m_splitPoint = v;
        m_testType = t;
        m_maxClass = (int) c;
        m_laplace = l;
    }

    /**
     * constructor
     *
     * @param a the attribute to split on
     * @param v the value of a where split occurs
     * @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !=)
     * @param oC the class to label the leaf node not pointed to by test as.
     * @param counts the distribution for this split
     */
    public GraftSplit(int a, double v, int t, double oC, double[][] counts) throws Exception {
        m_attIndex = a;
        m_splitPoint = v;
        m_testType = t;
        m_otherLeafMaxClass = (int) oC;

        // only deal with binary cuts (<= and >; = and !=)
        m_numSubsets = 2;

        // which subset are we looking at for the graft?
        int subset = subsetOfInterest(); // this is the subset for m_leaf

        // create graft distribution, based on counts
        m_distribution = new Distribution(counts);

        // create a distribution object for m_leaf
        double[][] lcounts = new double[1][m_distribution.numClasses()];
        for (int c = 0; c < lcounts[0].length; c++) {
            lcounts[0][c] = counts[subset][c];
        }
        m_leafdistro = new Distribution(lcounts);

        // set the max class
        m_maxClass = m_distribution.maxClass(subset);

        // set the laplace value (assumes binary class) for subset of interest
        m_laplace = (m_distribution.perClassPerBag(subset, m_maxClass) + 1.0)
                / (m_distribution.perBag(subset) + 2.0);
    }

    /**
     * deletes the cases in data that belong to leaf pointed to by
     * the test (i.e. the subset of interest).  this is useful so
     * the instances belonging to that leaf aren't passed down the
     * other branch.
     *
     * @param data the instances to delete from
     */
    public void deleteGraftedCases(Instances data) {

        int subOfInterest = subsetOfInterest();
        for (int x = 0; x < data.numInstances(); x++) {
            if (whichSubset(data.instance(x)) == subOfInterest) {
                data.delete(x--);
            }
        }
    }

    /**
     * builds m_graftdistro using the passed data
     *
     * @param data the instances to use when creating the distribution
     */
    public void buildClassifier(Instances data) throws Exception {

        // distribution for the graft, not counting cases in atbop, only orig leaf
        m_graftdistro = new Distribution(2, data.numClasses());

        // which subset are we looking at for the graft?
        int subset = subsetOfInterest(); // this is the subset for m_leaf

        double thisNodeCount = 0;
        double knownCases = 0;
        boolean allKnown = true;
        // populate distribution
        for (int x = 0; x < data.numInstances(); x++) {
            Instance instance = data.instance(x);
            if (instance.isMissing(m_attIndex)) {
                allKnown = false;
                continue;
            }
            knownCases += instance.weight();
            int subst = whichSubset(instance);
            if (subst == -1)
                continue;
            m_graftdistro.add(subst, instance);
            if (subst == subset) { // instance belongs at m_leaf
                thisNodeCount += instance.weight();
            }
        }
        double factor = (knownCases == 0) ? (1.0 / (double) 2.0) : (thisNodeCount / knownCases);
        if (!allKnown) {
            for (int x = 0; x < data.numInstances(); x++) {
                if (data.instance(x).isMissing(m_attIndex)) {
                    Instance instance = data.instance(x);
                    int subst = whichSubset(instance);
                    if (subst == -1)
                        continue;
                    instance.setWeight(instance.weight() * factor);
                    m_graftdistro.add(subst, instance);
                }
            }
        }

        // if there are no cases at the leaf, make sure the desired
        // class is chosen, by setting counts to 0.01
        if (m_graftdistro.perBag(subset) == 0) {
            double[] counts = new double[data.numClasses()];
            counts[m_maxClass] = 0.01;
            m_graftdistro.add(subset, counts);
        }
        if (m_graftdistro.perBag((subset == 0) ? 1 : 0) == 0) {
            double[] counts = new double[data.numClasses()];
            counts[(int) m_otherLeafMaxClass] = 0.01;
            m_graftdistro.add((subset == 0) ? 1 : 0, counts);
        }
    }

    /**
     * @return the NoSplit object for the leaf pointed to by m_testType branch
     */
    public NoSplit getLeaf() {
        return new NoSplit(m_leafdistro);
    }

    /**
     * @return the NoSplit object for the leaf not pointed to by m_testType branch
     */
    public NoSplit getOtherLeaf() {

        // the bag (subset) that isn't pointed to by m_testType branch
        int bag = (subsetOfInterest() == 0) ? 1 : 0;

        double[][] counts = new double[1][m_graftdistro.numClasses()];
        double totals = 0;
        for (int c = 0; c < counts[0].length; c++) {
            counts[0][c] = m_graftdistro.perClassPerBag(bag, c);
            totals += counts[0][c];
        }
        // if empty, make sure proper class gets chosen
        if (totals == 0) {
            counts[0][m_otherLeafMaxClass] += 0.01;
        }
        return new NoSplit(new Distribution(counts));
    }

    /**
     * Prints label for subset index of instances (eg class).
     *
     * @param index the bag to dump label for
     * @param data to get attribute names and such
     * @return the label as a string
     * @exception Exception if something goes wrong
     */
    public final String dumpLabelG(int index, Instances data) throws Exception {

        StringBuffer text;

        text = new StringBuffer();
        text.append(((Instances) data).classAttribute()
                .value((index == subsetOfInterest()) ? m_maxClass : m_otherLeafMaxClass));
        text.append(" (" + Utils.roundDouble(m_graftdistro.perBag(index), 1));
        if (Utils.gr(m_graftdistro.numIncorrect(index), 0))
            text.append("/" + Utils.roundDouble(m_graftdistro.numIncorrect(index), 2));

        // show the graft values, only if this is subsetOfInterest()
        if (index == subsetOfInterest()) {
            text.append("|" + Utils.roundDouble(m_distribution.perBag(index), 2));
            if (Utils.gr(m_distribution.numIncorrect(index), 0))
                text.append("/" + Utils.roundDouble(m_distribution.numIncorrect(index), 2));
        }
        text.append(")");
        return text.toString();
    }

    /**
     * @return the subset that is specified by the test type
     */
    public int subsetOfInterest() {
        if (m_testType == 2)
            return 0;
        if (m_testType == 3)
            return 1;
        return m_testType;
    }

    /**
     * @return the number of positive cases in the subset of interest
     */
    public double positivesForSubsetOfInterest() {
        return (m_distribution.perClassPerBag(subsetOfInterest(), m_maxClass));
    }

    /**
     * @param subset the subset to get the positives for
     * @return the number of positive cases in the specified subset
     */
    public double positives(int subset) {
        return (m_distribution.perClassPerBag(subset, m_distribution.maxClass(subset)));
    }

    /**
     * @return the number of instances in the subset of interest
     */
    public double totalForSubsetOfInterest() {
        return (m_distribution.perBag(subsetOfInterest()));
    }

    /**
     * @param subset the index of the bag to get the total for
     * @return the number of instances in the subset
     */
    public double totalForSubset(int subset) {
        return (m_distribution.perBag(subset));
    }

    /**
     * Prints left side of condition satisfied by instances.
     *
     * @param data the data.
     */
    public String leftSide(Instances data) {
        return data.attribute(m_attIndex).name();
    }

    /**
     * @return the index of the attribute to split on
     */
    public int attribute() {
        return m_attIndex;
    }

    /**
     * Prints condition satisfied by instances in subset index.
     */
    public final String rightSide(int index, Instances data) {

        StringBuffer text;

        text = new StringBuffer();
        if (data.attribute(m_attIndex).isNominal())
            if (index == 0)
                text.append(" = " + data.attribute(m_attIndex).value((int) m_splitPoint));
            else
                text.append(" != " + data.attribute(m_attIndex).value((int) m_splitPoint));
        else if (index == 0)
            text.append(" <= " + Utils.doubleToString(m_splitPoint, 6));
        else
            text.append(" > " + Utils.doubleToString(m_splitPoint, 6));
        return text.toString();
    }

    /**
     * Returns a string containing java source code equivalent to the test
     * made at this node. The instance being tested is called "i".
     *
     * @param index index of the nominal value tested
     * @param data the data containing instance structure info
     * @return a value of type 'String'
     */
    public final String sourceExpression(int index, Instances data) {

        StringBuffer expr = null;
        if (index < 0) {
            return "i[" + m_attIndex + "] == null";
        }
        if (data.attribute(m_attIndex).isNominal()) {
            if (index == 0)
                expr = new StringBuffer("i[");
            else
                expr = new StringBuffer("!i[");
            expr.append(m_attIndex).append("]");
            expr.append(".equals(\"").append(data.attribute(m_attIndex).value((int) m_splitPoint)).append("\")");
        } else {
            expr = new StringBuffer("((Double) i[");
            expr.append(m_attIndex).append("])");
            if (index == 0) {
                expr.append(".doubleValue() <= ").append(m_splitPoint);
            } else {
                expr.append(".doubleValue() > ").append(m_splitPoint);
            }
        }
        return expr.toString();
    }

    /**
     * @param instance the instance to produce the weights for
     * @return a double array of weights, null if only belongs to one subset
     */
    public double[] weights(Instance instance) {

        double[] weights;
        int i;

        if (instance.isMissing(m_attIndex)) {
            weights = new double[m_numSubsets];
            for (i = 0; i < m_numSubsets; i++) {
                weights[i] = m_graftdistro.perBag(i) / m_graftdistro.total();
            }
            return weights;
        } else {
            return null;
        }
    }

    /**
     * @param instance the instance for which to determine the subset
     * @return an int indicating the subset this instance belongs to
     */
    public int whichSubset(Instance instance) {

        if (instance.isMissing(m_attIndex))
            return -1;

        if (instance.attribute(m_attIndex).isNominal()) {
            // in the case of nominal, m_splitPoint is the = value, all else is !=
            if (instance.value(m_attIndex) == m_splitPoint)
                return 0;
            else
                return 1;
        } else {
            if (Utils.smOrEq(instance.value(m_attIndex), m_splitPoint))
                return 0;
            else
                return 1;
        }
    }

    /**
     * @return the value of the split point
     */
    public double splitPoint() {
        return m_splitPoint;
    }

    /**
     * @return the dominate class for the subset of interest
     */
    public int maxClassForSubsetOfInterest() {
        return m_maxClass;
    }

    /**
     * @return the laplace value for maxClass of subset of interest
     */
    public double laplaceForSubsetOfInterest() {
        return m_laplace;
    }

    /**
     * returns the test type
     * @return value of testtype
     */
    public int testType() {
        return m_testType;
    }

    /**
     * method needed for sorting a collection of GraftSplits by laplace value
     * @param g the graft split to compare to this one
     * @return -1, 0, or 1 if this GraftSplit laplace is <, = or > than that of g
     */
    public int compareTo(Object g) {

        if (m_laplace > ((GraftSplit) g).laplaceForSubsetOfInterest())
            return 1;
        if (m_laplace < ((GraftSplit) g).laplaceForSubsetOfInterest())
            return -1;
        return 0;
    }

    /**
     * returns the probability for instance for the specified class
     * @param classIndex the index of the class
     * @param instance the instance to get the probability for
     * @param theSubset the subset
     */
    public final double classProb(int classIndex, Instance instance, int theSubset) throws Exception {

        if (theSubset <= -1) {
            double[] weights = weights(instance);
            if (weights == null) {
                return m_distribution.prob(classIndex);
            } else {
                double prob = 0;
                for (int i = 0; i < weights.length; i++) {
                    prob += weights[i] * m_distribution.prob(classIndex, i);
                }
                return prob;
            }
        } else {
            if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
                return m_distribution.prob(classIndex, theSubset);
            } else {
                return m_distribution.prob(classIndex);
            }
        }
    }

    /**
     * method for returning information about this GraftSplit
     * @param data instances for determining names of attributes and values
     * @return a string showing this GraftSplit's information
     */
    public String toString(Instances data) {

        String theTest;
        if (m_testType == 0)
            theTest = " <= ";
        else if (m_testType == 1)
            theTest = " > ";
        else if (m_testType == 2)
            theTest = " = ";
        else
            theTest = " != ";

        if (data.attribute(m_attIndex).isNominal())
            theTest += data.attribute(m_attIndex).value((int) m_splitPoint);
        else
            theTest += Double.toString(m_splitPoint);

        return data.attribute(m_attIndex).name() + theTest + " (" + Double.toString(m_laplace) + ") --> "
                + data.attribute(data.classIndex()).value(m_maxClass);
    }

    /**
     * Returns the revision string.
     * 
     * @return      the revision
     */
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.2 $");
    }
}