Java tutorial
package edu.uc.rphash.tests.kmeanspp; /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Random; import edu.uc.rphash.Centroid; import edu.uc.rphash.Clusterer; import edu.uc.rphash.Readers.RPHashObject; import edu.uc.rphash.Readers.SimpleArrayReader; import edu.uc.rphash.tests.StatTests; import edu.uc.rphash.tests.generators.GenerateData; import edu.uc.rphash.util.VectorUtil; /** * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> * @since 2.0 * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlus} instead */ public class KMeansPlusPlus<T extends Clusterable<T>> implements Clusterer { /** Strategies to use for replacing an empty cluster. */ public static enum EmptyClusterStrategy { /** Split the cluster with largest distance variance. */ LARGEST_VARIANCE, /** Split the cluster with largest number of points. */ LARGEST_POINTS_NUMBER, /** Create a cluster around the point farthest from its centroid. */ FARTHEST_POINT, /** Generate an error. */ ERROR } /** Random generator for choosing initial centers. */ private Random random; /** Selected strategy for empty clusters. */ private EmptyClusterStrategy emptyStrategy; final int dim; /** * Runs the K-means++ clustering algorithm. * * @param points the points to cluster * @param k the number of clusters to split the data into * @param numTrials number of trial runs * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm * for at each trial run. If negative, no maximum will be used * @return a list of clusters containing the points * @throws MathIllegalArgumentException if the data points are null or the number * of clusters is larger than the number of data points * @throws ConvergenceException if an empty cluster is encountered and the * {@link #emptyStrategy} is set to {@code ERROR} */ public List<Cluster<T>> cluster(final Collection<T> points, final int k, int numTrials, int maxIterationsPerTrial) throws Exception { // at first, we have not found any clusters list yet List<Cluster<T>> best = null; double bestVarianceSum = Double.POSITIVE_INFINITY; // do several clustering trials for (int i = 0; i < numTrials; ++i) { // compute a clusters list List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial); // compute the variance of the current list double varianceSum = 0.0; for (final Cluster<T> cluster : clusters) { if (!cluster.getPoints().isEmpty()) { // compute the distance variance of the current cluster final T center = cluster.getCenter(); double n = 0; double mean = 0; double M2 = 0; for (final T point : cluster.getPoints()) { double x = point.distanceFrom(center); n++; double delta = x - mean; mean = mean + delta / n; M2 = M2 + delta * (x - mean); } varianceSum += M2 / (n - 1f); // final Variance stat = new Variance(); // for (final T point : cluster.getPoints()) { // stat.increment(point.distanceFrom(center)); // } // varianceSum += stat.getResult(); } } if (varianceSum <= bestVarianceSum) { // this one is the best we have found so far, remember it best = clusters; bestVarianceSum = varianceSum; } } // return the best clusters list found return best; } /** * Runs the K-means++ clustering algorithm. * * @param points the points to cluster * @param k the number of clusters to split the data into * @param maxIterations the maximum number of iterations to run the algorithm * for. If negative, no maximum will be used * @return a list of clusters containing the points * @throws MathIllegalArgumentException if the data points are null or the number * of clusters is larger than the number of data points * @throws ConvergenceException if an empty cluster is encountered and the * {@link #emptyStrategy} is set to {@code ERROR} */ public List<Cluster<T>> cluster(final Collection<T> points, final int k, final int maxIterations) throws Exception { // sanity checks // MathUtils.checkNotNull(points); // number of clusters has to be smaller or equal the number of data points if (points.size() < k) { throw new Exception(); } // create the initial clusters List<Cluster<T>> clusters = chooseInitialCenters(points, k, random); // create an array containing the latest assignment of a point to a cluster // no need to initialize the array, as it will be filled with the first assignment int[] assignments = new int[points.size()]; assignPointsToClusters(clusters, points, assignments); // iterate through updating the centers until we're done final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; for (int count = 0; count < max; count++) { boolean emptyCluster = false; List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>(); for (final Cluster<T> cluster : clusters) { final T newCenter; if (cluster.getPoints().isEmpty()) { switch (emptyStrategy) { case LARGEST_VARIANCE: newCenter = getPointFromLargestVarianceCluster(clusters); break; case LARGEST_POINTS_NUMBER: newCenter = getPointFromLargestNumberCluster(clusters); break; case FARTHEST_POINT: newCenter = getFarthestPoint(clusters); break; default: throw new Exception(); } emptyCluster = true; } else { newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); } newClusters.add(new Cluster<T>(newCenter)); } int changes = assignPointsToClusters(newClusters, points, assignments); clusters = newClusters; // if there were no more changes in the point-to-cluster assignment // and there are no empty clusters left, return the current clusters if (changes == 0 && !emptyCluster) { return clusters; } } return clusters; } /** * Adds the given points to the closest {@link Cluster}. * * @param <T> type of the points to cluster * @param clusters the {@link Cluster}s to add the points to * @param points the points to add to the given {@link Cluster}s * @param assignments points assignments to clusters * @return the number of points assigned to different clusters as the iteration before */ private static <T extends Clusterable<T>> int assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points, final int[] assignments) { int assignedDifferently = 0; int pointIndex = 0; for (final T p : points) { int clusterIndex = getNearestCluster(clusters, p); if (clusterIndex != assignments[pointIndex]) { assignedDifferently++; } Cluster<T> cluster = clusters.get(clusterIndex); cluster.addPoint(p); assignments[pointIndex++] = clusterIndex; } return assignedDifferently; } /** * Use K-means++ to choose the initial centers. * * @param <T> type of the points to cluster * @param points the points to choose the initial centers from * @param k the number of centers to choose * @param random random generator to use * @return the initial centers */ private static <T extends Clusterable<T>> List<Cluster<T>> chooseInitialCenters(final Collection<T> points, final int k, final Random random) { // Convert to list for indexed access. Make it unmodifiable, since removal of items // would screw up the logic of this method. final List<T> pointList = Collections.unmodifiableList(new ArrayList<T>(points)); // The number of points in the list. final int numPoints = pointList.size(); // Set the corresponding element in this array to indicate when // elements of pointList are no longer available. final boolean[] taken = new boolean[numPoints]; // The resulting list of initial centers. final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>(); // Choose one center uniformly at random from among the data points. final int firstPointIndex = random.nextInt(numPoints); final T firstPoint = pointList.get(firstPointIndex); resultSet.add(new Cluster<T>(firstPoint)); // Must mark it as taken taken[firstPointIndex] = true; // To keep track of the minimum distance squared of elements of // pointList to elements of resultSet. final double[] minDistSquared = new double[numPoints]; // Initialize the elements. Since the only point in resultSet is firstPoint, // this is very easy. for (int i = 0; i < numPoints; i++) { if (i != firstPointIndex) { // That point isn't considered double d = firstPoint.distanceFrom(pointList.get(i)); minDistSquared[i] = d * d; } } while (resultSet.size() < k) { // Sum up the squared distances for the points in pointList not // already taken. double distSqSum = 0.0; for (int i = 0; i < numPoints; i++) { if (!taken[i]) { distSqSum += minDistSquared[i]; } } // Add one new data point as a center. Each point x is chosen with // probability proportional to D(x)2 final double r = random.nextDouble() * distSqSum; // The index of the next point to be added to the resultSet. int nextPointIndex = -1; // Sum through the squared min distances again, stopping when // sum >= r. double sum = 0.0; for (int i = 0; i < numPoints; i++) { if (!taken[i]) { sum += minDistSquared[i]; if (sum >= r) { nextPointIndex = i; break; } } } // If it's not set to >= 0, the point wasn't found in the previous // for loop, probably because distances are extremely small. Just pick // the last available point. if (nextPointIndex == -1) { for (int i = numPoints - 1; i >= 0; i--) { if (!taken[i]) { nextPointIndex = i; break; } } } // We found one. if (nextPointIndex >= 0) { final T p = pointList.get(nextPointIndex); resultSet.add(new Cluster<T>(p)); // Mark it as taken. taken[nextPointIndex] = true; if (resultSet.size() < k) { // Now update elements of minDistSquared. We only have to compute // the distance to the new center to do this. for (int j = 0; j < numPoints; j++) { // Only have to worry about the points still not taken. if (!taken[j]) { double d = p.distanceFrom(pointList.get(j)); double d2 = d * d; if (d2 < minDistSquared[j]) { minDistSquared[j] = d2; } } } } } else { // None found -- // Break from the while loop to prevent // an infinite loop. break; } } return resultSet; } /** * Get a random point from the {@link Cluster} with the largest distance variance. * * @param clusters the {@link Cluster}s to search * @return a random point from the selected cluster * @throws ConvergenceException if clusters are all empty */ private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) throws Exception { double maxVariance = Double.NEGATIVE_INFINITY; Cluster<T> selected = null; for (final Cluster<T> cluster : clusters) { if (!cluster.getPoints().isEmpty()) { // compute the distance variance of the current cluster final T center = cluster.getCenter(); //final Variance stat = new Variance(); double n = 0; double mean = 0; double M2 = 0; for (final T point : cluster.getPoints()) { double x = point.distanceFrom(center); n++; double delta = x - mean; mean = mean + delta / n; M2 = M2 + delta * (x - mean); //stat.increment(point.distanceFrom(center)); } final double variance = M2 / (n - 1f);//stat.getResult(); // select the cluster with the largest variance if (variance > maxVariance) { maxVariance = variance; selected = cluster; } } } // did we find at least one non-empty cluster ? if (selected == null) { throw new Exception(); } // extract a random point from the cluster final List<T> selectedPoints = selected.getPoints(); return selectedPoints.remove(random.nextInt(selectedPoints.size())); } /** * Get a random point from the {@link Cluster} with the largest number of points * * @param clusters the {@link Cluster}s to search * @return a random point from the selected cluster * @throws ConvergenceException if clusters are all empty */ private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws Exception { int maxNumber = 0; Cluster<T> selected = null; for (final Cluster<T> cluster : clusters) { // get the number of points of the current cluster final int number = cluster.getPoints().size(); // select the cluster with the largest number of points if (number > maxNumber) { maxNumber = number; selected = cluster; } } // did we find at least one non-empty cluster ? if (selected == null) { throw new Exception(); } // extract a random point from the cluster final List<T> selectedPoints = selected.getPoints(); return selectedPoints.remove(random.nextInt(selectedPoints.size())); } /** * Get the point farthest to its cluster center * * @param clusters the {@link Cluster}s to search * @return point farthest to its cluster center * @throws ConvergenceException if clusters are all empty */ private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws Exception { double maxDistance = Double.NEGATIVE_INFINITY; Cluster<T> selectedCluster = null; int selectedPoint = -1; for (final Cluster<T> cluster : clusters) { // get the farthest point final T center = cluster.getCenter(); final List<T> points = cluster.getPoints(); for (int i = 0; i < points.size(); ++i) { final double distance = points.get(i).distanceFrom(center); if (distance > maxDistance) { maxDistance = distance; selectedCluster = cluster; selectedPoint = i; } } } // did we find at least one non-empty cluster ? if (selectedCluster == null) { throw new Exception(); } return selectedCluster.getPoints().remove(selectedPoint); } /** * Returns the nearest {@link Cluster} to the given point * * @param <T> type of the points to cluster * @param clusters the {@link Cluster}s to search * @param point the point to find the nearest {@link Cluster} for * @return the index of the nearest {@link Cluster} to the given point */ private static <T extends Clusterable<T>> int getNearestCluster(final Collection<Cluster<T>> clusters, final T point) { double minDistance = Double.MAX_VALUE; int clusterIndex = 0; int minCluster = 0; for (final Cluster<T> c : clusters) { final double distance = point.distanceFrom(c.getCenter()); if (distance < minDistance) { minDistance = distance; minCluster = clusterIndex; } clusterIndex++; } return minCluster; } Collection<T> points; final int maxIterations; int k; List<float[]> centroids;// // /** Build a clusterer. // * <p> // * The default strategy for handling empty clusters that may appear during // * algorithm iterations is to split the cluster with largest distance variance. // * </p> // * @param random random generator to use for choosing initial centers // */ // public KMeansPlusPlusClusterer(final Random random) { // this(random, EmptyClusterStrategy.LARGEST_VARIANCE); // } // // /** Build a clusterer. // * @param random random generator to use for choosing initial centers // * @param emptyStrategy strategy to use for handling empty clusters that // * may appear during algorithm iterations // * @since 2.2 // */ // public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) { // this.random = random; // this.emptyStrategy = emptyStrategy; // } // final private List<float[]> data; @SuppressWarnings("unchecked") public KMeansPlusPlus(List<float[]> data, int k) { random = new Random(); emptyStrategy = EmptyClusterStrategy.LARGEST_POINTS_NUMBER; this.k = k; maxIterations = 10000; points = new ArrayList<>(); // this.data = data; this.dim = data.get(0).length; for (float[] f : data) { DoublePoint p = new DoublePoint(f); points.add((T) p); } centroids = null; } @Override public List<Centroid> getCentroids() { ArrayList<Centroid> ret = new ArrayList<>(); List<Cluster<T>> clusters = null; if (centroids == null) { try { clusters = cluster(points, k, maxIterations); } catch (Exception e) { e.printStackTrace(); } } for (Cluster<T> c : clusters) { float[] retret = new float[dim]; for (int i = 0; i < dim; i++) retret[i] = 0; for (T point : c.getPoints()) { double[] dpoint = ((DoublePoint) point).getPoint(); for (int i = 0; i < dim; i++) retret[i] += dpoint[i]; } float sclr = 1f / ((float) c.getPoints().size() + 1); for (int i = 0; i < dim; i++) retret[i] *= sclr; ret.add(new Centroid(retret, 0)); } return ret; } @Override public RPHashObject getParam() { return null; } public static void main(String[] args) { GenerateData gen = new GenerateData(3, 5000, 2, .1f); KMeansPlusPlus<DoublePoint> kk = new KMeansPlusPlus<>(gen.data(), 3); System.out.println(StatTests.WCSSECentroidsFloat(kk.getCentroids(), gen.getData())); System.gc(); } @Override public void setWeights(List<Float> counts) { } @Override public void setK(int getk) { this.k = getk; } @Override public void setRawData(List<float[]> centroids) { random = new Random(); points = new ArrayList<>(); for (float[] f : centroids) { DoublePoint p = new DoublePoint(f); points.add((T) p); } centroids = null; } @Override public void setData(List<Centroid> centroids) { random = new Random(); points = new ArrayList<>(); for (Centroid f : centroids) { DoublePoint p = new DoublePoint(f.centroid()); points.add((T) p); } centroids = null; } @Override public void reset(int randomseed) { random = new Random(randomseed); centroids = null; } @Override public boolean setMultiRun(int runs) { return false; } }