aw_cluster.myKMeans.java Source code

Java tutorial

Introduction

Here is the source code for aw_cluster.myKMeans.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package aw_cluster;

import java.util.HashSet;
import java.util.Random;
import weka.clusterers.AbstractClusterer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

/**
 *
 * @author aodyra
 */
public class myKMeans extends AbstractClusterer {

    private int numCluster = 2;
    private Instances centroid;
    private int[] sizeEachCluster;
    private int[] assignments;
    private int numIteration = 0;
    private int seedRandom;
    private int maxIteration = 500;
    private double[] squaredError;
    private DistanceFunction distanceFunction = new EuclideanDistance();

    public myKMeans() {
        // default seed = 10
        seedRandom = 10;
    }

    public myKMeans(int seed) {
        seedRandom = 10;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capability.NO_CLASS);

        result.enable(Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capability.MISSING_VALUES);

        return result;
    }

    @Override
    public void buildClusterer(Instances data) throws Exception {
        getCapabilities().testWithFail(data);

        Instances instances = new Instances(data);
        instances.setClassIndex(-1);

        if (instances.numInstances() == 0) {
            throw new RuntimeException("Dataset should not be empty");
        }

        assignments = new int[instances.numInstances()];
        centroid = new Instances(instances, numCluster);
        distanceFunction.setInstances(instances);
        squaredError = new double[numCluster];

        // Initialize Centroid Random From seed
        Random random = new Random(getSeedRandom());
        Instances tempInstances = new Instances(instances);

        int tI = tempInstances.numInstances() - 1;
        while (tI >= 0 && centroid.numInstances() < numCluster) {
            int indexCentroid = random.nextInt(tI + 1);
            centroid.add(tempInstances.instance(indexCentroid));
            tempInstances.swap(tI, indexCentroid);
            tI--;
        }

        tempInstances = null;

        boolean converged = false;
        while (!converged) {
            converged = true;
            numIteration++;
            for (int i = 0; i < instances.numInstances(); ++i) {
                Instance toCluster = instances.instance(i);
                int clusterResult = clusterInstanceProcess(toCluster, true);
                if (clusterResult != assignments[i])
                    converged = false;
                assignments[i] = clusterResult;
            }

            // update centroid
            Instances[] TempI = new Instances[numCluster];
            centroid = new Instances(instances, numCluster);
            for (int i = 0; i < TempI.length; ++i) {
                TempI[i] = new Instances(instances, 0);
            }
            for (int i = 0; i < instances.numInstances(); ++i) {
                TempI[assignments[i]].add(instances.instance(i));
            }
            for (int i = 0; i < TempI.length; ++i) {
                moveCentroid(TempI[i]);
            }
            if (converged)
                squaredError = new double[numCluster];
            if (numIteration == maxIteration)
                converged = true;
            sizeEachCluster = new int[numCluster];
            for (int i = 0; i < numCluster; ++i) {
                sizeEachCluster[i] = TempI[i].numInstances();
            }

        }
    }

    protected double[] moveCentroid(Instances members) {
        double[] vals = new double[members.numAttributes()];

        for (int j = 0; j < members.numAttributes(); j++) {
            vals[j] = members.meanOrMode(j);
        }
        centroid.add(new Instance(1.0, vals));

        return vals;
    }

    public int clusterInstanceProcess(Instance toCluster, boolean updateError) {
        double distance;
        double distanceMin = 0;
        int indexCluster = 0;
        for (int i = 0; i < numCluster; ++i) {
            distance = distanceFunction.distance(toCluster, centroid.instance(i));
            if (i == 0) {
                distanceMin = distance;
                indexCluster = i;
            }
            if (distance < distanceMin) {
                distanceMin = distance;
                indexCluster = i;
            }
        }
        if (updateError) {
            distanceMin *= distanceMin;
            squaredError[indexCluster] += distanceMin;
        }
        return indexCluster;
    }

    public int clusterInstance(Instance instance) {
        return clusterInstanceProcess(instance, false);
    }

    @Override
    public int numberOfClusters() throws Exception {
        return numCluster;
    }

    public int[] getAssignment() {
        return assignments;
    }

    public Instances getCentroid() {
        return centroid;
    }

    public int[] getSizeEachCluster() {
        return sizeEachCluster;
    }

    public int getNumIteration() {
        return numIteration;
    }

    public DistanceFunction getDistanceFunction() {
        return distanceFunction;
    }

    public int getMaxIteration() {
        return maxIteration;
    }

    public int getSeedRandom() {
        return seedRandom;
    }

    public double[] getSquaredError() {
        return squaredError;
    }

    public void setSquaredError(double[] squaredError) {
        this.squaredError = squaredError;
    }

    public void setNumCluster(int numCluster) throws Exception {
        if (numCluster <= 0) {
            throw new Exception("Number of clusters must be > 0");
        }
        this.numCluster = numCluster;
    }

    public void setDistanceFunction(DistanceFunction distanceFunction) {
        this.distanceFunction = distanceFunction;
    }

    public void setSeedRandom(int seedRandom) {
        this.seedRandom = seedRandom;
    }

    public void setMaxIteration(int maxIteration) {
        this.maxIteration = maxIteration;
    }

    @Override
    public String toString() {
        if (centroid == null) {
            return "No clusterer built yet!";
        }

        int maxWidth = 0;
        int maxAttWidth = 0;
        boolean containsNumeric = false;
        for (int i = 0; i < numCluster; i++) {
            for (int j = 0; j < centroid.numAttributes(); j++) {
                if (centroid.attribute(j).name().length() > maxAttWidth) {
                    maxAttWidth = centroid.attribute(j).name().length();
                }
                if (centroid.attribute(j).isNumeric()) {
                    containsNumeric = true;
                    double width = Math.log(Math.abs(centroid.instance(i).value(j))) / Math.log(10.0);
                    if (width < 0) {
                        width = 1;
                    }
                    width += 6.0;
                    if ((int) width > maxWidth) {
                        maxWidth = (int) width;
                    }
                }
            }
        }

        for (int i = 0; i < centroid.numAttributes(); i++) {
            if (centroid.attribute(i).isNominal()) {
                Attribute a = centroid.attribute(i);
                for (int j = 0; j < centroid.numInstances(); j++) {
                    String val = a.value((int) centroid.instance(j).value(i));
                    if (val.length() > maxWidth) {
                        maxWidth = val.length();
                    }
                }
                for (int j = 0; j < a.numValues(); j++) {
                    String val = a.value(j) + " ";
                    if (val.length() > maxAttWidth) {
                        maxAttWidth = val.length();
                    }
                }
            }
        }

        // check for size of cluster sizes
        for (int i = 0; i < sizeEachCluster.length; i++) {
            String size = "(" + sizeEachCluster[i] + ")";
            if (size.length() > maxWidth) {
                maxWidth = size.length();
            }
        }

        String plusMinus = "+/-";
        maxAttWidth += 2;
        if (maxAttWidth < "Attribute".length() + 2) {
            maxAttWidth = "Attribute".length() + 2;
        }

        if (maxWidth < "Full Data".length()) {
            maxWidth = "Full Data".length() + 1;
        }

        if (maxWidth < "missing".length()) {
            maxWidth = "missing".length() + 1;
        }

        StringBuffer temp = new StringBuffer();
        temp.append("\nkMeans\n======\n");
        temp.append("\nNumber of iterations: " + numIteration + "\n");

        if (distanceFunction instanceof EuclideanDistance) {
            temp.append("Within cluster sum of squared errors: " + Utils.sum(squaredError));
        } else {
            temp.append("Sum of within cluster distances: " + Utils.sum(squaredError));
        }

        temp.append("\n\nCluster centroid:\n");
        temp.append(pad("Cluster#", " ", (maxAttWidth + (maxWidth * 2 + 2)) - "Cluster#".length(), true));

        temp.append("\n");
        temp.append(pad("Attribute", " ", maxAttWidth - "Attribute".length(), false));

        // cluster numbers
        for (int i = 0; i < numCluster; i++) {
            String clustNum = "" + i;
            temp.append(pad(clustNum, " ", maxWidth + 1 - clustNum.length(), true));
        }
        temp.append("\n");

        // cluster sizes
        String cSize = "";
        temp.append(pad(cSize, " ", maxAttWidth - cSize.length(), true));
        for (int i = 0; i < numCluster; i++) {
            cSize = "(" + sizeEachCluster[i] + ")";
            temp.append(pad(cSize, " ", maxWidth + 1 - cSize.length(), true));
        }
        temp.append("\n");

        temp.append(
                pad("", "=", maxAttWidth + (maxWidth * (centroid.numInstances()) + centroid.numInstances()), true));
        temp.append("\n");

        for (int i = 0; i < centroid.numAttributes(); i++) {
            String attName = centroid.attribute(i).name();
            temp.append(attName);
            for (int j = 0; j < maxAttWidth - attName.length(); j++) {
                temp.append(" ");
            }

            String strVal;
            String valMeanMode;

            for (int j = 0; j < numCluster; j++) {
                if (centroid.attribute(i).isNominal()) {
                    if (centroid.instance(j).isMissing(i)) {
                        valMeanMode = pad("missing", " ", maxWidth + 1 - "missing".length(), true);
                    } else {
                        valMeanMode = pad(
                                (strVal = centroid.attribute(i).value((int) centroid.instance(j).value(i))), " ",
                                maxWidth + 1 - strVal.length(), true);
                    }
                } else {
                    if (centroid.instance(j).isMissing(i)) {
                        valMeanMode = pad("missing", " ", maxWidth + 1 - "missing".length(), true);
                    } else {
                        valMeanMode = pad(
                                (strVal = Utils.doubleToString(centroid.instance(j).value(i), maxWidth, 4).trim()),
                                " ", maxWidth + 1 - strVal.length(), true);
                    }
                }
                temp.append(valMeanMode);
            }
            temp.append("\n");
        }

        temp.append("\n\n");
        return temp.toString();
    }

    private String pad(String source, String padChar, int length, boolean leftPad) {
        StringBuffer temp = new StringBuffer();

        if (leftPad) {
            for (int i = 0; i < length; i++) {
                temp.append(padChar);
            }
            temp.append(source);
        } else {
            temp.append(source);
            for (int i = 0; i < length; i++) {
                temp.append(padChar);
            }
        }
        return temp.toString();
    }
}