myclassifier.Util.java Source code

Java tutorial

Introduction

Here is the source code for myclassifier.Util.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package myclassifier;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;

/**
 *
 * @author Visat
 */
public class Util {
    public static int indexOfMax(double[] array) {
        double max = -Double.MAX_VALUE;
        int idx = -1;

        for (int i = 0; i < array.length; ++i) {
            if (Double.compare(array[i], max) > 0) {
                max = array[i];
                idx = i;
            }
        }
        return idx;
    }

    public static int indexOfMax(int[] array) {
        double max = Integer.MIN_VALUE;
        int idx = -1;

        for (int i = 0; i < array.length; ++i) {
            if (array[i] > max) {
                max = array[i];
                idx = i;
            }
        }
        return idx;
    }

    public static boolean equalValue(double a, double b) {
        final double epsilon = 1e-6;
        return ((a == b) || Math.abs(a - b) < epsilon);
    }

    public static double log2(double val) {
        return equalValue(val, 0) ? 0.0 : (Math.log(val) / Math.log(2));
    }

    public static void normalizeClassDistribution(double[] array) {
        double sum = 0;
        for (double d : array)
            sum += d;
        if (!Double.isNaN(sum) && sum != 0) {
            for (int i = 0; i < array.length; ++i)
                array[i] /= sum;
        }
    }

    public static double calculateE(Instances instances) {
        double[] labelCounts = new double[instances.numClasses()];
        for (int i = 0; i < instances.numInstances(); ++i)
            labelCounts[(int) instances.instance(i).classValue()]++;

        double entropy = 0.0;
        for (int i = 0; i < labelCounts.length; i++) {
            if (labelCounts[i] > 0) {
                double proportion = labelCounts[i] / instances.numInstances();
                entropy -= (proportion) * log2(proportion);
            }
        }
        return entropy;
    }

    public static double calculateGainRatio(Instances data, Attribute attribute) {
        double IG = calculateIG(data, attribute);
        double IV = calculateIntrinsicValue(data, attribute);
        if (IG == 0 || IV == 0)
            return 0;
        return IG / IV;
    }

    private static double calculateIntrinsicValue(Instances data, Attribute attribute) {
        double IV = 0;
        Instances[] splitData = splitData(data, attribute);
        for (int i = 0; i < attribute.numValues(); i++) {
            if (splitData[i].numInstances() > 0) {
                double proportion = (double) splitData[i].numInstances() / (double) data.numInstances();
                IV -= (proportion * Util.log2(proportion));
            }
        }
        return IV;
    }

    public static double calculateIG(Instances instances, Attribute attribute) {
        double IG = calculateE(instances);
        int missingCount = 0;
        Instances[] splitData = splitData(instances, attribute);
        for (int j = 0; j < attribute.numValues(); j++) {
            if (splitData[j].numInstances() > 0) {
                IG -= ((double) splitData[j].numInstances() / (double) instances.numInstances())
                        * calculateE(splitData[j]);
            }
        }

        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            if (instance.isMissing(attribute))
                missingCount++;
        }
        return IG * (instances.numInstances() - missingCount / instances.numInstances());
    }

    public static double calculateIGCont(Instances instances, Attribute attribute, double threshold) {
        double gain = calculateE(instances);
        Instances[] split = splitDataCont(instances, attribute, threshold);
        for (int i = 0; i < 2; i++) {
            if (split[i].numInstances() > 0) {
                gain -= ((double) split[i].numInstances() / (double) instances.numInstances())
                        * Util.calculateE(split[i]);
            }
        }
        return gain;
    }

    public static Instances[] splitData(Instances instances, Attribute attribute) {
        Instances[] splittedData = new Instances[attribute.numValues()];

        for (int i = 0; i < attribute.numValues(); i++)
            splittedData[i] = new Instances(instances, instances.numInstances());

        for (int i = 0; i < instances.numInstances(); i++) {
            int attValue = (int) instances.instance(i).value(attribute);
            if (attValue >= 0 && attValue < attribute.numValues())
                splittedData[attValue].add(instances.instance(i));
        }
        //
        //        for (Instances currentSplitData: splittedData)
        //            currentSplitData.compactify();

        return splittedData;
    }

    public static Instances[] splitDataCont(Instances instances, Attribute attribute, double threshold) {
        Instances[] split = new Instances[2];
        for (int i = 0; i < 2; i++) {
            split[i] = new Instances(instances, instances.numInstances());
        }

        for (int i = 0; i < instances.numInstances(); i++) {
            double temp = instances.instance(i).value(attribute);
            if (temp < threshold) {
                split[0].add(instances.instance(i));
            } else {
                split[1].add(instances.instance(i));
            }
        }

        return split;
    }

    public static Instances setAttributeThreshold(Instances data, Attribute att, int threshold) throws Exception {
        Instances temp = new Instances(data);
        Add filter = new Add();
        filter.setAttributeName("thresholded " + att.name());
        filter.setAttributeIndex(String.valueOf(att.index() + 2));
        filter.setNominalLabels("<=" + threshold + ",>" + threshold);
        filter.setInputFormat(temp);

        Instances thresholdedData = Filter.useFilter(data, filter);

        for (int i = 0; i < thresholdedData.numInstances(); i++) {
            if ((int) thresholdedData.instance(i).value(thresholdedData.attribute(att.name())) <= threshold)
                thresholdedData.instance(i).setValue(thresholdedData.attribute("thresholded " + att.name()),
                        "<=" + threshold);
            else
                thresholdedData.instance(i).setValue(thresholdedData.attribute("thresholded " + att.name()),
                        ">" + threshold);
        }
        thresholdedData = wekaCode.removeAttributes(thresholdedData, String.valueOf(att.index() + 1));
        thresholdedData.renameAttribute(thresholdedData.attribute("thresholded " + att.name()), att.name());
        return thresholdedData;
    }

    public static Instances toNominal(Instances data) throws Exception {
        for (int n = 0; n < data.numAttributes(); n++) {
            Attribute att = data.attribute(n);
            if (data.attribute(n).isNumeric()) {
                HashSet<Integer> uniqueValues = new HashSet();
                for (int i = 0; i < data.numInstances(); ++i) {
                    uniqueValues.add((int) (data.instance(i).value(att)));
                }
                List<Integer> dataValues = new ArrayList<>(uniqueValues);
                dataValues.sort((Integer o1, Integer o2) -> {
                    if (o1 > o2)
                        return 1;
                    else
                        return -1;
                });

                double[] infoGains = new double[dataValues.size() - 1];
                Instances[] tempInstances = new Instances[dataValues.size() - 1];
                for (int i = 0; i < dataValues.size() - 1; ++i) {
                    tempInstances[i] = setAttributeThreshold(data, att, dataValues.get(i));
                    infoGains[i] = calculateIG(tempInstances[i], tempInstances[i].attribute(att.name()));
                }
                data = new Instances(tempInstances[Util.indexOfMax(infoGains)]);
            }
        }
        return data;
    }

    public static void percentageSplit(Instances data, Classifier cls) throws Exception {
        int trainSize = (int) Math.round(data.numInstances() * 0.8);
        int testSize = data.numInstances() - trainSize;
        Instances train = new Instances(data, 0, trainSize);
        Instances test = new Instances(data, trainSize, testSize);

        Evaluation eval = new Evaluation(train);
        eval.evaluateModel(cls, test);
    }

    public static double percentageSplitRate(Instances data, Classifier cls) throws Exception {
        int trainSize = (int) Math.round(data.numInstances() * 0.8);
        int testSize = data.numInstances() - trainSize;
        Instances train = new Instances(data, 0, trainSize);
        Instances test = new Instances(data, trainSize, testSize);

        Evaluation eval = new Evaluation(train);
        eval.evaluateModel(cls, test);
        return eval.pctCorrect();
    }
}