myclusterer.MyKMeans.java Source code

Java tutorial

Introduction

Here is the source code for myclusterer.MyKMeans.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package myclusterer;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import weka.clusterers.NumberOfClustersRequestable;
import weka.clusterers.RandomizableClusterer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

/**
 *
 * @author Visat
 */
public class MyKMeans extends RandomizableClusterer implements NumberOfClustersRequestable {

    protected Instances instances;
    protected Instances centroids;
    protected List<Instance>[] clusters;
    protected int K = 3;
    protected int iterations;
    protected int maxIterations = 500;
    protected final DistanceFunction distanceFunction = new EuclideanDistance();

    @Override
    public void buildClusterer(Instances instances) throws Exception {
        int N = instances.numInstances();
        if (K < 1)
            K = 1;
        if (N == 0 || N < K)
            return;
        getCapabilities().testWithFail(instances);

        this.instances = instances;
        distanceFunction.setInstances(instances);

        // assign first centroids randomly
        Random rand = new Random();
        Set<Integer> centroidIdx = new HashSet<>();
        while (centroidIdx.size() < K) {
            int x = rand.nextInt(N);
            centroidIdx.add(x);
        }
        centroids = new Instances(instances, K);
        centroidIdx.forEach((idx) -> {
            centroids.add(instances.instance(idx));
        });

        int[] prevCluster = new int[N];
        for (int i = 0; i < N; ++i)
            prevCluster[i] = -1;

        List<Integer>[] tmpCluster = new List[K];
        for (int i = 0; i < K; ++i)
            tmpCluster[i] = new ArrayList<>();

        boolean converged = false;
        iterations = 0;
        while (!converged && iterations < maxIterations) {
            ++iterations;
            converged = true;

            for (int i = 0; i < K; ++i)
                tmpCluster[i].clear();
            for (int i = 0; i < N; ++i) {
                int cluster = clusterInstance(instances.instance(i));
                if (prevCluster[i] != cluster) {
                    converged = false;
                    prevCluster[i] = cluster;
                }
                tmpCluster[cluster].add(i);
            }

            // update centroid
            centroids = new Instances(instances, K);
            for (int i = 0; i < K; ++i) {
                Instances members = new Instances(instances, N);
                for (Integer member : tmpCluster[i])
                    members.add(instances.instance(member));
                centroids.add(createCentroid(members));
            }
        }
        clusters = new List[K];
        for (int i = 0; i < K; ++i) {
            clusters[i] = new ArrayList<>();
            for (Integer member : tmpCluster[i])
                clusters[i].add(instances.instance(member));
        }
    }

    @Override
    public int numberOfClusters() throws Exception {
        return K;
    }

    @Override
    public void setNumClusters(int K) throws Exception {
        if (K <= 0)
            throw new Exception("Number of clusters must be > 0");
        this.K = K;
    }

    public int getIterations() {
        return iterations;
    }

    public int getMaxIterations() {
        return maxIterations;
    }

    public void setMaxIterations(int maxIterations) throws Exception {
        if (maxIterations <= 0)
            throw new Exception("Number of iterations must be > 0");
        this.maxIterations = maxIterations;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capability.NO_CLASS);

        // attributes
        result.enable(Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capability.MISSING_VALUES);

        return result;
    }

    private Instance createCentroid(Instances members) {
        double[] vals = new double[members.numAttributes()];
        double[][] nominalDists = new double[members.numAttributes()][];
        double[] weightMissing = new double[members.numAttributes()];
        double[] weightNonMissing = new double[members.numAttributes()];

        for (int j = 0; j < members.numAttributes(); j++) {
            if (members.attribute(j).isNominal()) {
                nominalDists[j] = new double[members.attribute(j).numValues()];
            }
        }
        for (int i = 0; i < members.numInstances(); ++i) {
            Instance inst = members.instance(i);
            for (int j = 0; j < members.numAttributes(); j++) {
                if (inst.isMissing(j)) {
                    weightMissing[j] += inst.weight();
                } else {
                    weightNonMissing[j] += inst.weight();
                    if (members.attribute(j).isNumeric())
                        vals[j] += inst.weight() * inst.value(j);
                    else
                        nominalDists[j][(int) inst.value(j)] += inst.weight();
                }
            }
        }
        for (int i = 0; i < members.numAttributes(); i++) {
            if (members.attribute(i).isNumeric()) {
                if (weightNonMissing[i] > 0) {
                    vals[i] /= weightNonMissing[i];
                } else {
                    vals[i] = Instance.missingValue();
                }
            } else {
                double max = -Double.MAX_VALUE;
                double maxIndex = -1;
                for (int j = 0; j < nominalDists[i].length; j++) {
                    if (nominalDists[i][j] > max) {
                        max = nominalDists[i][j];
                        maxIndex = j;
                    }
                    vals[i] = max < weightMissing[i] ? Instance.missingValue() : maxIndex;
                }
            }
        }
        return new Instance(1.0, vals);
    }

    @Override
    public int clusterInstance(Instance instance) throws Exception {
        double min = Double.MAX_VALUE;
        int idx = 0;
        for (int i = 0; i < K; ++i) {
            double dist = distanceFunction.distance(centroids.instance(i), instance);
            if (dist < min) {
                min = dist;
                idx = i;
            }
        }
        return idx;
    }

    @Override
    public String toString() {
        if (centroids == null) {
            return "No clusterer built yet!";
        }

        int maxWidth = 0;
        int maxAttWidth = 0;
        boolean containsNumeric = false;
        for (int i = 0; i < K; i++) {
            for (int j = 0; j < centroids.numAttributes(); j++) {
                if (centroids.attribute(j).name().length() > maxAttWidth) {
                    maxAttWidth = centroids.attribute(j).name().length();
                }
                if (centroids.attribute(j).isNumeric()) {
                    containsNumeric = true;
                    double width = Math.log(Math.abs(centroids.instance(i).value(j))) / Math.log(10.0);
                    if (width < 0) {
                        width = 1;
                    }
                    // decimal + # decimal places + 1
                    width += 6.0;
                    if ((int) width > maxWidth) {
                        maxWidth = (int) width;
                    }
                }
            }
        }

        for (int i = 0; i < centroids.numAttributes(); i++) {
            if (centroids.attribute(i).isNominal()) {
                Attribute a = centroids.attribute(i);
                for (int j = 0; j < centroids.numInstances(); j++) {
                    String val = a.value((int) centroids.instance(j).value(i));
                    if (val.length() > maxWidth) {
                        maxWidth = val.length();
                    }
                }
                for (int j = 0; j < a.numValues(); j++) {
                    String val = a.value(j) + " ";
                    if (val.length() > maxAttWidth) {
                        maxAttWidth = val.length();
                    }
                }
            }
        }

        // check for size of cluster sizes
        for (int i = 0; i < clusters.length; i++) {
            String size = "(" + clusters[i].size() + ")";
            if (size.length() > maxWidth) {
                maxWidth = size.length();
            }
        }

        String plusMinus = "+/-";
        maxAttWidth += 2;
        if (maxAttWidth < "Attribute".length() + 2) {
            maxAttWidth = "Attribute".length() + 2;
        }

        if (maxWidth < "Full Data".length()) {
            maxWidth = "Full Data".length() + 1;
        }

        if (maxWidth < "missing".length()) {
            maxWidth = "missing".length() + 1;
        }

        StringBuffer temp = new StringBuffer();
        //    String naString = "N/A";

        /*    for (int i = 0; i < maxWidth+2; i++) {
              naString += " ";
              } */
        temp.append("\nkMeans\n======\n");
        temp.append("\nNumber of iterations: " + iterations + "\n");

        /*if(distanceFunction instanceof EuclideanDistance){
        temp.append("Within cluster sum of squared errors: " + Utils.sum(squaredErrors));
        }else{
        temp.append("Sum of within cluster distances: " + Utils.sum(squaredErrors));
        }*/

        temp.append("\n\nCluster centroids:\n");
        temp.append(pad("Cluster#", " ", (maxAttWidth + (maxWidth * 2 + 2)) - "Cluster#".length(), true));

        temp.append("\n");
        temp.append(pad("Attribute", " ", maxAttWidth - "Attribute".length(), false));

        //        temp.append(pad("Full Data", " ", maxWidth + 1 - "Full Data".length(), true));

        // cluster numbers
        for (int i = 0; i < K; i++) {
            String clustNum = "" + i;
            temp.append(pad(clustNum, " ", maxWidth + 1 - clustNum.length(), true));
        }
        temp.append("\n");

        // cluster sizes
        String cSize = "";
        temp.append(pad(cSize, " ", maxAttWidth - cSize.length(), true));
        for (int i = 0; i < K; i++) {
            cSize = "(" + clusters[i].size() + ")";
            temp.append(pad(cSize, " ", maxWidth + 1 - cSize.length(), true));
        }
        temp.append("\n");

        temp.append(pad("", "=", maxAttWidth + (maxWidth * (centroids.numInstances()) + centroids.numInstances()),
                true));
        temp.append("\n");

        for (int i = 0; i < centroids.numAttributes(); i++) {
            String attName = centroids.attribute(i).name();
            temp.append(attName);
            for (int j = 0; j < maxAttWidth - attName.length(); j++) {
                temp.append(" ");
            }

            String strVal;
            String valMeanMode;
            for (int j = 0; j < K; j++) {
                if (centroids.attribute(i).isNominal()) {
                    if (centroids.instance(j).isMissing(i)) {
                        valMeanMode = pad("missing", " ", maxWidth + 1 - "missing".length(), true);
                    } else {
                        valMeanMode = pad(
                                (strVal = centroids.attribute(i).value((int) centroids.instance(j).value(i))), " ",
                                maxWidth + 1 - strVal.length(), true);
                    }
                } else {
                    if (centroids.instance(j).isMissing(i)) {
                        valMeanMode = pad("missing", " ", maxWidth + 1 - "missing".length(), true);
                    } else {
                        valMeanMode = pad(
                                (strVal = Utils.doubleToString(centroids.instance(j).value(i), maxWidth, 4).trim()),
                                " ", maxWidth + 1 - strVal.length(), true);
                    }
                }
                temp.append(valMeanMode);
            }
            temp.append("\n");
        }

        temp.append("\n\n");
        return temp.toString();
    }

    private String pad(String source, String padChar, int length, boolean leftPad) {
        StringBuffer temp = new StringBuffer();

        if (leftPad) {
            for (int i = 0; i < length; i++) {
                temp.append(padChar);
            }
            temp.append(source);
        } else {
            temp.append(source);
            for (int i = 0; i < length; i++) {
                temp.append(padChar);
            }
        }
        return temp.toString();
    }

    public static void main(String[] args) {
        runClusterer(new MyKMeans(), args);
    }
}