weka.classifiers.trees.SimpleCart.java Source code

Java tutorial

Introduction

Here is the source code for weka.classifiers.trees.SimpleCart.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * SimpleCart.java
 * Copyright (C) 2007 Haijian Shi
 *
 */

package weka.classifiers.trees;

import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableClassifier;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.matrix.Matrix;

/**
 * <!-- globalinfo-start --> Class implementing minimal cost-complexity pruning.<br/>
 * Note when dealing with missing values, use "fractional instances" method
 * instead of surrogate split method.<br/>
 * <br/>
 * For more information, see:<br/>
 * <br/>
 * Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984).
 * Classification and Regression Trees. Wadsworth International Group, Belmont,
 * California.
 * <p/>
 * <!-- globalinfo-end -->
 * 
 * <!-- technical-bibtex-start --> BibTeX:
 * 
 * <pre>
 * &#64;book{Breiman1984,
 *    address = {Belmont, California},
 *    author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone},
 *    publisher = {Wadsworth International Group},
 *    title = {Classification and Regression Trees},
 *    year = {1984}
 * }
 * </pre>
 * <p/>
 * <!-- technical-bibtex-end -->
 * 
 * <!-- options-start --> Valid options are:
 * <p/>
 * 
 * <pre>
 * -S &lt;num&gt;
 *  Random number seed.
 *  (default 1)
 * </pre>
 * 
 * <pre>
 * -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * </pre>
 * 
 * <pre>
 * -M &lt;min no&gt;
 *  The minimal number of instances at the terminal nodes.
 *  (default 2)
 * </pre>
 * 
 * <pre>
 * -N &lt;num folds&gt;
 *  The number of folds used in the minimal cost-complexity pruning.
 *  (default 5)
 * </pre>
 * 
 * <pre>
 * -U
 *  Don't use the minimal cost-complexity pruning.
 *  (default yes).
 * </pre>
 * 
 * <pre>
 * -H
 *  Don't use the heuristic method for binary split.
 *  (default true).
 * </pre>
 * 
 * <pre>
 * -A
 *  Use 1 SE rule to make pruning decision.
 *  (default no).
 * </pre>
 * 
 * <pre>
 * -C
 *  Percentage of training data size (0-1].
 *  (default 1).
 * </pre>
 * 
 * <!-- options-end -->
 * 
 * @author Haijian Shi (hs69@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class SimpleCart extends RandomizableClassifier
        implements AdditionalMeasureProducer, TechnicalInformationHandler {

    /** For serialization. */
    private static final long serialVersionUID = 4154189200352566053L;

    /** Training data. */
    protected Instances m_train;

    /** Successor nodes. */
    protected SimpleCart[] m_Successors;

    /** Attribute used to split data. */
    protected Attribute m_Attribute;

    /** Split point for a numeric attribute. */
    protected double m_SplitValue;

    /** Split subset used to split data for nominal attributes. */
    protected String m_SplitString;

    /** Class value if the node is leaf. */
    protected double m_ClassValue;

    /** Class attriubte of data. */
    protected Attribute m_ClassAttribute;

    /** Minimum number of instances in at the terminal nodes. */
    protected double m_minNumObj = 2;

    /** Number of folds for minimal cost-complexity pruning. */
    protected int m_numFoldsPruning = 5;

    /** Alpha-value (for pruning) at the node. */
    protected double m_Alpha;

    /** Number of training examples misclassified by the model (subtree rooted). */
    protected double m_numIncorrectModel;

    /**
     * Number of training examples misclassified by the model (subtree not
     * rooted).
     */
    protected double m_numIncorrectTree;

    /** Indicate if the node is a leaf node. */
    protected boolean m_isLeaf;

    /** If use minimal cost-compexity pruning. */
    protected boolean m_Prune = true;

    /** Total number of instances used to build the classifier. */
    protected int m_totalTrainInstances;

    /** Proportion for each branch. */
    protected double[] m_Props;

    /** Class probabilities. */
    protected double[] m_ClassProbs = null;

    /**
     * Distributions of leaf node (or temporary leaf node in minimal
     * cost-complexity pruning)
     */
    protected double[] m_Distribution;

    /**
     * If use huristic search for nominal attributes in multi-class problems
     * (default true).
     */
    protected boolean m_Heuristic = true;

    /** If use the 1SE rule to make final decision tree. */
    protected boolean m_UseOneSE = false;

    /** Training data size. */
    protected double m_SizePer = 1;

    /**
     * Return a description suitable for displaying in the explorer/experimenter.
     * 
     * @return a description suitable for displaying in the explorer/experimenter
     */
    public String globalInfo() {
        return "Class implementing minimal cost-complexity pruning.\n"
                + "Note when dealing with missing values, use \"fractional "
                + "instances\" method instead of surrogate split method.\n\n" + "For more information, see:\n\n"
                + getTechnicalInformation().toString();
    }

    /**
     * Returns an instance of a TechnicalInformation object, containing detailed
     * information about the technical background of this class, e.g., paper
     * reference or book this class is based on.
     * 
     * @return the technical information about this class
     */
    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result;

        result = new TechnicalInformation(Type.BOOK);
        result.setValue(Field.AUTHOR,
                "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");
        result.setValue(Field.YEAR, "1984");
        result.setValue(Field.TITLE, "Classification and Regression Trees");
        result.setValue(Field.PUBLISHER, "Wadsworth International Group");
        result.setValue(Field.ADDRESS, "Belmont, California");

        return result;
    }

    /**
     * Returns default capabilities of the classifier.
     * 
     * @return the capabilities of this classifier
     */
    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();

        // attributes
        result.enable(Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capability.MISSING_VALUES);

        // class
        result.enable(Capability.NOMINAL_CLASS);

        return result;
    }

    /**
     * Build the classifier.
     * 
     * @param data the training instances
     * @throws Exception if something goes wrong
     */
    @Override
    public void buildClassifier(Instances data) throws Exception {

        getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();

        // unpruned CART decision tree
        if (!m_Prune) {

            // calculate sorted indices and weights, and compute initial class counts.
            int[][] sortedIndices = new int[data.numAttributes()][0];
            double[][] weights = new double[data.numAttributes()][0];
            double[] classProbs = new double[data.numClasses()];
            double totalWeight = computeSortedInfo(data, sortedIndices, weights, classProbs);

            makeTree(data, data.numInstances(), sortedIndices, weights, classProbs, totalWeight, m_minNumObj,
                    m_Heuristic);
            return;
        }

        Random random = new Random(m_Seed);
        Instances cvData = new Instances(data);
        cvData.randomize(random);
        cvData = new Instances(cvData, 0, (int) (cvData.numInstances() * m_SizePer) - 1);
        cvData.stratify(m_numFoldsPruning);

        double[][] alphas = new double[m_numFoldsPruning][];
        double[][] errors = new double[m_numFoldsPruning][];

        // calculate errors and alphas for each fold
        for (int i = 0; i < m_numFoldsPruning; i++) {

            // for every fold, grow tree on training set and fix error on test set.
            Instances train = cvData.trainCV(m_numFoldsPruning, i);
            Instances test = cvData.testCV(m_numFoldsPruning, i);

            // calculate sorted indices and weights, and compute initial class counts
            // for each fold
            int[][] sortedIndices = new int[train.numAttributes()][0];
            double[][] weights = new double[train.numAttributes()][0];
            double[] classProbs = new double[train.numClasses()];
            double totalWeight = computeSortedInfo(train, sortedIndices, weights, classProbs);

            makeTree(train, train.numInstances(), sortedIndices, weights, classProbs, totalWeight, m_minNumObj,
                    m_Heuristic);

            int numNodes = numInnerNodes();
            alphas[i] = new double[numNodes + 2];
            errors[i] = new double[numNodes + 2];

            // prune back and log alpha-values and errors on test set
            prune(alphas[i], errors[i], test);
        }

        // calculate sorted indices and weights, and compute initial class counts on
        // all training instances
        int[][] sortedIndices = new int[data.numAttributes()][0];
        double[][] weights = new double[data.numAttributes()][0];
        double[] classProbs = new double[data.numClasses()];
        double totalWeight = computeSortedInfo(data, sortedIndices, weights, classProbs);

        // build tree using all the data
        makeTree(data, data.numInstances(), sortedIndices, weights, classProbs, totalWeight, m_minNumObj,
                m_Heuristic);

        int numNodes = numInnerNodes();

        double[] treeAlphas = new double[numNodes + 2];

        // prune back and log alpha-values
        int iterations = prune(treeAlphas, null, null);

        double[] treeErrors = new double[numNodes + 2];

        // for each pruned subtree, find the cross-validated error
        for (int i = 0; i <= iterations; i++) {
            // compute midpoint alphas
            double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i + 1]);
            double error = 0;
            for (int k = 0; k < m_numFoldsPruning; k++) {
                int l = 0;
                while (alphas[k][l] <= alpha) {
                    l++;
                }
                error += errors[k][l - 1];
            }
            treeErrors[i] = error / m_numFoldsPruning;
        }

        // find best alpha
        int best = -1;
        double bestError = Double.MAX_VALUE;
        for (int i = iterations; i >= 0; i--) {
            if (treeErrors[i] < bestError) {
                bestError = treeErrors[i];
                best = i;
            }
        }

        // 1 SE rule to choose expansion
        if (m_UseOneSE) {
            double oneSE = Math.sqrt(bestError * (1 - bestError) / (data.numInstances()));
            for (int i = iterations; i >= 0; i--) {
                if (treeErrors[i] <= bestError + oneSE) {
                    best = i;
                    break;
                }
            }
        }

        double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);

        // "unprune" final tree (faster than regrowing it)
        unprune();
        prune(bestAlpha);
    }

    /**
     * Make binary decision tree recursively.
     * 
     * @param data the training instances
     * @param totalInstances total number of instances
     * @param sortedIndices sorted indices of the instances
     * @param weights weights of the instances
     * @param classProbs class probabilities
     * @param totalWeight total weight of instances
     * @param minNumObj minimal number of instances at leaf nodes
     * @param useHeuristic if use heuristic search for nominal attributes in
     *          multi-class problem
     * @throws Exception if something goes wrong
     */
    protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices, double[][] weights,
            double[] classProbs, double totalWeight, double minNumObj, boolean useHeuristic) throws Exception {

        // if no instances have reached this node (normally won't happen)
        if (totalWeight == 0) {
            m_Attribute = null;
            m_ClassValue = Utils.missingValue();
            m_Distribution = new double[data.numClasses()];
            return;
        }

        m_totalTrainInstances = totalInstances;
        m_isLeaf = true;
        m_Successors = null;

        m_ClassProbs = new double[classProbs.length];
        m_Distribution = new double[classProbs.length];
        System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
        System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
        if (Utils.sum(m_ClassProbs) != 0) {
            Utils.normalize(m_ClassProbs);
        }

        // Compute class distributions and value of splitting
        // criterion for each attribute
        double[][][] dists = new double[data.numAttributes()][0][0];
        double[][] props = new double[data.numAttributes()][0];
        double[][] totalSubsetWeights = new double[data.numAttributes()][2];
        double[] splits = new double[data.numAttributes()];
        String[] splitString = new String[data.numAttributes()];
        double[] giniGains = new double[data.numAttributes()];

        // for each attribute find split information
        for (int i = 0; i < data.numAttributes(); i++) {
            Attribute att = data.attribute(i);
            if (i == data.classIndex()) {
                continue;
            }
            if (att.isNumeric()) {
                // numeric attribute
                splits[i] = numericDistribution(props, dists, att, sortedIndices[i], weights[i], totalSubsetWeights,
                        giniGains, data);
            } else {
                // nominal attribute
                splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i], weights[i],
                        totalSubsetWeights, giniGains, data, useHeuristic);
            }
        }

        // Find best attribute (split with maximum Gini gain)
        int attIndex = Utils.maxIndex(giniGains);
        m_Attribute = data.attribute(attIndex);

        m_train = new Instances(data, sortedIndices[attIndex].length);
        for (int i = 0; i < sortedIndices[attIndex].length; i++) {
            Instance inst = data.instance(sortedIndices[attIndex][i]);
            Instance instCopy = (Instance) inst.copy();
            instCopy.setWeight(weights[attIndex][i]);
            m_train.add(instCopy);
        }

        // Check if node does not contain enough instances, or if it can not be
        // split,
        // or if it is pure. If does, make leaf.
        if (totalWeight < 2 * minNumObj || giniGains[attIndex] == 0 || props[attIndex][0] == 0
                || props[attIndex][1] == 0) {
            makeLeaf(data);
        }

        else {
            m_Props = props[attIndex];
            int[][][] subsetIndices = new int[2][data.numAttributes()][0];
            double[][][] subsetWeights = new double[2][data.numAttributes()][0];

            // numeric split
            if (m_Attribute.isNumeric()) {
                m_SplitValue = splits[attIndex];
            } else {
                m_SplitString = splitString[attIndex];
            }

            splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue, m_SplitString, sortedIndices,
                    weights, data);

            // If split of the node results in a node with less than minimal number of
            // isntances,
            // make the node leaf node.
            if (subsetIndices[0][attIndex].length < minNumObj || subsetIndices[1][attIndex].length < minNumObj) {
                makeLeaf(data);
                return;
            }

            // Otherwise, split the node.
            m_isLeaf = false;
            m_Successors = new SimpleCart[2];
            for (int i = 0; i < 2; i++) {
                m_Successors[i] = new SimpleCart();
                m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i], subsetWeights[i],
                        dists[attIndex][i], totalSubsetWeights[attIndex][i], minNumObj, useHeuristic);
            }
        }
    }

    /**
     * Prunes the original tree using the CART pruning scheme, given a
     * cost-complexity parameter alpha.
     * 
     * @param alpha the cost-complexity parameter
     * @throws Exception if something goes wrong
     */
    public void prune(double alpha) throws Exception {

        Vector<SimpleCart> nodeList;

        // determine training error of pruned subtrees (both with and without
        // replacing a subtree),
        // and calculate alpha-values from them
        modelErrors();
        treeErrors();
        calculateAlphas();

        // get list of all inner nodes in the tree
        nodeList = getInnerNodes();

        boolean prune = (nodeList.size() > 0);
        double preAlpha = Double.MAX_VALUE;
        while (prune) {

            // select node with minimum alpha
            SimpleCart nodeToPrune = nodeToPrune(nodeList);

            // want to prune if its alpha is smaller than alpha
            if (nodeToPrune.m_Alpha > alpha) {
                break;
            }

            nodeToPrune.makeLeaf(nodeToPrune.m_train);

            // normally would not happen
            if (nodeToPrune.m_Alpha == preAlpha) {
                nodeToPrune.makeLeaf(nodeToPrune.m_train);
                treeErrors();
                calculateAlphas();
                nodeList = getInnerNodes();
                prune = (nodeList.size() > 0);
                continue;
            }
            preAlpha = nodeToPrune.m_Alpha;

            // update tree errors and alphas
            treeErrors();
            calculateAlphas();

            nodeList = getInnerNodes();
            prune = (nodeList.size() > 0);
        }
    }

    /**
     * Method for performing one fold in the cross-validation of minimal
     * cost-complexity pruning. Generates a sequence of alpha-values with error
     * estimates for the corresponding (partially pruned) trees, given the test
     * set of that fold.
     * 
     * @param alphas array to hold the generated alpha-values
     * @param errors array to hold the corresponding error estimates
     * @param test test set of that fold (to obtain error estimates)
     * @return the iteration of the pruning
     * @throws Exception if something goes wrong
     */
    public int prune(double[] alphas, double[] errors, Instances test) throws Exception {

        Vector<SimpleCart> nodeList;

        // determine training error of subtrees (both with and without replacing a
        // subtree),
        // and calculate alpha-values from them
        modelErrors();
        treeErrors();
        calculateAlphas();

        // get list of all inner nodes in the tree
        nodeList = getInnerNodes();

        boolean prune = (nodeList.size() > 0);

        // alpha_0 is always zero (unpruned tree)
        alphas[0] = 0;

        Evaluation eval;

        // error of unpruned tree
        if (errors != null) {
            eval = new Evaluation(test);
            eval.evaluateModel(this, test);
            errors[0] = eval.errorRate();
        }

        int iteration = 0;
        double preAlpha = Double.MAX_VALUE;
        while (prune) {

            iteration++;

            // get node with minimum alpha
            SimpleCart nodeToPrune = nodeToPrune(nodeList);

            // do not set m_sons null, want to unprune
            nodeToPrune.m_isLeaf = true;

            // normally would not happen
            if (nodeToPrune.m_Alpha == preAlpha) {
                iteration--;
                treeErrors();
                calculateAlphas();
                nodeList = getInnerNodes();
                prune = (nodeList.size() > 0);
                continue;
            }

            // get alpha-value of node
            alphas[iteration] = nodeToPrune.m_Alpha;

            // log error
            if (errors != null) {
                eval = new Evaluation(test);
                eval.evaluateModel(this, test);
                errors[iteration] = eval.errorRate();
            }
            preAlpha = nodeToPrune.m_Alpha;

            // update errors/alphas
            treeErrors();
            calculateAlphas();

            nodeList = getInnerNodes();
            prune = (nodeList.size() > 0);
        }

        // set last alpha 1 to indicate end
        alphas[iteration + 1] = 1.0;
        return iteration;
    }

    /**
     * Method to "unprune" the CART tree. Sets all leaf-fields to false. Faster
     * than re-growing the tree because CART do not have to be fit again.
     */
    protected void unprune() {
        if (m_Successors != null) {
            m_isLeaf = false;
            for (SimpleCart m_Successor : m_Successors) {
                m_Successor.unprune();
            }
        }
    }

    /**
     * Compute distributions, proportions and total weights of two successor nodes
     * for a given numeric attribute.
     * 
     * @param props proportions of each two branches for each attribute
     * @param dists class distributions of two branches for each attribute
     * @param att numeric att split on
     * @param sortedIndices sorted indices of instances for the attirubte
     * @param weights weights of instances for the attirbute
     * @param subsetWeights total weight of two branches split based on the
     *          attribute
     * @param giniGains Gini gains for each attribute
     * @param data training instances
     * @return Gini gain the given numeric attribute
     * @throws Exception if something goes wrong
     */
    protected double numericDistribution(double[][] props, double[][][] dists, Attribute att, int[] sortedIndices,
            double[] weights, double[][] subsetWeights, double[] giniGains, Instances data) throws Exception {

        double splitPoint = Double.NaN;
        double[][] dist = null;
        int numClasses = data.numClasses();
        int i; // differ instances with or without missing values

        double[][] currDist = new double[2][numClasses];
        dist = new double[2][numClasses];

        // Move all instances without missing values into second subset
        double[] parentDist = new double[numClasses];
        int missingStart = 0;
        for (int j = 0; j < sortedIndices.length; j++) {
            Instance inst = data.instance(sortedIndices[j]);
            if (!inst.isMissing(att)) {
                missingStart++;
                currDist[1][(int) inst.classValue()] += weights[j];
            }
            parentDist[(int) inst.classValue()] += weights[j];
        }
        System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);

        // Try all possible split points
        double currSplit = data.instance(sortedIndices[0]).value(att);
        double currGiniGain;
        double bestGiniGain = -Double.MAX_VALUE;

        for (i = 0; i < sortedIndices.length; i++) {
            Instance inst = data.instance(sortedIndices[i]);
            if (inst.isMissing(att)) {
                break;
            }
            if (inst.value(att) > currSplit) {

                double[][] tempDist = new double[2][numClasses];
                for (int k = 0; k < 2; k++) {
                    // tempDist[k] = currDist[k];
                    System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
                }

                double[] tempProps = new double[2];
                for (int k = 0; k < 2; k++) {
                    tempProps[k] = Utils.sum(tempDist[k]);
                }

                if (Utils.sum(tempProps) != 0) {
                    Utils.normalize(tempProps);
                }

                // split missing values
                int index = missingStart;
                while (index < sortedIndices.length) {
                    Instance insta = data.instance(sortedIndices[index]);
                    for (int j = 0; j < 2; j++) {
                        tempDist[j][(int) insta.classValue()] += tempProps[j] * weights[index];
                    }
                    index++;
                }

                currGiniGain = computeGiniGain(parentDist, tempDist);

                if (currGiniGain > bestGiniGain) {
                    bestGiniGain = currGiniGain;

                    // clean split point
                    // splitPoint = Math.rint((inst.value(att) +
                    // currSplit)/2.0*100000)/100000.0;
                    splitPoint = (inst.value(att) + currSplit) / 2.0;

                    for (int j = 0; j < currDist.length; j++) {
                        System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length);
                    }
                }
            }
            currSplit = inst.value(att);
            currDist[0][(int) inst.classValue()] += weights[i];
            currDist[1][(int) inst.classValue()] -= weights[i];
        }

        // Compute weights
        int attIndex = att.index();
        props[attIndex] = new double[2];
        for (int k = 0; k < 2; k++) {
            props[attIndex][k] = Utils.sum(dist[k]);
        }
        if (Utils.sum(props[attIndex]) != 0) {
            Utils.normalize(props[attIndex]);
        }

        // Compute subset weights
        subsetWeights[attIndex] = new double[2];
        for (int j = 0; j < 2; j++) {
            subsetWeights[attIndex][j] += Utils.sum(dist[j]);
        }

        // clean Gini gain
        // giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
        giniGains[attIndex] = bestGiniGain;
        dists[attIndex] = dist;

        return splitPoint;
    }

    /**
     * Compute distributions, proportions and total weights of two successor nodes
     * for a given nominal attribute.
     * 
     * @param props proportions of each two branches for each attribute
     * @param dists class distributions of two branches for each attribute
     * @param att numeric att split on
     * @param sortedIndices sorted indices of instances for the attirubte
     * @param weights weights of instances for the attirbute
     * @param subsetWeights total weight of two branches split based on the
     *          attribute
     * @param giniGains Gini gains for each attribute
     * @param data training instances
     * @param useHeuristic if use heuristic search
     * @return Gini gain for the given nominal attribute
     * @throws Exception if something goes wrong
     */
    protected String nominalDistribution(double[][] props, double[][][] dists, Attribute att, int[] sortedIndices,
            double[] weights, double[][] subsetWeights, double[] giniGains, Instances data, boolean useHeuristic)
            throws Exception {

        String[] values = new String[att.numValues()];
        int numCat = values.length; // number of values of the attribute
        int numClasses = data.numClasses();

        String bestSplitString = "";
        double bestGiniGain = -Double.MAX_VALUE;

        // class frequency for each value
        int[] classFreq = new int[numCat];
        for (int j = 0; j < numCat; j++) {
            classFreq[j] = 0;
        }

        double[] parentDist = new double[numClasses];
        double[][] currDist = new double[2][numClasses];
        double[][] dist = new double[2][numClasses];
        int missingStart = 0;

        for (int i = 0; i < sortedIndices.length; i++) {
            Instance inst = data.instance(sortedIndices[i]);
            if (!inst.isMissing(att)) {
                missingStart++;
                classFreq[(int) inst.value(att)]++;
            }
            parentDist[(int) inst.classValue()] += weights[i];
        }

        // count the number of values that class frequency is not 0
        int nonEmpty = 0;
        for (int j = 0; j < numCat; j++) {
            if (classFreq[j] != 0) {
                nonEmpty++;
            }
        }

        // attribute values that class frequency is not 0
        String[] nonEmptyValues = new String[nonEmpty];
        int nonEmptyIndex = 0;
        for (int j = 0; j < numCat; j++) {
            if (classFreq[j] != 0) {
                nonEmptyValues[nonEmptyIndex] = att.value(j);
                nonEmptyIndex++;
            }
        }

        // attribute values that class frequency is 0
        int empty = numCat - nonEmpty;
        String[] emptyValues = new String[empty];
        int emptyIndex = 0;
        for (int j = 0; j < numCat; j++) {
            if (classFreq[j] == 0) {
                emptyValues[emptyIndex] = att.value(j);
                emptyIndex++;
            }
        }

        if (nonEmpty <= 1) {
            giniGains[att.index()] = 0;
            return "";
        }

        // for tow-class probloms
        if (data.numClasses() == 2) {

            // // Firstly, for attribute values which class frequency is not zero

            // probability of class 0 for each attribute value
            double[] pClass0 = new double[nonEmpty];
            // class distribution for each attribute value
            double[][] valDist = new double[nonEmpty][2];

            for (int j = 0; j < nonEmpty; j++) {
                for (int k = 0; k < 2; k++) {
                    valDist[j][k] = 0;
                }
            }

            for (int sortedIndice : sortedIndices) {
                Instance inst = data.instance(sortedIndice);
                if (inst.isMissing(att)) {
                    break;
                }

                for (int j = 0; j < nonEmpty; j++) {
                    if (att.value((int) inst.value(att)).compareTo(nonEmptyValues[j]) == 0) {
                        valDist[j][(int) inst.classValue()] += inst.weight();
                        break;
                    }
                }
            }

            for (int j = 0; j < nonEmpty; j++) {
                double distSum = Utils.sum(valDist[j]);
                if (distSum == 0) {
                    pClass0[j] = 0;
                } else {
                    pClass0[j] = valDist[j][0] / distSum;
                }
            }

            // sort category according to the probability of the first class
            String[] sortedValues = new String[nonEmpty];
            for (int j = 0; j < nonEmpty; j++) {
                sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];
                pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
            }

            // Find a subset of attribute values that maximize Gini decrease

            // for the attribute values that class frequency is not 0
            String tempStr = "";

            for (int j = 0; j < nonEmpty - 1; j++) {
                currDist = new double[2][numClasses];
                if (tempStr == "") {
                    tempStr = "(" + sortedValues[j] + ")";
                } else {
                    tempStr += "|" + "(" + sortedValues[j] + ")";
                }
                for (int i = 0; i < sortedIndices.length; i++) {
                    Instance inst = data.instance(sortedIndices[i]);
                    if (inst.isMissing(att)) {
                        break;
                    }

                    if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) {
                        currDist[0][(int) inst.classValue()] += weights[i];
                    } else {
                        currDist[1][(int) inst.classValue()] += weights[i];
                    }
                }

                double[][] tempDist = new double[2][numClasses];
                for (int kk = 0; kk < 2; kk++) {
                    tempDist[kk] = currDist[kk];
                }

                double[] tempProps = new double[2];
                for (int kk = 0; kk < 2; kk++) {
                    tempProps[kk] = Utils.sum(tempDist[kk]);
                }

                if (Utils.sum(tempProps) != 0) {
                    Utils.normalize(tempProps);
                }

                // split missing values
                int mstart = missingStart;
                while (mstart < sortedIndices.length) {
                    Instance insta = data.instance(sortedIndices[mstart]);
                    for (int jj = 0; jj < 2; jj++) {
                        tempDist[jj][(int) insta.classValue()] += tempProps[jj] * weights[mstart];
                    }
                    mstart++;
                }

                double currGiniGain = computeGiniGain(parentDist, tempDist);

                if (currGiniGain > bestGiniGain) {
                    bestGiniGain = currGiniGain;
                    bestSplitString = tempStr;
                    for (int jj = 0; jj < 2; jj++) {
                        // dist[jj] = new double[currDist[jj].length];
                        System.arraycopy(tempDist[jj], 0, dist[jj], 0, dist[jj].length);
                    }
                }
            }
        }

        // multi-class problems - exhaustive search
        else if (!useHeuristic || nonEmpty <= 4) {

            // Firstly, for attribute values which class frequency is not zero
            for (int i = 0; i < (int) Math.pow(2, nonEmpty - 1); i++) {
                String tempStr = "";
                currDist = new double[2][numClasses];
                int mod;
                int bit10 = i;
                for (int j = nonEmpty - 1; j >= 0; j--) {
                    mod = bit10 % 2; // convert from 10bit to 2bit
                    if (mod == 1) {
                        if (tempStr == "") {
                            tempStr = "(" + nonEmptyValues[j] + ")";
                        } else {
                            tempStr += "|" + "(" + nonEmptyValues[j] + ")";
                        }
                    }
                    bit10 = bit10 / 2;
                }
                for (int j = 0; j < sortedIndices.length; j++) {
                    Instance inst = data.instance(sortedIndices[j]);
                    if (inst.isMissing(att)) {
                        break;
                    }

                    if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) {
                        currDist[0][(int) inst.classValue()] += weights[j];
                    } else {
                        currDist[1][(int) inst.classValue()] += weights[j];
                    }
                }

                double[][] tempDist = new double[2][numClasses];
                for (int k = 0; k < 2; k++) {
                    tempDist[k] = currDist[k];
                }

                double[] tempProps = new double[2];
                for (int k = 0; k < 2; k++) {
                    tempProps[k] = Utils.sum(tempDist[k]);
                }

                if (Utils.sum(tempProps) != 0) {
                    Utils.normalize(tempProps);
                }

                // split missing values
                int index = missingStart;
                while (index < sortedIndices.length) {
                    Instance insta = data.instance(sortedIndices[index]);
                    for (int j = 0; j < 2; j++) {
                        tempDist[j][(int) insta.classValue()] += tempProps[j] * weights[index];
                    }
                    index++;
                }

                double currGiniGain = computeGiniGain(parentDist, tempDist);

                if (currGiniGain > bestGiniGain) {
                    bestGiniGain = currGiniGain;
                    bestSplitString = tempStr;
                    for (int j = 0; j < 2; j++) {
                        // dist[jj] = new double[currDist[jj].length];
                        System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length);
                    }
                }
            }
        }

        // huristic search to solve multi-classes problems
        else {
            // Firstly, for attribute values which class frequency is not zero
            int n = nonEmpty;
            int k = data.numClasses(); // number of classes of the data
            double[][] P = new double[n][k]; // class probability matrix
            int[] numInstancesValue = new int[n]; // number of instances for an
                                                  // attribute value
            double[] meanClass = new double[k]; // vector of mean class probability
            int numInstances = data.numInstances(); // total number of instances

            // initialize the vector of mean class probability
            for (int j = 0; j < meanClass.length; j++) {
                meanClass[j] = 0;
            }

            for (int j = 0; j < numInstances; j++) {
                Instance inst = data.instance(j);
                int valueIndex = 0; // attribute value index in nonEmptyValues
                for (int i = 0; i < nonEmpty; i++) {
                    if (att.value((int) inst.value(att)).compareToIgnoreCase(nonEmptyValues[i]) == 0) {
                        valueIndex = i;
                        break;
                    }
                }
                P[valueIndex][(int) inst.classValue()]++;
                numInstancesValue[valueIndex]++;
                meanClass[(int) inst.classValue()]++;
            }

            // calculate the class probability matrix
            for (int i = 0; i < P.length; i++) {
                for (int j = 0; j < P[0].length; j++) {
                    if (numInstancesValue[i] == 0) {
                        P[i][j] = 0;
                    } else {
                        P[i][j] /= numInstancesValue[i];
                    }
                }
            }

            // calculate the vector of mean class probability
            for (int i = 0; i < meanClass.length; i++) {
                meanClass[i] /= numInstances;
            }

            // calculate the covariance matrix
            double[][] covariance = new double[k][k];
            for (int i1 = 0; i1 < k; i1++) {
                for (int i2 = 0; i2 < k; i2++) {
                    double element = 0;
                    for (int j = 0; j < n; j++) {
                        element += (P[j][i2] - meanClass[i2]) * (P[j][i1] - meanClass[i1]) * numInstancesValue[j];
                    }
                    covariance[i1][i2] = element;
                }
            }

            Matrix matrix = new Matrix(covariance);
            weka.core.matrix.EigenvalueDecomposition eigen = new weka.core.matrix.EigenvalueDecomposition(matrix);
            double[] eigenValues = eigen.getRealEigenvalues();

            // find index of the largest eigenvalue
            int index = 0;
            double largest = eigenValues[0];
            for (int i = 1; i < eigenValues.length; i++) {
                if (eigenValues[i] > largest) {
                    index = i;
                    largest = eigenValues[i];
                }
            }

            // calculate the first principle component
            double[] FPC = new double[k];
            Matrix eigenVector = eigen.getV();
            double[][] vectorArray = eigenVector.getArray();
            for (int i = 0; i < FPC.length; i++) {
                FPC[i] = vectorArray[i][index];
            }

            // calculate the first principle component scores
            // System.out.println("the first principle component scores: ");
            double[] Sa = new double[n];
            for (int i = 0; i < Sa.length; i++) {
                Sa[i] = 0;
                for (int j = 0; j < k; j++) {
                    Sa[i] += FPC[j] * P[i][j];
                }
            }

            // sort category according to Sa(s)
            double[] pCopy = new double[n];
            System.arraycopy(Sa, 0, pCopy, 0, n);
            String[] sortedValues = new String[n];
            Arrays.sort(Sa);

            for (int j = 0; j < n; j++) {
                sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
                pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
            }

            // for the attribute values that class frequency is not 0
            String tempStr = "";

            for (int j = 0; j < nonEmpty - 1; j++) {
                currDist = new double[2][numClasses];
                if (tempStr == "") {
                    tempStr = "(" + sortedValues[j] + ")";
                } else {
                    tempStr += "|" + "(" + sortedValues[j] + ")";
                }
                for (int i = 0; i < sortedIndices.length; i++) {
                    Instance inst = data.instance(sortedIndices[i]);
                    if (inst.isMissing(att)) {
                        break;
                    }

                    if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) {
                        currDist[0][(int) inst.classValue()] += weights[i];
                    } else {
                        currDist[1][(int) inst.classValue()] += weights[i];
                    }
                }

                double[][] tempDist = new double[2][numClasses];
                for (int kk = 0; kk < 2; kk++) {
                    tempDist[kk] = currDist[kk];
                }

                double[] tempProps = new double[2];
                for (int kk = 0; kk < 2; kk++) {
                    tempProps[kk] = Utils.sum(tempDist[kk]);
                }

                if (Utils.sum(tempProps) != 0) {
                    Utils.normalize(tempProps);
                }

                // split missing values
                int mstart = missingStart;
                while (mstart < sortedIndices.length) {
                    Instance insta = data.instance(sortedIndices[mstart]);
                    for (int jj = 0; jj < 2; jj++) {
                        tempDist[jj][(int) insta.classValue()] += tempProps[jj] * weights[mstart];
                    }
                    mstart++;
                }

                double currGiniGain = computeGiniGain(parentDist, tempDist);

                if (currGiniGain > bestGiniGain) {
                    bestGiniGain = currGiniGain;
                    bestSplitString = tempStr;
                    for (int jj = 0; jj < 2; jj++) {
                        // dist[jj] = new double[currDist[jj].length];
                        System.arraycopy(tempDist[jj], 0, dist[jj], 0, dist[jj].length);
                    }
                }
            }
        }

        // Compute weights
        int attIndex = att.index();
        props[attIndex] = new double[2];
        for (int k = 0; k < 2; k++) {
            props[attIndex][k] = Utils.sum(dist[k]);
        }

        if (!(Utils.sum(props[attIndex]) > 0)) {
            for (int k = 0; k < props[attIndex].length; k++) {
                props[attIndex][k] = 1.0 / props[attIndex].length;
            }
        } else {
            Utils.normalize(props[attIndex]);
        }

        // Compute subset weights
        subsetWeights[attIndex] = new double[2];
        for (int j = 0; j < 2; j++) {
            subsetWeights[attIndex][j] += Utils.sum(dist[j]);
        }

        // Then, for the attribute values that class frequency is 0, split it into
        // the
        // most frequent branch
        for (int j = 0; j < empty; j++) {
            if (props[attIndex][0] >= props[attIndex][1]) {
                if (bestSplitString == "") {
                    bestSplitString = "(" + emptyValues[j] + ")";
                } else {
                    bestSplitString += "|" + "(" + emptyValues[j] + ")";
                }
            }
        }

        // clean Gini gain for the attribute
        // giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
        giniGains[attIndex] = bestGiniGain;

        dists[attIndex] = dist;
        return bestSplitString;
    }

    /**
     * Split data into two subsets and store sorted indices and weights for two
     * successor nodes.
     * 
     * @param subsetIndices sorted indecis of instances for each attribute for two
     *          successor node
     * @param subsetWeights weights of instances for each attribute for two
     *          successor node
     * @param att attribute the split based on
     * @param splitPoint split point the split based on if att is numeric
     * @param splitStr split subset the split based on if att is nominal
     * @param sortedIndices sorted indices of the instances to be split
     * @param weights weights of the instances to bes split
     * @param data training data
     * @throws Exception if something goes wrong
     */
    protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights, Attribute att, double splitPoint,
            String splitStr, int[][] sortedIndices, double[][] weights, Instances data) throws Exception {

        int j;
        // For each attribute
        for (int i = 0; i < data.numAttributes(); i++) {
            if (i == data.classIndex()) {
                continue;
            }
            int[] num = new int[2];
            for (int k = 0; k < 2; k++) {
                subsetIndices[k][i] = new int[sortedIndices[i].length];
                subsetWeights[k][i] = new double[weights[i].length];
            }

            for (j = 0; j < sortedIndices[i].length; j++) {
                Instance inst = data.instance(sortedIndices[i][j]);
                if (inst.isMissing(att)) {
                    // Split instance up
                    for (int k = 0; k < 2; k++) {
                        if (m_Props[k] > 0) {
                            subsetIndices[k][i][num[k]] = sortedIndices[i][j];
                            subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
                            num[k]++;
                        }
                    }
                } else {
                    int subset;
                    if (att.isNumeric()) {
                        subset = (inst.value(att) < splitPoint) ? 0 : 1;
                    } else { // nominal attribute
                        if (splitStr.indexOf("(" + att.value((int) inst.value(att.index())) + ")") != -1) {
                            subset = 0;
                        } else {
                            subset = 1;
                        }
                    }
                    subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
                    subsetWeights[subset][i][num[subset]] = weights[i][j];
                    num[subset]++;
                }
            }

            // Trim arrays
            for (int k = 0; k < 2; k++) {
                int[] copy = new int[num[k]];
                System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
                subsetIndices[k][i] = copy;
                double[] copyWeights = new double[num[k]];
                System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]);
                subsetWeights[k][i] = copyWeights;
            }
        }
    }

    /**
     * Updates the numIncorrectModel field for all nodes when subtree (to be
     * pruned) is rooted. This is needed for calculating the alpha-values.
     * 
     * @throws Exception if something goes wrong
     */
    public void modelErrors() throws Exception {
        Evaluation eval = new Evaluation(m_train);

        if (!m_isLeaf) {
            m_isLeaf = true; // temporarily make leaf

            // calculate distribution for evaluation
            eval.evaluateModel(this, m_train);
            m_numIncorrectModel = eval.incorrect();

            m_isLeaf = false;

            for (SimpleCart m_Successor : m_Successors) {
                m_Successor.modelErrors();
            }

        } else {
            eval.evaluateModel(this, m_train);
            m_numIncorrectModel = eval.incorrect();
        }
    }

    /**
     * Updates the numIncorrectTree field for all nodes. This is needed for
     * calculating the alpha-values.
     * 
     * @throws Exception if something goes wrong
     */
    public void treeErrors() throws Exception {
        if (m_isLeaf) {
            m_numIncorrectTree = m_numIncorrectModel;
        } else {
            m_numIncorrectTree = 0;
            for (SimpleCart m_Successor : m_Successors) {
                m_Successor.treeErrors();
                m_numIncorrectTree += m_Successor.m_numIncorrectTree;
            }
        }
    }

    /**
     * Updates the alpha field for all nodes.
     * 
     * @throws Exception if something goes wrong
     */
    public void calculateAlphas() throws Exception {

        if (!m_isLeaf) {
            double errorDiff = m_numIncorrectModel - m_numIncorrectTree;
            if (errorDiff <= 0) {
                // split increases training error (should not normally happen).
                // prune it instantly.
                makeLeaf(m_train);
                m_Alpha = Double.MAX_VALUE;
            } else {
                // compute alpha
                errorDiff /= m_totalTrainInstances;
                m_Alpha = errorDiff / (numLeaves() - 1);
                long alphaLong = Math.round(m_Alpha * Math.pow(10, 10));
                m_Alpha = alphaLong / Math.pow(10, 10);
                for (SimpleCart m_Successor : m_Successors) {
                    m_Successor.calculateAlphas();
                }
            }
        } else {
            // alpha = infinite for leaves (do not want to prune)
            m_Alpha = Double.MAX_VALUE;
        }
    }

    /**
     * Find the node with minimal alpha value. If two nodes have the same alpha,
     * choose the one with more leave nodes.
     * 
     * @param nodeList list of inner nodes
     * @return the node to be pruned
     */
    protected SimpleCart nodeToPrune(Vector<SimpleCart> nodeList) {
        if (nodeList.size() == 0) {
            return null;
        }
        if (nodeList.size() == 1) {
            return nodeList.elementAt(0);
        }
        SimpleCart returnNode = nodeList.elementAt(0);
        double baseAlpha = returnNode.m_Alpha;
        for (int i = 1; i < nodeList.size(); i++) {
            SimpleCart node = nodeList.elementAt(i);
            if (node.m_Alpha < baseAlpha) {
                baseAlpha = node.m_Alpha;
                returnNode = node;
            } else if (node.m_Alpha == baseAlpha) { // break tie
                if (node.numLeaves() > returnNode.numLeaves()) {
                    returnNode = node;
                }
            }
        }
        return returnNode;
    }

    /**
     * Compute sorted indices, weights and class probabilities for a given
     * dataset. Return total weights of the data at the node.
     * 
     * @param data training data
     * @param sortedIndices sorted indices of instances at the node
     * @param weights weights of instances at the node
     * @param classProbs class probabilities at the node
     * @return total weights of instances at the node
     * @throws Exception if something goes wrong
     */
    protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,
            double[] classProbs) throws Exception {

        // Create array of sorted indices and weights
        double[] vals = new double[data.numInstances()];
        for (int j = 0; j < data.numAttributes(); j++) {
            if (j == data.classIndex()) {
                continue;
            }
            weights[j] = new double[data.numInstances()];

            if (data.attribute(j).isNominal()) {

                // Handling nominal attributes. Putting indices of
                // instances with missing values at the end.
                sortedIndices[j] = new int[data.numInstances()];
                int count = 0;
                for (int i = 0; i < data.numInstances(); i++) {
                    Instance inst = data.instance(i);
                    if (!inst.isMissing(j)) {
                        sortedIndices[j][count] = i;
                        weights[j][count] = inst.weight();
                        count++;
                    }
                }
                for (int i = 0; i < data.numInstances(); i++) {
                    Instance inst = data.instance(i);
                    if (inst.isMissing(j)) {
                        sortedIndices[j][count] = i;
                        weights[j][count] = inst.weight();
                        count++;
                    }
                }
            } else {

                // Sorted indices are computed for numeric attributes
                // missing values instances are put to end
                for (int i = 0; i < data.numInstances(); i++) {
                    Instance inst = data.instance(i);
                    vals[i] = inst.value(j);
                }
                sortedIndices[j] = Utils.sort(vals);
                for (int i = 0; i < data.numInstances(); i++) {
                    weights[j][i] = data.instance(sortedIndices[j][i]).weight();
                }
            }
        }

        // Compute initial class counts
        double totalWeight = 0;
        for (int i = 0; i < data.numInstances(); i++) {
            Instance inst = data.instance(i);
            classProbs[(int) inst.classValue()] += inst.weight();
            totalWeight += inst.weight();
        }

        return totalWeight;
    }

    /**
     * Compute and return gini gain for given distributions of a node and its
     * successor nodes.
     * 
     * @param parentDist class distributions of parent node
     * @param childDist class distributions of successor nodes
     * @return Gini gain computed
     */
    protected double computeGiniGain(double[] parentDist, double[][] childDist) {
        double totalWeight = Utils.sum(parentDist);
        if (totalWeight == 0) {
            return 0;
        }

        double leftWeight = Utils.sum(childDist[0]);
        double rightWeight = Utils.sum(childDist[1]);

        double parentGini = computeGini(parentDist, totalWeight);
        double leftGini = computeGini(childDist[0], leftWeight);
        double rightGini = computeGini(childDist[1], rightWeight);

        return parentGini - leftWeight / totalWeight * leftGini - rightWeight / totalWeight * rightGini;
    }

    /**
     * Compute and return gini index for a given distribution of a node.
     * 
     * @param dist class distributions
     * @param total class distributions
     * @return Gini index of the class distributions
     */
    protected double computeGini(double[] dist, double total) {
        if (total == 0) {
            return 0;
        }
        double val = 0;
        for (double element : dist) {
            val += (element / total) * (element / total);
        }
        return 1 - val;
    }

    /**
     * Computes class probabilities for instance using the decision tree.
     * 
     * @param instance the instance for which class probabilities is to be
     *          computed
     * @return the class probabilities for the given instance
     * @throws Exception if something goes wrong
     */
    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!m_isLeaf) {
            // value of split attribute is missing
            if (instance.isMissing(m_Attribute)) {
                double[] returnedDist = new double[m_ClassProbs.length];

                for (int i = 0; i < m_Successors.length; i++) {
                    double[] help = m_Successors[i].distributionForInstance(instance);
                    if (help != null) {
                        for (int j = 0; j < help.length; j++) {
                            returnedDist[j] += m_Props[i] * help[j];
                        }
                    }
                }
                return returnedDist;
            }

            // split attribute is nonimal
            else if (m_Attribute.isNominal()) {
                if (m_SplitString.indexOf("(" + m_Attribute.value((int) instance.value(m_Attribute)) + ")") != -1) {
                    return m_Successors[0].distributionForInstance(instance);
                } else {
                    return m_Successors[1].distributionForInstance(instance);
                }
            }

            // split attribute is numeric
            else {
                if (instance.value(m_Attribute) < m_SplitValue) {
                    return m_Successors[0].distributionForInstance(instance);
                } else {
                    return m_Successors[1].distributionForInstance(instance);
                }
            }
        } else {
            return m_ClassProbs;
        }
    }

    /**
     * Make the node leaf node.
     * 
     * @param data trainging data
     */
    protected void makeLeaf(Instances data) {
        m_Attribute = null;
        m_isLeaf = true;
        m_ClassValue = Utils.maxIndex(m_ClassProbs);
        m_ClassAttribute = data.classAttribute();
    }

    /**
     * Prints the decision tree using the protected toString method from below.
     * 
     * @return a textual description of the classifier
     */
    @Override
    public String toString() {
        if ((m_ClassProbs == null) && (m_Successors == null)) {
            return "CART Tree: No model built yet.";
        }

        return "CART Decision Tree\n" + toString(0) + "\n\n" + "Number of Leaf Nodes: " + numLeaves() + "\n\n"
                + "Size of the Tree: " + numNodes();
    }

    /**
     * Outputs a tree at a certain level.
     * 
     * @param level the level at which the tree is to be printed
     * @return a tree at a certain level
     */
    protected String toString(int level) {

        StringBuffer text = new StringBuffer();
        // if leaf nodes
        if (m_Attribute == null) {
            if (Utils.isMissingValue(m_ClassValue)) {
                text.append(": null");
            } else {
                double correctNum = (int) (m_Distribution[Utils.maxIndex(m_Distribution)] * 100) / 100.0;
                double wrongNum = (int) ((Utils.sum(m_Distribution)
                        - m_Distribution[Utils.maxIndex(m_Distribution)]) * 100) / 100.0;
                String str = "(" + correctNum + "/" + wrongNum + ")";
                text.append(": " + m_ClassAttribute.value((int) m_ClassValue) + str);
            }
        } else {
            for (int j = 0; j < 2; j++) {
                text.append("\n");
                for (int i = 0; i < level; i++) {
                    text.append("|  ");
                }
                if (j == 0) {
                    if (m_Attribute.isNumeric()) {
                        text.append(m_Attribute.name() + " < " + m_SplitValue);
                    } else {
                        text.append(m_Attribute.name() + "=" + m_SplitString);
                    }
                } else {
                    if (m_Attribute.isNumeric()) {
                        text.append(m_Attribute.name() + " >= " + m_SplitValue);
                    } else {
                        text.append(m_Attribute.name() + "!=" + m_SplitString);
                    }
                }
                text.append(m_Successors[j].toString(level + 1));
            }
        }
        return text.toString();
    }

    /**
     * Compute size of the tree.
     * 
     * @return size of the tree
     */
    public int numNodes() {
        if (m_isLeaf) {
            return 1;
        } else {
            int size = 1;
            for (SimpleCart m_Successor : m_Successors) {
                size += m_Successor.numNodes();
            }
            return size;
        }
    }

    /**
     * Method to count the number of inner nodes in the tree.
     * 
     * @return the number of inner nodes
     */
    public int numInnerNodes() {
        if (m_Attribute == null) {
            return 0;
        }
        int numNodes = 1;
        for (SimpleCart m_Successor : m_Successors) {
            numNodes += m_Successor.numInnerNodes();
        }
        return numNodes;
    }

    /**
     * Return a list of all inner nodes in the tree.
     * 
     * @return the list of all inner nodes
     */
    protected Vector<SimpleCart> getInnerNodes() {
        Vector<SimpleCart> nodeList = new Vector<SimpleCart>();
        fillInnerNodes(nodeList);
        return nodeList;
    }

    /**
     * Fills a list with all inner nodes in the tree.
     * 
     * @param nodeList the list to be filled
     */
    protected void fillInnerNodes(Vector<SimpleCart> nodeList) {
        if (!m_isLeaf) {
            nodeList.add(this);
            for (SimpleCart m_Successor : m_Successors) {
                m_Successor.fillInnerNodes(nodeList);
            }
        }
    }

    /**
     * Compute number of leaf nodes.
     * 
     * @return number of leaf nodes
     */
    public int numLeaves() {
        if (m_isLeaf) {
            return 1;
        } else {
            int size = 0;
            for (SimpleCart m_Successor : m_Successors) {
                size += m_Successor.numLeaves();
            }
            return size;
        }
    }

    /**
     * Returns an enumeration describing the available options.
     * 
     * @return an enumeration of all the available options.
     */
    @Override
    public Enumeration<Option> listOptions() {

        Vector<Option> result = new Vector<Option>(6);

        result.addElement(new Option("\tThe minimal number of instances at the terminal nodes.\n" + "\t(default 2)",
                "M", 1, "-M <min no>"));

        result.addElement(
                new Option("\tThe number of folds used in the minimal cost-complexity pruning.\n" + "\t(default 5)",
                        "N", 1, "-N <num folds>"));

        result.addElement(new Option("\tDon't use the minimal cost-complexity pruning.\n" + "\t(default yes).", "U",
                0, "-U"));

        result.addElement(new Option("\tDon't use the heuristic method for binary split.\n" + "\t(default true).",
                "H", 0, "-H"));

        result.addElement(
                new Option("\tUse 1 SE rule to make pruning decision.\n" + "\t(default no).", "A", 0, "-A"));

        result.addElement(
                new Option("\tPercentage of training data size (0-1].\n" + "\t(default 1).", "C", 1, "-C"));

        result.addAll(Collections.list(super.listOptions()));

        return result.elements();
    }

    /**
     * Parses a given list of options.
     * <p/>
     * 
     * <!-- options-start --> Valid options are:
     * <p/>
     * 
     * <pre>
     * -S &lt;num&gt;
     *  Random number seed.
     *  (default 1)
     * </pre>
     * 
     * <pre>
     * -M &lt;min no&gt;
     *  The minimal number of instances at the terminal nodes.
     *  (default 2)
     * </pre>
     * 
     * <pre>
     * -N &lt;num folds&gt;
     *  The number of folds used in the minimal cost-complexity pruning.
     *  (default 5)
     * </pre>
     * 
     * <pre>
     * -U
     *  Don't use the minimal cost-complexity pruning.
     *  (default yes).
     * </pre>
     * 
     * <pre>
     * -H
     *  Don't use the heuristic method for binary split.
     *  (default true).
     * </pre>
     * 
     * <pre>
     * -A
     *  Use 1 SE rule to make pruning decision.
     *  (default no).
     * </pre>
     * 
     * <pre>
     * -C
     *  Percentage of training data size (0-1].
     *  (default 1).
     * </pre>
     * 
     * <!-- options-end -->
     * 
     * @param options the list of options as an array of strings
     * @throws Exception if an options is not supported
     */
    @Override
    public void setOptions(String[] options) throws Exception {
        String tmpStr;

        tmpStr = Utils.getOption('M', options);
        if (tmpStr.length() != 0) {
            setMinNumObj(Double.parseDouble(tmpStr));
        } else {
            setMinNumObj(2);
        }

        tmpStr = Utils.getOption('N', options);
        if (tmpStr.length() != 0) {
            setNumFoldsPruning(Integer.parseInt(tmpStr));
        } else {
            setNumFoldsPruning(5);
        }

        setUsePrune(!Utils.getFlag('U', options));
        setHeuristic(!Utils.getFlag('H', options));
        setUseOneSE(Utils.getFlag('A', options));

        tmpStr = Utils.getOption('C', options);
        if (tmpStr.length() != 0) {
            setSizePer(Double.parseDouble(tmpStr));
        } else {
            setSizePer(1);
        }

        super.setOptions(options);

        Utils.checkForRemainingOptions(options);
    }

    /**
     * Gets the current settings of the classifier.
     * 
     * @return the current setting of the classifier
     */
    @Override
    public String[] getOptions() {

        Vector<String> result = new Vector<String>();

        result.add("-M");
        result.add("" + getMinNumObj());

        result.add("-N");
        result.add("" + getNumFoldsPruning());

        if (!getUsePrune()) {
            result.add("-U");
        }

        if (!getHeuristic()) {
            result.add("-H");
        }

        if (getUseOneSE()) {
            result.add("-A");
        }

        result.add("-C");
        result.add("" + getSizePer());

        Collections.addAll(result, super.getOptions());

        return result.toArray(new String[result.size()]);
    }

    /**
     * Return an enumeration of the measure names.
     * 
     * @return an enumeration of the measure names
     */
    @Override
    public Enumeration<String> enumerateMeasures() {
        Vector<String> result = new Vector<String>();

        result.addElement("measureTreeSize");

        return result.elements();
    }

    /**
     * Return number of tree size.
     * 
     * @return number of tree size
     */
    public double measureTreeSize() {
        return numNodes();
    }

    /**
     * Returns the value of the named measure.
     * 
     * @param additionalMeasureName the name of the measure to query for its value
     * @return the value of the named measure
     * @throws IllegalArgumentException if the named measure is not supported
     */
    @Override
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
            return measureTreeSize();
        } else {
            throw new IllegalArgumentException(additionalMeasureName + " not supported (Cart pruning)");
        }
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String minNumObjTipText() {
        return "The minimal number of observations at the terminal nodes (default 2).";
    }

    /**
     * Set minimal number of instances at the terminal nodes.
     * 
     * @param value minimal number of instances at the terminal nodes
     */
    public void setMinNumObj(double value) {
        m_minNumObj = value;
    }

    /**
     * Get minimal number of instances at the terminal nodes.
     * 
     * @return minimal number of instances at the terminal nodes
     */
    public double getMinNumObj() {
        return m_minNumObj;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String numFoldsPruningTipText() {
        return "The number of folds in the internal cross-validation (default 5).";
    }

    /**
     * Set number of folds in internal cross-validation.
     * 
     * @param value number of folds in internal cross-validation.
     */
    public void setNumFoldsPruning(int value) {
        m_numFoldsPruning = value;
    }

    /**
     * Set number of folds in internal cross-validation.
     * 
     * @return number of folds in internal cross-validation.
     */
    public int getNumFoldsPruning() {
        return m_numFoldsPruning;
    }

    /**
     * Return the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui.
     */
    public String usePruneTipText() {
        return "Use minimal cost-complexity pruning (default yes).";
    }

    /**
     * Set if use minimal cost-complexity pruning.
     * 
     * @param value if use minimal cost-complexity pruning
     */
    public void setUsePrune(boolean value) {
        m_Prune = value;
    }

    /**
     * Get if use minimal cost-complexity pruning.
     * 
     * @return if use minimal cost-complexity pruning
     */
    public boolean getUsePrune() {
        return m_Prune;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui.
     */
    public String heuristicTipText() {
        return "If heuristic search is used for binary split for nominal attributes "
                + "in multi-class problems (default yes).";
    }

    /**
     * Set if use heuristic search for nominal attributes in multi-class problems.
     * 
     * @param value if use heuristic search for nominal attributes in multi-class
     *          problems
     */
    public void setHeuristic(boolean value) {
        m_Heuristic = value;
    }

    /**
     * Get if use heuristic search for nominal attributes in multi-class problems.
     * 
     * @return if use heuristic search for nominal attributes in multi-class
     *         problems
     */
    public boolean getHeuristic() {
        return m_Heuristic;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui.
     */
    public String useOneSETipText() {
        return "Use the 1SE rule to make pruning decisoin.";
    }

    /**
     * Set if use the 1SE rule to choose final model.
     * 
     * @param value if use the 1SE rule to choose final model
     */
    public void setUseOneSE(boolean value) {
        m_UseOneSE = value;
    }

    /**
     * Get if use the 1SE rule to choose final model.
     * 
     * @return if use the 1SE rule to choose final model
     */
    public boolean getUseOneSE() {
        return m_UseOneSE;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui.
     */
    public String sizePerTipText() {
        return "The percentage of the training set size (0-1, 0 not included).";
    }

    /**
     * Set training set size.
     * 
     * @param value training set size
     */
    public void setSizePer(double value) {
        if ((value <= 0) || (value > 1)) {
            System.err.println("The percentage of the training set size must be in range 0 to 1 "
                    + "(0 not included) - ignored!");
        } else {
            m_SizePer = value;
        }
    }

    /**
     * Get training set size.
     * 
     * @return training set size
     */
    public double getSizePer() {
        return m_SizePer;
    }

    /**
     * Returns the revision string.
     * 
     * @return the revision
     */
    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision$");
    }

    /**
     * Main method.
     * 
     * @param args the options for the classifier
     */
    public static void main(String[] args) {
        runClassifier(new SimpleCart(), args);
    }
}