myid3andc45classifier.Model.MyC45.java Source code

Java tutorial

Introduction

Here is the source code for myid3andc45classifier.Model.MyC45.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package myid3andc45classifier.Model;

import java.util.ArrayList;
import static java.util.Collections.copy;
import java.util.Enumeration;
import static javafx.collections.FXCollections.copy;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;

/**
 *
 * @author ryanyonata
 */
public class MyC45 extends Classifier {

    private MyC45[] successors;
    private Attribute attribute;
    private double label;
    private double[] distribution;
    private Attribute classAttribute;
    private Attribute splittedAttribute;
    private static final double epsilon = 1e-6;
    private boolean pruned = false;

    @Override
    public void buildClassifier(Instances data) throws Exception {
        getCapabilities().testWithFail(data);

        data = new Instances(data);
        data.deleteWithMissingClass();

        Enumeration enumAtt = data.enumerateAttributes();
        while (enumAtt.hasMoreElements()) {
            Attribute attr = (Attribute) enumAtt.nextElement();
            if (attr.isNumeric()) {
                ArrayList<Double> mid = new ArrayList<Double>();
                Instances savedData = null;
                double temp, max = Double.NEGATIVE_INFINITY;
                // TODO: split nominal
                data.sort(attr);
                for (int i = 0; i < data.numInstances() - 1; i++) {
                    if (data.instance(i).classValue() != data.instance(i + 1).classValue()) {
                        if (data.attribute(attr.name() + " "
                                + (data.instance(i + 1).value(attr) + data.instance(i).value(attr)) / 2) == null) {
                            data = convertInstances(data, attr,
                                    (data.instance(i + 1).value(attr) + data.instance(i).value(attr)) / 2);
                            //temp = computeInfoGainRatio(newData, newData.attribute(newData.numAttributes()-1));
                            //System.out.println("attribute "+newData.attribute(newData.numAttributes()-1).name());
                            //if (temp > max) {
                            //    max = temp;
                            //    savedData = newData;
                            //}
                        }
                    }
                }

                //Penanganan Missing Value
                AttributeStats attributeStats = data.attributeStats(attr.index());
                double mean = attributeStats.numericStats.mean;
                if (Double.isNaN(mean))
                    mean = 0;
                // Replace missing value with mean
                Enumeration instEnumerate = data.enumerateInstances();
                while (instEnumerate.hasMoreElements()) {
                    Instance instance = (Instance) instEnumerate.nextElement();
                    if (instance.isMissing(attr.index())) {
                        instance.setValue(attr.index(), mean);
                    }
                }

                //data = new Instances(savedData);
            } else {
                //Penanganan Missing Value
                AttributeStats attributeStats = data.attributeStats(attr.index());
                int maxIndex = 0;
                for (int i = 1; i < attr.numValues(); i++) {
                    if (attributeStats.nominalCounts[maxIndex] < attributeStats.nominalCounts[i]) {
                        maxIndex = i;
                    }
                }
                // Replace missing value with max index
                Enumeration instEnumerate = data.enumerateInstances();
                while (instEnumerate.hasMoreElements()) {
                    Instance instance = (Instance) instEnumerate.nextElement();
                    if (instance.isMissing(attr.index())) {
                        instance.setValue(attr.index(), maxIndex);
                    }
                }
            }
        }
        makeMyC45Tree(data);

    }

    @Override
    public double classifyInstance(Instance instance) {
        int i = 0;
        if (attribute == null) {
            return label;
        } else {
            boolean numeric = false;
            for (int j = 0; j < instance.numAttributes(); j++) {
                if (instance.attribute(j).isNumeric()) {
                    if (instance.attribute(j).name().equalsIgnoreCase(attribute.name().split(" ")[0])) {
                        numeric = true;
                        break;
                    }
                    i++;
                }
            }
            if (numeric) {
                double threshold = Double.parseDouble(attribute.name().split(" ")[1]);
                //System.out.println("WOWW!!! " + attribute.name() + " threshold is " + threshold);
                double val = (double) instance.value(i);
                if (val <= threshold) {
                    return successors[(int) attribute.indexOfValue("<=" + threshold)].classifyInstance(instance);
                    //instance.setValue(attribute, "<="+threshold);
                } else {
                    return successors[(int) attribute.indexOfValue(">" + threshold)].classifyInstance(instance);
                    //instance.setValue(attribute, ">"+threshold);
                }
            }

            return successors[(int) instance.value(attribute)].classifyInstance(instance);
        }

    }

    public void makeMyC45Tree(Instances data) throws Exception {
        if (data.numInstances() == 0) {
            attribute = null;
            label = Instance.missingValue();
            return;
        }
        //System.out.println("NEW");
        double[] infoGainRatios = new double[data.numAttributes()];
        Enumeration attEnum = data.enumerateAttributes();
        while (attEnum.hasMoreElements()) {
            Attribute att = (Attribute) attEnum.nextElement();
            if (!att.isNumeric())
                infoGainRatios[att.index()] = computeInfoGainRatio(data, att);
            else
                infoGainRatios[att.index()] = Double.NEGATIVE_INFINITY;
            //System.out.println(att.name() + " " + infoGainRatios[att.index()]);
        }

        // TODO: build the tree
        attribute = data.attribute(maxIndex(infoGainRatios));
        //System.out.println(infoGainRatios[maxIndex(infoGainRatios)]);
        // Make leaf if information gain is zero. 
        // Otherwise create successors.
        if (infoGainRatios[maxIndex(infoGainRatios)] <= epsilon
                || Double.isNaN(infoGainRatios[maxIndex(infoGainRatios)])) {
            attribute = null;
            double[] numClasses = new double[data.numClasses()];

            Enumeration instEnum = data.enumerateInstances();
            while (instEnum.hasMoreElements()) {
                Instance inst = (Instance) instEnum.nextElement();
                numClasses[(int) inst.classValue()]++;
            }

            label = maxIndex(numClasses);
            classAttribute = data.classAttribute();
        } else {
            classAttribute = data.classAttribute();
            Instances[] splitData = splitInstancesByAttribute(data, attribute);
            Instances[] distrData = splitInstancesByAttribute(data, data.classAttribute());
            distribution = new double[distrData.length];
            for (int j = 0; j < distribution.length; j++) {
                distribution[j] = distrData[j].numInstances();
            }
            successors = new MyC45[attribute.numValues()];
            for (int j = 0; j < attribute.numValues(); j++) {
                successors[j] = new MyC45();
                successors[j].buildClassifier(splitData[j]);
            }
        }
        // TODO: prune
        //pruneTree(data);
    }

    public double[] listClassCountsValues(Instances data) throws Exception {

        double[] classCounts = new double[data.numClasses()]; //array untuk menyimpan value kelas sesuai jumlah kelas
        Enumeration instanceEnum = data.enumerateInstances();

        //Masukkan data ke array
        while (instanceEnum.hasMoreElements()) {
            Instance inst = (Instance) instanceEnum.nextElement();
            classCounts[(int) inst.classValue()]++;
        }

        return classCounts;
    }

    public double computeEntropy(Instances data) throws Exception {

        double entropy = 0;

        double[] classCounts = listClassCountsValues(data);
        for (int i = 0; i < data.numClasses(); i++) {
            if (classCounts[i] > 0) {
                double p = classCounts[i] / (double) data.numInstances();
                entropy -= p * (Utils.log2(p));
            }
        }

        return entropy;
    }

    public Instances[] splitInstancesByAttribute(Instances data, Attribute attr) throws Exception {
        //Split data menjadi beberapa instances sesuai dengan jumlah jenis data pada atribut
        Instances[] splitData = new Instances[attr.numValues()];

        for (int i = 0; i < attr.numValues(); i++) {
            splitData[i] = new Instances(data, data.numInstances());
        }

        Enumeration instanceEnum = data.enumerateInstances();
        while (instanceEnum.hasMoreElements()) {
            Instance inst = (Instance) instanceEnum.nextElement();
            splitData[(int) inst.value(attr)].add(inst);
        }

        for (int i = 0; i < splitData.length; i++) {
            splitData[i].compactify();
        }

        return splitData;
    }

    public double computeInfoGainRatio(Instances data, Attribute attr) throws Exception {
        double attributeEntropy = 0;
        double attributeSplitInfo = 0;

        Instances[] splitData = splitInstancesByAttribute(data, attr);
        for (int i = 0; i < splitData.length; i++) {
            double p = splitData[i].numInstances() / (double) data.numInstances();
            attributeEntropy += p * computeEntropy(splitData[i]);
            attributeSplitInfo -= p * Utils.log2(p);
        }
        return (computeEntropy(data) - attributeEntropy) / attributeSplitInfo;

    }

    public double computeInfoGain(Instances data, Attribute attr) throws Exception {

        double attributeEntropy = 0;

        Instances[] splitData = splitInstancesByAttribute(data, attr);
        for (int i = 0; i < splitData.length; i++) {
            double p = splitData[i].numInstances() / (double) data.numInstances();
            attributeEntropy += p * computeEntropy(splitData[i]);
        }

        return computeEntropy(data) - attributeEntropy;

    }

    @Override
    public Capabilities getCapabilities() {

        Capabilities result = super.getCapabilities();
        result.disableAll();

        // attributes
        result.enable(Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capability.DATE_ATTRIBUTES);
        result.enable(Capability.MISSING_VALUES);

        // class
        result.enable(Capability.NOMINAL_CLASS);
        result.enable(Capability.MISSING_CLASS_VALUES);

        // instances
        result.setMinimumNumberInstances(0);

        return result;

    }

    private Instances convertInstances(Instances data, Attribute att, double threshold) {
        Instances newData = new Instances(data);
        int idx = att.index();
        String name = att.name();

        try {
            Add filter = new Add();
            //filter.setAttributeIndex((idx + 2) + "");
            filter.setNominalLabels("<=" + threshold + ",>" + threshold);
            filter.setAttributeName(name + " " + threshold);
            filter.setInputFormat(newData);
            newData = Filter.useFilter(newData, filter);
        } catch (Exception e) {
            e.printStackTrace();
        }
        //System.out.println("Base attribute "+name+" index "+newData.attribute(name).index());
        //System.out.println("New attribute "+newData.attribute(name + " " + threshold).name()+" index "+newData.attribute(name + " " + threshold).index());
        for (int i = 0; i < newData.numInstances(); ++i) {
            if ((double) newData.instance(i).value(newData.attribute(idx)) <= threshold) {
                newData.instance(i).setValue(newData.attribute(name + " " + threshold), "<=" + threshold);
            } else {
                newData.instance(i).setValue(newData.attribute(name + " " + threshold), ">" + threshold);
            }
        }

        //newData.deleteAttributeAt(att.index());

        return newData;
    }

    private boolean isDoubleEqual(double a, double b) {
        return (a == b) || Math.abs(a - b) < epsilon;
    }

    private int maxIndex(double[] array) {
        double max = 0;
        int index = 0;

        if (array.length > 0) {
            for (int i = 0; i < array.length; ++i) {
                if (array[i] > max) {
                    max = array[i];
                    index = i;
                }
            }
            return index;
        } else {
            return -1;
        }
    }

    public String toString(int level) {

        StringBuffer text = new StringBuffer();

        if (attribute == null) {
            if (Instance.isMissingValue(label)) {
                text.append(": null");
            } else {
                text.append(": " + classAttribute.value((int) label));
            }
        } else {
            for (int i = 0; i < attribute.numValues(); i++) {
                text.append("\n");
                for (int j = 0; j < level; j++) {
                    text.append("|  ");
                }
                text.append(attribute.name() + " = " + attribute.value(i));
                text.append(successors[i].toString(level + 1));
            }
        }

        return text.toString();
    }

    public String toString() {
        if (successors == null) {
            return "C45: No model built yet.";
        }

        return "C45\n\n" + toString(0);
    }

    private double getThreshold(Attribute attr) {
        return Double.parseDouble(attr.value(0).replace("<=", ""));
    }

    public boolean checkInstance(Instance instance) {
        double cv = instance.classValue();
        return isDoubleEqual(cv, classifyInstance(instance));
    }

    public double countError(Instances instances) {
        int ctrFalse = 0;
        int ctr = 0;
        Enumeration enumeration = instances.enumerateInstances();
        while (enumeration.hasMoreElements()) {
            Instance instance = (Instance) enumeration.nextElement();
            if (!checkInstance(instance)) {
                ctrFalse++;
            }
            ctr++;
        }
        return (double) ctrFalse / (double) (ctr);
    }

    public void pruneTree(Instances data) throws Exception {

        //Pruning jika successor != 0
        if (successors != null) {
            MyC45 temp[] = new MyC45[successors.length];
            double error = countError(data);
            for (int i = 0; i < successors.length; i++) {

                temp[i] = this.successors[i]; //save children
                this.successors[i] = null; //pruning
                this.label = (double) maxDistribution();

            }
            Attribute tempA = attribute;
            attribute = null;
            double prunedError = countError(data);
            if (error < prunedError) {
                //Cancel Pruning
                label = Double.NaN;
                for (int i = 0; i < successors.length; i++) {

                    this.successors[i] = temp[i];
                }
                attribute = tempA;
            } else {
                System.out.println("Pruned");
            }

        }

    }

    private int maxDistribution() {
        double max = Double.NEGATIVE_INFINITY;
        int idx = 0;
        for (int i = 0; i < distribution.length; i++) {
            if (distribution[i] > max) {
                idx = i;
                max = distribution[i];
            }
        }
        return idx;
    }
}