org.wkwk.classifier.MyC45.java Source code

Java tutorial

Introduction

Here is the source code for org.wkwk.classifier.MyC45.java

Source

package org.wkwk.classifier;

import java.util.ArrayList;
import java.util.Enumeration;
import weka.classifiers.AbstractClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */

/**
 *
 * @author adarwawan
 */
public class MyC45 extends AbstractClassifier {

    /**
     * The node's successors 
     */
    private MyC45[] successors;

    /**
     * Attribute for splitting
     */
    private Attribute splitAttribute;

    /**
     * Class value if node is leaf.
     */
    private double classValue;

    /**
     * Class distribution if node is leaf.
     */
    private double[] classDistribution;

    /**
     * True if the tree is to be pruned.
     */
    private boolean isPruned = true;

    /**
     * Class attribute of dataset.
     */
    private Attribute classAttribute;

    /**
     * Threshold for numeric value.
     */
    private double attrThreshold;

    @Override
    public void buildClassifier(Instances data) throws Exception {
        getCapabilities().testWithFail(data);

        for (int i = 0; i < data.numAttributes(); i++) {
            Attribute attr = data.attribute(i);
            for (int j = 0; j < 10; j++) {
                Instance instance = data.instance(j);
                if (instance.isMissing(attr)) {
                    instance.setValue(attr, fillMissingValue(data, attr));
                }
            }
        }

        data.deleteWithMissingClass();
        makeTree(data);
    }

    public double fillMissingValue(Instances data, Attribute attr) {
        int[] sum = new int[attr.numValues()];
        for (int i = 0; i < data.numInstances(); ++i) {
            sum[(int) data.instance(i).value(attr)]++;
        }
        return sum[Utils.maxIndex(sum)];
    }

    @Override
    public double classifyInstance(Instance data) {
        if (splitAttribute == null) {
            return classValue;
        } else {
            if (splitAttribute.isNominal()) {
                return successors[(int) data.value(splitAttribute)].classifyInstance(data);
            } else if (splitAttribute.isNumeric()) {
                if (data.value(splitAttribute) < attrThreshold) {
                    return successors[0].classifyInstance(data);
                } else {
                    return successors[1].classifyInstance(data);
                }
            } else {
                return -1;
            }
        }
    }

    public void makeTree(Instances data) throws Exception {
        if (data.numInstances() == 0) {
            splitAttribute = null;
        }

        // Calculate information gain for all attributes, except class attribute
        double[] infoGains = new double[data.numAttributes()];
        for (int i = 0; i < data.numAttributes() - 1; i++) {
            Attribute m_attr = data.attribute(i);
            if (m_attr.isNominal()) {
                infoGains[i] = computeInfoGain(data, data.attribute(i));
            } else if (m_attr.isNumeric()) {
                infoGains[i] = computeInfoGainCont(data, data.attribute(i), bestThreshold(data, m_attr));
            }
        }
        splitAttribute = data.attribute(Utils.maxIndex(infoGains));
        if (splitAttribute.isNumeric()) {
            attrThreshold = bestThreshold(data, splitAttribute);
        }

        if (Utils.eq(infoGains[splitAttribute.index()], 0)) {
            splitAttribute = null;
            classDistribution = new double[data.numClasses()];
            for (int i = 0; i < data.numInstances(); i++) {
                int inst = (int) data.instance(i).value(data.classAttribute());
                classDistribution[inst]++;
            }
            Utils.normalize(classDistribution);
            classValue = Utils.maxIndex(classDistribution);
            classAttribute = data.classAttribute();
        } else {
            Instances[] splitData = null;
            if (splitAttribute.isNominal()) {
                splitData = splitData(data, splitAttribute);
            } else if (splitAttribute.isNumeric()) {
                splitData = splitDataCont(data, splitAttribute, attrThreshold);
            }

            if (splitAttribute.isNominal()) {
                successors = new MyC45[splitAttribute.numValues()];
                for (int i = 0; i < splitAttribute.numValues(); i++) {
                    successors[i] = new MyC45();
                    successors[i].makeTree(splitData[i]);
                }
            } else if (splitAttribute.isNumeric()) {
                successors = new MyC45[2];
                for (int i = 0; i < 2; i++) {
                    successors[i] = new MyC45();
                    successors[i].makeTree(splitData[i]);
                }
            }
        }

        if (isPruned) {
            data = prune(data);
        }
    }

    // Implementasi
    public Instances prune(Instances data) throws Exception {
        ArrayList<Integer> unsignificantAttr = new ArrayList<>();
        Enumeration attEnum = data.enumerateAttributes();
        while (attEnum.hasMoreElements()) {
            Attribute att = (Attribute) attEnum.nextElement();
            double currentGainRatio;

            if (att.isNominal()) {
                currentGainRatio = computeInfoGain(data, att);
            } else {
                currentGainRatio = computeInfoGainCont(data, att, bestThreshold(data, att));
            }
            if (currentGainRatio < 1.0) {
                unsignificantAttr.add(att.index() + 1);
            }
        }
        if (unsignificantAttr.size() > 0) {
            StringBuilder unsignificant = new StringBuilder();
            int i = 0;
            for (Integer current : unsignificantAttr) {
                unsignificant.append(current.toString());
                if (i != unsignificantAttr.size() - 1) {
                    unsignificant.append(",");
                }
                i++;
            }
            return removeAttr(data, unsignificant.toString());
        } else {
            return data;
        }

    }

    public double computeInfoGain(Instances data, Attribute attr) {
        double infoGain = computeEntropy(data);
        Instances[] splitData = splitData(data, attr);

        for (int i = 0; i < attr.numValues(); i++) {
            if (splitData[i].numInstances() > 0) {
                infoGain -= ((double) splitData[i].numInstances() / (double) data.numInstances())
                        * computeEntropy(splitData[i]);
            }
        }
        return infoGain;
    }

    public double computeInfoGainCont(Instances data, Attribute attr, double threshold) {
        double infoGain = computeEntropy(data);
        Instances[] splitData = splitDataCont(data, attr, threshold);
        for (int i = 0; i < 2; i++) {
            if (splitData[i].numInstances() > 0) {
                infoGain -= ((double) splitData[i].numInstances() / (double) data.numInstances())
                        * computeEntropy(splitData[i]);
            }
        }
        return infoGain;
    }

    public double computeEntropy(Instances data) {
        // Hitung kemunculan kelas
        double[] classCounts = new double[data.numClasses()];
        Enumeration instEnum = data.enumerateInstances();
        while (instEnum.hasMoreElements()) {
            Instance inst = (Instance) instEnum.nextElement();
            classCounts[(int) inst.classValue()]++;
        }

        // Hitung entropy
        double entropy = 0;
        for (int i = 0; i < data.numClasses(); i++) {
            if (classCounts[i] > 0) {
                entropy -= classCounts[i] / data.numInstances() * Utils.log2(classCounts[i] / data.numInstances());
            }
        }
        return entropy;
    }

    public Instances[] splitData(Instances data, Attribute attr) {
        Instances[] splitData = new Instances[attr.numValues()];
        for (int i = 0; i < attr.numValues(); i++) {
            splitData[i] = new Instances(data, data.numInstances());
        }

        Enumeration instEnum = data.enumerateInstances();
        while (instEnum.hasMoreElements()) {
            Instance inst = (Instance) instEnum.nextElement();
            splitData[(int) inst.value(attr)].add(inst);
        }
        return splitData;
    }

    public Instances[] splitDataCont(Instances data, Attribute attr, double threshold) {
        Instances[] splitData = new Instances[2];
        for (int i = 0; i < 2; i++) {
            splitData[i] = new Instances(data, data.numInstances());
        }

        for (int i = 0; i < data.numInstances(); i++) {
            double temp = data.instance(i).value(attr);
            if (temp < threshold) {
                splitData[0].add(data.instance(i));
            } else {
                splitData[1].add(data.instance(i));
            }
        }
        return splitData;
    }

    public double bestThreshold(Instances data, Attribute attr) {
        data.sort(attr);

        double m_ig = 0;
        double bestThr = 0;
        double classTemp = data.get(0).classValue();
        double valueTemp = data.get(0).value(attr);

        Enumeration instEnum = data.enumerateInstances();
        double dt;
        while (instEnum.hasMoreElements()) {
            Instance inst = (Instance) instEnum.nextElement();
            if (classTemp != inst.classValue()) {
                classTemp = inst.classValue();
                dt = valueTemp;
                valueTemp = inst.value(attr);
                double threshold = dt + ((valueTemp - dt) / 2);
                double igTemp = computeInfoGainCont(data, attr, threshold);
                if (m_ig < igTemp) {
                    m_ig = igTemp;
                    bestThr = threshold;
                }
            }
        }
        return bestThr;
    }

    private Instances removeAttr(Instances data, String attr) throws Exception {
        Remove remove = new Remove();
        remove.setAttributeIndices(attr); //Set which attributes are to be deleted (or kept if invert is true)
        remove.setInputFormat(data); //Sets the format of the input instances.
        Instances filterData = Filter.useFilter(data, remove); //Filters an entire set of instances through a filter and returns the new set.
        return filterData;
    }
}