decisiontreeclassifier.ITree2.java Source code

Java tutorial

Introduction

Here is the source code for decisiontreeclassifier.ITree2.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package decisiontreeclassifier;

import decisiontreeclassifier.InstanceTree.Node;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

/**
 *
 * @author Mackenzie Bodily
 */
public class ITree2 extends Classifier {
    private Instances dataSet;
    InstanceTree iTree;
    int binNum;

    /********************************************************************
     * Default constructor. 
     ********************************************************************/
    public ITree2(int bn) {
        binNum = bn;
    }

    /********************************************************************
     * Finds the entropy for the node. 
     ********************************************************************/
    public double findEntropy(Node theNode) {
        double numClasses;

        //allows us to deal with it if the class attribute is numeric.
        if (theNode.dataSet.instance(0).classAttribute().isNumeric()) {
            numClasses = theNode.dataSet.numDistinctValues(theNode.dataSet.instance(0).classAttribute());
        } else {
            numClasses = dataSet.numClasses();
        }
        //initialize the array

        double[] array = new double[(int) numClasses];

        double entropy = 0.0;
        double numInstances = theNode.dataSet.numInstances();

        if (numInstances == 1 || theNode.allTheSame()) {
            return 0;
        }

        //Counts the number of instances of each class in the data set.
        for (int i = 0; i < numInstances; i++) {
            array[(int) theNode.dataSet.instance(i).classValue()] += 1;
        }

        for (int i = 0; i < numClasses; i++) {
            if (array[i] != 0 && numInstances != 0) {
                entropy += ((array[i] / numInstances) * (Math.log((array[i] / numInstances)) / Math.log(2)));
            }
        }

        entropy *= -1;
        return entropy;
    }

    /********************************************************************
     * Finds the range of the values, it's more efficient to do it this
     * way than to look for the highest then call findLowest.
     ********************************************************************/
    public double findRange(int attIndex, Node theNode) {
        double range = 0;
        double highest = Double.NEGATIVE_INFINITY;
        double lowest = Double.POSITIVE_INFINITY;

        int numInstances = theNode.dataSet.numInstances();

        //Find the highest value in the node parameter...
        for (int i = 0; i < numInstances; i++) {
            if (theNode.dataSet.instance(i).value(attIndex) > highest) {
                highest = theNode.dataSet.instance(i).value(attIndex);
            }
            if (theNode.dataSet.instance(i).value(attIndex) < lowest) {
                lowest = theNode.dataSet.instance(i).value(attIndex);
            }
        }
        range = highest - lowest;

        return range;
    }

    /********************************************************************
     * Finds the lowest value in the dataSet. Couldn't use kthlowest since
     * sometimes the item is nominial, not numeric.
     ********************************************************************/
    public double findLowest(int attIndex, Node theNode) {
        double lowest = Double.POSITIVE_INFINITY;
        int numInstances = theNode.dataSet.numInstances();

        for (int i = 0; i < numInstances; i++) {
            if (theNode.dataSet.instance(i).value(attIndex) < lowest) {
                lowest = theNode.dataSet.instance(i).value(attIndex);
            }
        }
        return lowest;
    }

    /********************************************************************
     * Initializes a list - nothing fancy.
     ********************************************************************/
    public void initializeList(List<Double> theList, int numItems) {
        for (int i = 0; i < numItems; i++) {
            theList.add(0.00);
        }
    }

    /********************************************************************
     * Finds the entropies for each of the splits (the number of children
     * is indicated by binNum.
     ********************************************************************/
    public double findNodeScore(double range, int attIndex, double totalEntropy, Node theNode) {
        double score = 0;
        double increment = range / binNum;
        double lowest = findLowest(attIndex, theNode);
        double grandTotal = 0.00;
        int numClasses = theNode.dataSet.numClasses();
        double numInstances = theNode.dataSet.numInstances();

        List<double[]> ranges = new ArrayList<double[]>();
        List<Double> entropies = new ArrayList<Double>();
        List<Double> totals = new ArrayList<Double>();

        for (int i = 0; i < binNum; i++) {
            double[] rangeArray = new double[numClasses];
            ranges.add(rangeArray);
        }

        //count the number that will fall into each range...
        for (int i = 0; i < numInstances; i++) {
            for (int j = 0; j < numClasses; j++) {
                if (theNode.dataSet.instance(i).value(attIndex) <= (lowest + (increment * (j + 1)))) {
                    ranges.get(j)[(int) theNode.dataSet.instance(i).classValue()] += 1;
                    break;
                }
            }
        }

        initializeList(entropies, binNum);
        initializeList(totals, binNum);

        //Calculates the entropies for each bin...
        for (int i = 0; i < binNum; i++) {
            for (int j = 0; j < numClasses; j++) {
                double temp = (totals.get(i)) + (ranges.get(i)[j]);
                totals.set(i, temp);
                if (ranges.get(i)[j] != 0) {
                    double temp2 = entropies.get(i) + ((ranges.get(i)[j] / numInstances)
                            * (Math.log((ranges.get(i)[j] / numInstances)) / Math.log(2)));
                    entropies.set(i, temp2);
                }
            }
        }

        for (int i = 0; i < binNum; i++) {
            double temp = (entropies.get(i) * -1);
            entropies.set(i, temp);
            grandTotal += totals.get(i);
        }

        //Find the average of the results
        for (int i = 0; i < binNum; i++) {
            score += (totals.get(i) / grandTotal) * entropies.get(i);
        }

        return score;
    }

    /********************************************************************
     * Finds the score for the node so that it can be determined which
     * one is best to split on.
     ********************************************************************/
    public int scoreForNode(Node theNode) {
        //Find the Entropy for the whole thing...

        double score = 0;
        double entropy = findEntropy(theNode);

        int numAttributes = theNode.dataSet.numAttributes() - 1;

        double bestAttributeScore = Double.NEGATIVE_INFINITY;
        int bestAttribute = 0;

        for (int i = 0; i < numAttributes; i++) {
            if (!theNode.usedFeatures.contains(i)) {
                double range = findRange(i, theNode);
                score = findNodeScore(range, i, entropy, theNode);

                if (score > bestAttributeScore) {
                    bestAttributeScore = score;
                    bestAttribute = i;
                }

            }
        }

        return bestAttribute;
    }

    /********************************************************************
     * Makes a deep copy of the usedFeatures list. 
     ********************************************************************/
    public ArrayList copyArrayList(Node root) {
        ArrayList copiedList = new ArrayList<Integer>();

        ListIterator<Integer> listIterator = root.usedFeatures.listIterator();
        ArrayList temp = new ArrayList<Integer>();
        temp = root.usedFeatures;

        while (listIterator.hasNext()) {
            int i = listIterator.next();
            copiedList.add(listIterator.next());
        }

        return copiedList;

    }

    /********************************************************************
     * Adds children nodes with all the necessary data to the Node 
     * parameter
     ********************************************************************/
    public void addTreeNode(int attIndex, Node theNode) {
        double range = findRange(attIndex, theNode);
        double increment = range / binNum;
        double lowest = findLowest(attIndex, theNode);
        int numInstances = theNode.dataSet.numInstances();

        for (int i = 0; i < binNum; i++) {
            Node child;
            Instances InstCopy = new Instances(theNode.dataSet);
            ArrayList UFcopy = (ArrayList<Integer>) theNode.usedFeatures.clone();
            if (i == 0) {
                child = new Node(InstCopy, UFcopy, 0, Double.NEGATIVE_INFINITY, (lowest + increment), attIndex,
                        iTree.root);
            } else if (i == (binNum - 1)) {
                child = new Node(InstCopy, UFcopy, i, (lowest + (i * increment)), Double.POSITIVE_INFINITY,
                        attIndex, iTree.root);
            } else {
                child = new Node(InstCopy, UFcopy, i, (lowest + (increment * i)), (lowest + (increment * (i + 1))),
                        attIndex, iTree.root);
            }
            theNode.addChild(child);
        }

        filterInstances(theNode, attIndex);
    }

    /********************************************************************
     * Goes through for each child of a node and if the value does 
     * not fall in the range associated with that value, it removes
     * it from the child node's dataSet.
     ********************************************************************/
    public void filterInstances(Node theNode, int attIndex) {
        int numInstances;
        for (int i = 0; i < theNode.numChildren; i++) {
            numInstances = theNode.getChildAt(i).dataSet.numInstances();
            for (int j = 0; j < numInstances; j++) {
                if (!(theNode.getChildAt(i).withinRange(theNode.getChildAt(i).dataSet.instance(j).value(attIndex),
                        attIndex))) {
                    theNode.getChildAt(i).dataSet.delete(j);
                    j -= 1;
                    numInstances -= 1;
                }
            }
        }
    }

    /********************************************************************
     * Builds the tree used by the algorithm. 
     ********************************************************************/
    public void buildTree(Node root) {
        //create the root, this should contain all the instances.
        int attribute2Split;
        Node nodeOfInterest = iTree.root;
        Boolean done = false;

        while (!done) {
            if (!nodeOfInterest.allFeaturesUsed() && !nodeOfInterest.allTheSame()
                    && nodeOfInterest.numChildren != binNum) {
                attribute2Split = scoreForNode(nodeOfInterest);

                if (nodeOfInterest == iTree.root) {
                    nodeOfInterest.splitOn = attribute2Split;
                }
                nodeOfInterest.usedFeatures.add(attribute2Split);
                addTreeNode(attribute2Split, nodeOfInterest);
                nodeOfInterest = nodeOfInterest.getChildAt(0);
            } else {
                if (nodeOfInterest.index != (binNum - 1)) {
                    nodeOfInterest = nodeOfInterest.getSibling();
                } else {
                    while (nodeOfInterest.index == (binNum - 1)) {
                        nodeOfInterest = nodeOfInterest.getParent();
                        if (nodeOfInterest == iTree.root) {
                            done = true;
                            break;
                        }
                    }
                }
            }
        }
    }

    /********************************************************************
     * Changes the missing data to 0.0. For the voting data set, this 
     * should be sufficient seeing as 0.00 is more or less random.
     ********************************************************************/
    public Instances fixMissingData(Instances iToFix) {
        for (int i = 0; i < iToFix.numInstances(); i++) {
            for (int j = 0; j < iToFix.numAttributes(); j++) {
                if (iToFix.instance(i).isMissing(j)) {
                    iToFix.instance(i).setValue(j, 0.0);
                }
            }
        }
        return iToFix;
    }

    /********************************************************************
     * Builds the classifier using the instances. 
     ********************************************************************/
    @Override
    public void buildClassifier(Instances i) throws Exception {
        dataSet = fixMissingData(i);
        iTree = new InstanceTree(dataSet);
        buildTree(iTree.root);

        iTree.printTree();
    }

    /********************************************************************
     * Classifies the instance passed to it. 
     ********************************************************************/
    @Override
    public double classifyInstance(Instance inst) throws Exception {
        Node nodeBeingChecked = iTree.root;
        Node nextNode;
        while (nodeBeingChecked.hasChildren()) {
            nextNode = nodeBeingChecked.findChild(inst);

            if (nextNode == null) {
                return 1.00;
            }
            nodeBeingChecked = nextNode;
        }

        if (nodeBeingChecked.dataSet.numInstances() != 0) {
            return (nodeBeingChecked.makeGuess());
        } else {
            return (nodeBeingChecked.getParent().makeGuess());

        }
    }
}