Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package myID3; import WekaInterface.Weka; import java.io.IOException; import java.util.Enumeration; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.core.Attribute; import weka.core.Capabilities; import weka.core.Capabilities.Capability; import weka.core.Instance; import weka.core.Instances; import weka.core.NoSupportForMissingValuesException; import weka.core.Utils; /** * * @author ahmadshahab */ public class MyId3 extends AbstractClassifier { /** Attribute of this class */ private MyId3[] nodes; // For splitting the tree private Attribute currentAttribute; // If leaf, then the value is class value private double classValue; // Class distribution (if leaf) private double[] classDistribution; // Attribute identity for the class of the node (if leaf) private Attribute classAttribute; /** * Build an Id3 classifier * @param instances dataset used for building the model * @throws Exception */ @Override public void buildClassifier(Instances instances) throws Exception { // Detecting the instance type, can Id3 handle the data? getCapabilities().testWithFail(instances); // Remove missing class Instances data = new Instances(instances); data.deleteWithMissingClass(); // Build the id3 buildTree(data); } /** * Construct the tree using the given instance * Find the highest attribute value which best at dividing the data * @param data Instance */ public void buildTree(Instances data) { if (data.numInstances() > 0) { // Lets find the highest Information Gain! // First compute each information gain attribute double IG[] = new double[data.numAttributes()]; Enumeration enumAttribute = data.enumerateAttributes(); while (enumAttribute.hasMoreElements()) { Attribute attribute = (Attribute) enumAttribute.nextElement(); IG[attribute.index()] = informationGain(data, attribute); // System.out.println(attribute.toString() + ": " + IG[attribute.index()]); } // Assign it as the tree attribute! currentAttribute = data.attribute(maxIndex(IG)); //System.out.println(Arrays.toString(IG) + IG[currentAttribute.index()]); // IG = 0 then current node = leaf! if (Utils.eq(IG[currentAttribute.index()], 0)) { // Set the class value as the highest frequency of the class currentAttribute = null; classDistribution = new double[data.numClasses()]; Enumeration enumInstance = data.enumerateInstances(); while (enumInstance.hasMoreElements()) { Instance temp = (Instance) enumInstance.nextElement(); classDistribution[(int) temp.classValue()]++; } Utils.normalize(classDistribution); classValue = Utils.maxIndex(classDistribution); classAttribute = data.classAttribute(); } else { // Create another node from the current tree Instances[] splitData = splitDataByAttribute(data, currentAttribute); nodes = new MyId3[currentAttribute.numValues()]; for (int i = 0; i < currentAttribute.numValues(); i++) { nodes[i] = new MyId3(); nodes[i].buildTree(splitData[i]); } } } else { classAttribute = null; classValue = Utils.missingValue(); classDistribution = new double[data.numClasses()]; } } /** * Count the information gain for selected attribute * from the given dataset * @param data * @param attribute * @return */ private double informationGain(Instances data, Attribute attribute) { /* Information Gain = Init entropy - After change entropy */ double initEntropy = entropy(data); // Now we split the attribute first to count each entropy on different value Instances[] subSet = splitDataByAttribute(data, attribute); double[] entropy = new double[attribute.numValues()]; // Count the entropy! for (int i = 0; i < attribute.numValues(); i++) { if (subSet[i].numInstances() > 0) entropy[i] = entropy(subSet[i]); else entropy[i] = 0; } //System.out.println(attribute.toString() + " " + Arrays.toString(entropy) + "\n"); double IG = initEntropy; for (int i = 0; i < attribute.numValues(); i++) { IG = IG - (entropy[i] * (double) subSet[i].numInstances() / data.numInstances()); } return IG; } /** * Find the entropy from a given dataset * @param data * @return */ private double entropy(Instances data) { /* Entropy = -(p1 log2 p1) -(p2 log2 p2).... */ double numInstance = data.numInstances(); double numClass = data.numClasses(); double[] distribution = new double[data.numClasses()]; Enumeration instance = data.enumerateInstances(); while (instance.hasMoreElements()) { Instance temp = (Instance) instance.nextElement(); /* Count the p1, p2 */ distribution[(int) temp.classValue()]++; } /* Sum all the distribution */ double sum = 0; for (int i = 0; i < numClass; i++) { distribution[i] = distribution[i] / numInstance; if (distribution[i] > 0.0) distribution[i] *= Utils.log2(distribution[i]); // System.out.println(Arrays.toString(distribution)); sum += distribution[i]; } return -1 * sum; } /** * Create split of data based on the value of attribute * @param data * @param attribute * @return */ private Instances[] splitDataByAttribute(Instances data, Attribute attribute) { // Init the object first Instances[] subSet = new Instances[attribute.numValues()]; for (int i = 0; i < attribute.numValues(); i++) { subSet[i] = new Instances(data, data.numInstances()); } // Split it! Enumeration instanceEnum = data.enumerateInstances(); while (instanceEnum.hasMoreElements()) { Instance instance = (Instance) instanceEnum.nextElement(); subSet[(int) instance.value(attribute)].add(instance); } // Compact the array of object by removing the empty array for (int i = 0; i < attribute.numValues(); i++) { subSet[i].compactify(); // System.out.println(subSet[i]); } return subSet; } /** * Capability of id3 classifier * @return */ @Override public Capabilities getCapabilities() { Capabilities id3_capability = super.getCapabilities(); id3_capability.disableAll(); // Attribute type capability id3_capability.enable(Capability.NOMINAL_ATTRIBUTES); // Class capability id3_capability.enable(Capability.NOMINAL_CLASS); id3_capability.enable(Capability.MISSING_CLASS_VALUES); // Minimum number of instances allowed to be use id3_capability.setMinimumNumberInstances(0); return id3_capability; } private int maxIndex(double[] arr) { double max_val = 0; int max_index = 0; for (int i = 0; i < arr.length; i++) { if (arr[i] == Double.NaN) { arr[i] = -9.9; } if (max_val <= arr[i]) { max_val = arr[i]; max_index = i; } } return max_index; } /** * Classifies a given test instance using the decision tree. * * @param instance the instance to be classified * @return the classification * @throws NoSupportForMissingValuesException if instance has missing values */ public double classifyInstance(Instance instance) throws NoSupportForMissingValuesException { if (instance.hasMissingValue()) { throw new NoSupportForMissingValuesException("Id3: no missing values, " + "please."); } if (currentAttribute == null) { return classValue; } else { return nodes[(int) instance.value(currentAttribute)].classifyInstance(instance); } } /** * Computes class distribution for instance using decision tree. * * @param instance the instance for which distribution is to be computed * @return the class distribution for the given instance * @throws NoSupportForMissingValuesException if instance has missing values */ public double[] distributionForInstance(Instance instance) throws NoSupportForMissingValuesException { if (instance.hasMissingValue()) { throw new NoSupportForMissingValuesException("Id3: no missing values, " + "please."); } if (currentAttribute == null) { return classDistribution; } else { return nodes[(int) instance.value(currentAttribute)].distributionForInstance(instance); } } /** * Prints the decision tree using the private toString method from below. * * @return a textual description of the classifier */ @Override public String toString() { if ((classDistribution == null) && (nodes == null)) { return "Id3: No model built yet."; } return "Id3\n\n" + printTree(0); } public String printTree(int level) { StringBuilder text = new StringBuilder(); if (currentAttribute == null) { if (Utils.isMissingValue(classValue)) { text.append(": null"); } else { text.append(": ").append(classAttribute.value((int) classValue)); } } else { for (int j = 0; j < currentAttribute.numValues(); j++) { text.append("\n"); for (int i = 0; i < level; i++) { text.append("| "); } text.append(currentAttribute.name()).append(" = ").append(currentAttribute.value(j)); text.append(nodes[j].printTree(level + 1)); } } return text.toString(); } public static void main(String[] args) throws IOException, Exception { Weka a = new Weka(); a.setTraining("weather.nominal.arff"); Classifier b = new MyId3(); b.buildClassifier(a.getM_Training()); System.out.println(b.toString()); } }