newdtl.NewJ48.java Source code

Java tutorial

Introduction

Here is the source code for newdtl.NewJ48.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package newdtl;

import static java.lang.System.exit;
import java.util.Enumeration;
import java.util.stream.DoubleStream;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.NoSupportForMissingValuesException;

public class NewJ48 extends Classifier {

    private final double DOUBLE_MISSING_VALUE = Double.NaN;
    private final double DOUBLE_ERROR_MAXIMUM = 1e-6;

    /**
     * The node's children.
     */
    private NewJ48[] children;

    /**
     * Attribute used for splitting.
     */
    private Attribute splitAttribute;

    /**
     * Threshold used for splitting if attribute is numeric.
     */
    private double splitThreshold;

    /**
     * Class value if node is leaf.
     */
    private double label;

    /**
     * Class distribution.
     */
    private double[] classDistributions;

    /**
     * Class attribute of dataset.
     */
    private Attribute classAttribute;

    /**
     * True if node is leaf.
     */
    private boolean isLeaf;

    /**
     * Returns default capabilities of the classifier.
     *
     * @return the capabilities of this classifier
     */
    @Override
    public Capabilities getCapabilities() {

        Capabilities result = super.getCapabilities();
        result.disableAll();

        // attributes
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);

        // class
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);

        // instances
        result.setMinimumNumberInstances(0);

        return result;
    }

    /**
     * Builds J48 tree classifier.
     *
     * @param data the training data
     * @exception Exception if classifier failed to build
     */
    @Override
    public void buildClassifier(Instances data) throws Exception {

        // Mengecek apakah data dapat dibuat classifier
        getCapabilities().testWithFail(data);

        // Menghapus instances dengan missing class
        data = new Instances(data);
        data.deleteWithMissingClass();

        makeTree(data);

        pruneTree(data);
    }

    /**
     * Creates a J48 tree.
     *
     * @param data the training data
     * @exception Exception if tree failed to build
     */
    private void makeTree(Instances data) throws Exception {

        // Mengecek apakah tidak terdapat instance dalam node ini
        if (data.numInstances() == 0) {
            splitAttribute = null;
            label = DOUBLE_MISSING_VALUE;
            classDistributions = new double[data.numClasses()];
            isLeaf = true;
        } else {
            // Mencari Gain Ratio maksimum
            double[] gainRatios = new double[data.numAttributes()];
            double[] thresholds = new double[data.numAttributes()];

            Enumeration attEnum = data.enumerateAttributes();
            while (attEnum.hasMoreElements()) {
                Attribute att = (Attribute) attEnum.nextElement();
                double[] result = computeGainRatio(data, att);
                gainRatios[att.index()] = result[0];
                thresholds[att.index()] = result[1];
            }

            splitAttribute = data.attribute(maxIndex(gainRatios));

            if (splitAttribute.isNumeric()) {
                splitThreshold = thresholds[maxIndex(gainRatios)];
            } else {
                splitThreshold = Double.NaN;
            }

            classDistributions = new double[data.numClasses()];
            for (int i = 0; i < data.numInstances(); i++) {
                Instance inst = (Instance) data.instance(i);
                classDistributions[(int) inst.classValue()]++;
            }

            // Membuat daun jika Gain Ratio-nya 0
            if (Double.compare(gainRatios[splitAttribute.index()], 0) == 0) {
                splitAttribute = null;

                label = maxIndex(classDistributions);
                classAttribute = data.classAttribute();
                isLeaf = true;
            } else {
                // Mengecek jika ada missing value
                if (isMissing(data, splitAttribute)) {
                    // cari modus
                    int index = modusIndex(data, splitAttribute);

                    // ubah data yang punya missing value
                    Enumeration dataEnum = data.enumerateInstances();
                    while (dataEnum.hasMoreElements()) {
                        Instance inst = (Instance) dataEnum.nextElement();
                        if (inst.isMissing(splitAttribute)) {
                            inst.setValue(splitAttribute, splitAttribute.value(index));
                        }
                    }
                }

                // Membuat tree baru di bawah node ini
                Instances[] splitData;
                if (splitAttribute.isNumeric()) {
                    splitData = splitData(data, splitAttribute, splitThreshold);
                    children = new NewJ48[2];
                    for (int j = 0; j < 2; j++) {
                        children[j] = new NewJ48();
                        children[j].makeTree(splitData[j]);
                    }
                } else {
                    splitData = splitData(data, splitAttribute);
                    children = new NewJ48[splitAttribute.numValues()];
                    for (int j = 0; j < splitAttribute.numValues(); j++) {
                        children[j] = new NewJ48();
                        children[j].makeTree(splitData[j]);
                    }
                }
                isLeaf = false;
            }
        }
    }

    /**
     * Creates a pruned J48 tree using expected error pruning.
     *
     * @param data the training data
     */
    private double pruneTree(Instances data) throws Exception {

        double staticError = staticErrorEstimate((int) DoubleStream.of(classDistributions).sum(),
                (int) classDistributions[maxIndex(classDistributions)], classDistributions.length);

        if (isLeaf) {
            return staticError;
        } else {
            double backupError = 0;
            double totalInstances = DoubleStream.of(classDistributions).sum();

            for (NewJ48 children1 : children) {
                double totalChildInstances = DoubleStream.of(children1.classDistributions).sum();
                backupError += totalChildInstances / totalInstances * children1.pruneTree(data);
            }

            if (staticError < backupError) {
                splitAttribute = null;
                label = maxIndex(classDistributions);
                classAttribute = data.classAttribute();
                isLeaf = true;
                children = null;

                return staticError;
            } else {
                return backupError;
            }
        }
    }

    /**
     * Classifies a given test instance using the decision tree.
     *
     * @param instance the instance to be classified
     * @return the classification
     */
    @Override
    public double classifyInstance(Instance instance) throws NoSupportForMissingValuesException {
        if (instance.hasMissingValue()) {
            throw new NoSupportForMissingValuesException("NewID3: Cannot handle missing values");
        }

        if (splitAttribute == null) {
            return label;
        } else {
            if (splitAttribute.isNumeric()) {
                if (Double.compare(instance.value(splitAttribute), splitThreshold) <= 0) {
                    return children[0].classifyInstance(instance);
                } else {
                    return children[1].classifyInstance(instance);
                }
            } else {
                return children[(int) instance.value(splitAttribute)].classifyInstance(instance);
            }
        }
    }

    /**
     * Computes class distribution for instance using decision tree.
     *
     * @param instance the instance for which distribution is to be computed
     * @return the class distribution for the given instance
     */
    @Override
    public double[] distributionForInstance(Instance instance) throws NoSupportForMissingValuesException {
        if (instance.hasMissingValue()) {
            throw new NoSupportForMissingValuesException("NewID3: Cannot handle missing values");
        }

        if (splitAttribute == null) {
            return normalize(classDistributions);
        } else {
            if (splitAttribute.isNumeric()) {
                if (Double.compare(instance.value(splitAttribute), splitThreshold) <= 0) {
                    return children[0].distributionForInstance(instance);
                } else {
                    return children[1].distributionForInstance(instance);
                }
            } else {
                return children[(int) instance.value(splitAttribute)].distributionForInstance(instance);
            }
        }
    }

    /**
     * split the dataset based on nominal attribute 
     *
     * @param data dataset used for splitting
     * @param att attribute used to split the dataset
     * @return array of instances which has been split by attribute
     */
    private Instances[] splitData(Instances data, Attribute att) {

        Instances[] splitData = new Instances[att.numValues()];
        for (int j = 0; j < att.numValues(); j++) {
            splitData[j] = new Instances(data, data.numInstances());
        }

        for (int i = 0; i < data.numInstances(); i++) {
            splitData[(int) data.instance(i).value(att)].add(data.instance(i));
        }

        for (Instances splitData1 : splitData) {
            splitData1.compactify();
        }

        return splitData;
    }

    /**
     * split the dataset based on attribute for numeric attribute
     *
     * @param data dataset used for splitting
     * @param att attribute used to split the dataset
     * @param threshold the threshold value
     * @return
     */
    private Instances[] splitData(Instances data, Attribute att, double threshold) {

        Instances[] splitData = new Instances[2];
        for (int j = 0; j < 2; j++) {
            splitData[j] = new Instances(data, data.numInstances());
        }

        for (int i = 0; i < data.numInstances(); i++) {
            if (Double.compare(data.instance(i).value(att), threshold) <= 0) {
                splitData[0].add(data.instance(i));
            } else {
                splitData[1].add(data.instance(i));
            }
        }

        for (Instances splitData1 : splitData) {
            splitData1.compactify();
        }

        return splitData;
    }

    /**
     * Computes Gain Ratio for an attribute.
     *
     * @param data the data for which gain ratio is to be computed
     * @param att the attribute
     * @return the gain ratio for the given attribute and data
     * @throws Exception if computation fails
     */
    private double[] computeGainRatio(Instances data, Attribute att) {
        if (att.isNumeric()) {
            data.sort(att);
            double[] threshold;
            double[] gainRatios;

            if (data.numInstances() == 1) {
                threshold = new double[1];
                gainRatios = new double[1];

                threshold[0] = data.instance(0).value(att);
                double infoGain = computeInfoGain(data, att, threshold[0]);
                double splitInfo = computeSplitInformation(data, att, threshold[0]);
                gainRatios[0] = infoGain > 0 ? infoGain / splitInfo : infoGain;
            } else {
                threshold = new double[data.numInstances() - 1];
                gainRatios = new double[data.numInstances() - 1];
                for (int i = 0; i < data.numInstances() - 1; i++) {
                    threshold[i] = data.instance(i).value(att);
                    double infoGain = computeInfoGain(data, att, threshold[i]);
                    double splitInfo = computeSplitInformation(data, att, threshold[i]);
                    gainRatios[i] = infoGain > 0 ? infoGain / splitInfo : infoGain;
                }
            }
            return new double[] { gainRatios[maxIndex(gainRatios)], threshold[maxIndex(gainRatios)] };
        } else {
            double infoGain = computeInfoGain(data, att);
            double splitInfo = computeSplitInformation(data, att);

            return new double[] { splitInfo > 0 ? infoGain / splitInfo : splitInfo, 0 };
        }
    }

    /**
     * Computes information gain for an attribute.
     *
     * @param data the data for which info gain is to be computed
     * @param att the attribute
     * @return the information gain for the given attribute and data
     * @throws Exception if computation fails
     */
    private double computeInfoGain(Instances data, Attribute att) {

        double infoGain = computeEntropy(data);
        Instances[] splitData = splitData(data, att);
        for (Instances splitdata : splitData) {
            if (splitdata.numInstances() > 0) {
                double splitNumInstances = splitdata.numInstances();
                double dataNumInstances = data.numInstances();
                double proportion = splitNumInstances / dataNumInstances;
                infoGain -= proportion * computeEntropy(splitdata);
            }
        }
        return infoGain;
    }

    /**
     * Computes information gain for a numeric attribute.
     *
     * @param data the data for which info gain is to be computed
     * @param att the attribute
     * @return the information gain for the given attribute and data
     * @throws Exception if computation fails
     */
    private double computeInfoGain(Instances data, Attribute att, double threshold) {

        double infoGain = computeEntropy(data);
        Instances[] splitData = splitData(data, att, threshold);
        for (Instances splitdata : splitData) {
            if (splitdata.numInstances() > 0) {
                double splitNumInstances = splitdata.numInstances();
                double dataNumInstances = data.numInstances();
                double proportion = splitNumInstances / dataNumInstances;
                infoGain -= proportion * computeEntropy(splitdata);
            }
        }
        return infoGain;
    }

    /**
     * Computes the entropy of a dataset.
     *
     * @param data the data for which entropy is to be computed
     * @return the entropy of the data class distribution
     * @throws Exception if computation fails
     */
    private double computeEntropy(Instances data) {

        double[] labelCounts = new double[data.numClasses()];
        for (int i = 0; i < data.numInstances(); ++i) {
            labelCounts[(int) data.instance(i).classValue()]++;
        }

        double entropy = 0;
        for (int i = 0; i < labelCounts.length; i++) {
            if (labelCounts[i] > 0) {
                double proportion = labelCounts[i] / data.numInstances();
                entropy -= (proportion) * log2(proportion);
            }
        }
        return entropy;
    }

    /**
     * Computes Split information for an attribute.
     *
     * @param data the data for which split information is to be computed
     * @param att the attribute
     * @return the split information for the given attribute and data
     * @throws Exception if computation fails
     */
    private double computeSplitInformation(Instances data, Attribute att) {

        double splitInfo = 0;
        Instances[] splitData = splitData(data, att);
        double dataNumInstances = data.numInstances();

        for (Instances splitdata : splitData) {
            if (splitdata.numInstances() > 0) {
                double splitNumInstances = splitdata.numInstances();
                double proportion = splitNumInstances / dataNumInstances;
                splitInfo -= proportion * log2(proportion);
            }
        }
        return splitInfo;
    }

    /**
     * Computes Split information for a numeric attribute.
     *
     * @param data the data for which split information is to be computed
     * @param att the attribute
     * @return the split information for the given attribute and data
     * @throws Exception if computation fails
     */
    private double computeSplitInformation(Instances data, Attribute att, double threshold) {

        double splitInfo = 0;
        Instances[] splitData = splitData(data, att, threshold);
        double dataNumInstances = data.numInstances();

        for (Instances splitdata : splitData) {
            if (splitdata.numInstances() > 0) {
                double splitNumInstances = splitdata.numInstances();
                double proportion = splitNumInstances / dataNumInstances;
                splitInfo -= proportion * log2(proportion);
            }
        }
        return splitInfo;
    }

    /**
     * Computes static error estimate for pruning.
     *
     * @param N the number of instances
     * @param n number of instance in majority class
     * @param k number of class value
     * @return the static error estimate
     * 
     */
    private double staticErrorEstimate(int N, int n, int k) {
        double E = (N - n + k - 1) / (double) (N + k);

        return E;
    }

    public double backUpError() {

        double E = 0;
        double totalInstances = DoubleStream.of(classDistributions).sum();
        for (NewJ48 child : children) {
            double totalChildInstances = DoubleStream.of(child.classDistributions).sum();
            E += totalChildInstances / totalInstances * staticErrorEstimate((int) totalChildInstances,
                    (int) child.classDistributions[(int) child.label], child.classDistributions.length);

        }

        return E;
    }

    /**
     * search data that has missing value for attribute
     *
     * @param data the data for searching
     * @param att the attribute for searching
     * @return if data has missing value for attribute
     */
    private boolean isMissing(Instances data, Attribute att) {

        boolean isMissingValue = false;
        Enumeration dataEnum = data.enumerateInstances();

        while (dataEnum.hasMoreElements() && !isMissingValue) {
            Instance inst = (Instance) dataEnum.nextElement();
            if (inst.isMissing(att)) {
                isMissingValue = true;
            }
        }

        return isMissingValue;
    }

    /**
     * search index of attribute that has most common value
     *
     * @param data the data for searching
     * @param att the attribute for searching
     * @return index of attribute that has most common value
     */
    private int modusIndex(Instances data, Attribute att) {
        // cari modus
        int[] modus = new int[att.numValues()];
        Enumeration dataEnumeration = data.enumerateInstances();

        while (dataEnumeration.hasMoreElements()) {
            Instance inst = (Instance) dataEnumeration.nextElement();
            if (!inst.isMissing(att)) {
                modus[(int) inst.value(att)]++;
            }
        }

        // cari modus terbesar
        int indexMax = 0;
        for (int i = 1; i < modus.length; ++i) {
            if (modus[i] > modus[indexMax]) {
                indexMax = i;
            }
        }
        return indexMax;
    }

    /**
     * Prints the decision tree using the private toString method from below.
     *
     * @return a textual description of the classifier
     */
    @Override
    public String toString() {

        if ((classDistributions == null) && (children == null)) {
            return "NewJ48: No model built yet.";
        }
        return "NewJ48 pruned tree\n------------------\n" + toString(0);
    }

    /**
     * Outputs a tree at a certain level.
     *
     * @param level the level at which the tree is to be printed
     * @return the tree as string at the given level
     */
    private String toString(int level) {

        StringBuilder text = new StringBuilder();

        if (splitAttribute == null) {
            if (Instance.isMissingValue(label)) {
                text.append(": null");
            } else {
                text.append(": ").append(classAttribute.value((int) label));

                double totalInstances = DoubleStream.of(classDistributions).sum();
                text.append(" (").append(totalInstances);

                double wrongClass = totalInstances - classDistributions[(int) label];
                if (wrongClass > 0) {
                    text.append("/").append(totalInstances - classDistributions[(int) label]);
                }

                text.append(")");
            }
        } else {
            if (splitAttribute.isNumeric()) {
                for (int j = 0; j < 2; j++) {
                    text.append("\n");
                    for (int i = 0; i < level; i++) {
                        text.append("|  ");
                    }
                    if (j == 0) {
                        text.append(splitAttribute.name()).append(" <= ").append(splitThreshold);
                    } else {
                        text.append(splitAttribute.name()).append(" > ").append(splitThreshold);
                    }
                    text.append(children[j].toString(level + 1));
                }
            } else {
                for (int j = 0; j < splitAttribute.numValues(); j++) {
                    text.append("\n");
                    for (int i = 0; i < level; i++) {
                        text.append("|  ");
                    }
                    text.append(splitAttribute.name()).append(" = ").append(splitAttribute.value(j));
                    text.append(children[j].toString(level + 1));
                }
            }
        }

        return text.toString();
    }

    /**
     * Count the logarithm value with base 2 of a number
     *
     * @param num number that will be counted
     * @return logarithm value with base 2
     */
    private double log2(double num) {

        return (num == 0) ? 0 : Math.log(num) / Math.log(2);
    }

    /**
     * Search for index with largest value from array of double
     *
     * @param array the array of double
     * @return index of array with maximum value, -1 if array empty
     */
    private int maxIndex(double[] array) {
        int max = 0;

        if (array.length > 0) {
            for (int i = 1; i < array.length; ++i) {
                if (array[i] > array[max]) {
                    max = i;
                }
            }
            return max;
        } else {
            return -1;
        }
    }

    /**
     * Normalize the class distribution
     *
     * @exception Exception if sum of class distribution is 0 or NAN
     */
    private double[] normalize(double[] array) {
        double sum = DoubleStream.of(array).sum();
        double[] newArray = new double[array.length];
        if (!Double.isNaN(sum) && sum != 0) {
            for (int i = 0; i < array.length; ++i) {
                newArray[i] = array[i] / sum;
            }
            return newArray;
        } else {
            return array;
        }
    }
}