Java tutorial
/******************************************************************************* * Copyright 2015, 2016 Taylor G Smith * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package com.clust4j.algo; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.Random; import java.util.TreeSet; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentSkipListSet; import java.util.concurrent.RejectedExecutionException; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.util.FastMath; import com.clust4j.algo.NearestNeighborsParameters; import com.clust4j.algo.Neighborhood; import com.clust4j.algo.RadiusNeighborsParameters; import com.clust4j.except.IllegalClusterStateException; import com.clust4j.except.ModelNotFitException; import com.clust4j.kernel.RadialBasisKernel; import com.clust4j.kernel.GaussianKernel; import com.clust4j.log.LogTimer; import com.clust4j.log.Log.Tag.Algo; import com.clust4j.log.Loggable; import com.clust4j.metrics.pairwise.GeometricallySeparable; import com.clust4j.metrics.pairwise.SimilarityMetric; import com.clust4j.utils.EntryPair; import com.clust4j.utils.MatUtils; import com.clust4j.utils.VecUtils; /** * Mean shift is a procedure for locating the maxima of a density function given discrete * data sampled from that function. It is useful for detecting the modes of this density. * This is an iterative method, and we start with an initial estimate <i>x</i> . Let a * {@link RadialBasisKernel} function be given. This function determines the weight of nearby * points for re-estimation of the mean. Typically a {@link GaussianKernel} kernel on the * distance to the current estimate is used. * * @see <a href="https://en.wikipedia.org/wiki/Mean_shift">Mean shift on Wikipedia</a> * @author Taylor G Smith <tgsmith61591@gmail.com>, adapted from <a href="https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/cluster/mean_shift_.py">sklearn implementation</a> */ final public class MeanShift extends AbstractDensityClusterer implements CentroidLearner, Convergeable, NoiseyClusterer { /** * */ private static final long serialVersionUID = 4423672142693334046L; final public static double DEF_BANDWIDTH = 5.0; final public static int DEF_MAX_ITER = 300; final public static int DEF_MIN_BIN_FREQ = 1; final static double incrementAmt = 0.25; final public static HashSet<Class<? extends GeometricallySeparable>> UNSUPPORTED_METRICS; /** * Static initializer */ static { UNSUPPORTED_METRICS = new HashSet<>(); // Add metrics here if necessary... already vetoes any // similarity metrics, so this might be sufficient... } @Override final public boolean isValidMetric(GeometricallySeparable geo) { return !UNSUPPORTED_METRICS.contains(geo.getClass()) && !(geo instanceof SimilarityMetric); } /** The max iterations */ private final int maxIter; /** Min change convergence criteria */ private final double tolerance; /** The kernel bandwidth (volatile because can change in sync method) */ volatile private double bandwidth; /** Class labels */ volatile private int[] labels = null; /** The M x N seeds to be used as initial kernel points */ private double[][] seeds; /** Num rows, cols */ private final int n; /** Whether bandwidth is auto-estimated */ private final boolean autoEstimate; /** Track convergence */ private volatile boolean converged = false; /** The centroid records */ private volatile ArrayList<double[]> centroids; private volatile int numClusters; private volatile int numNoisey; /** Count iterations */ private volatile int itersElapsed = 0; /** * Default constructor * @param data * @param bandwidth */ protected MeanShift(RealMatrix data, final double bandwidth) { this(data, new MeanShiftParameters(bandwidth)); } /** * Default constructor for auto bandwidth estimation * @param data * @param bandwidth */ protected MeanShift(RealMatrix data) { this(data, new MeanShiftParameters()); } /** * Constructor with custom MeanShiftPlanner * @param data * @param planner */ protected MeanShift(RealMatrix data, MeanShiftParameters planner) { super(data, planner); // Check bandwidth... if (planner.getBandwidth() <= 0.0) error(new IllegalArgumentException("bandwidth " + "must be greater than 0.0")); // Check seeds dimension if (null != planner.getSeeds()) { if (planner.getSeeds().length == 0) error(new IllegalArgumentException("seeds " + "length must be greater than 0")); // Throws NonUniformMatrixException if non uniform... MatUtils.checkDimsForUniformity(planner.getSeeds()); if (planner.getSeeds()[0].length != (n = this.data.getColumnDimension())) error(new DimensionMismatchException(planner.getSeeds()[0].length, n)); if (planner.getSeeds().length > this.data.getRowDimension()) error(new IllegalArgumentException("seeds " + "length cannot exceed number of datapoints")); info("initializing kernels from given seeds"); // Handle the copying in the planner seeds = planner.getSeeds(); } else { // Default = all*/ info("no seeds provided; defaulting to all datapoints"); seeds = this.data.getData(); // use THIS as it's already scaled... n = this.data.getColumnDimension(); } /* * Check metric for validity */ if (!isValidMetric(this.dist_metric)) { warn(this.dist_metric.getName() + " is not valid for " + getName() + ". " + "Falling back to default Euclidean dist"); setSeparabilityMetric(DEF_DIST); } this.maxIter = planner.getMaxIter(); this.tolerance = planner.getConvergenceTolerance(); this.autoEstimate = planner.getAutoEstimate(); final LogTimer aeTimer = new LogTimer(); /* * Assign bandwidth */ this.bandwidth = /* if all singular, just pick a number... */ this.singular_value ? 0.5 : /* Otherwise if we're auto-estimating, estimate it */ autoEstimate ? autoEstimateBW(this, planner.getAutoEstimationQuantile()) : planner.getBandwidth(); /* * Give auto-estimation timer update */ if (autoEstimate && !this.singular_value) info("bandwidth auto-estimated in " + (parallel ? "parallel in " : "") + aeTimer.toString()); logModelSummary(); } @Override final protected ModelSummary modelSummary() { return new ModelSummary( new Object[] { "Num Rows", "Num Cols", "Metric", "Bandwidth", "Allow Par.", "Max Iter.", "Tolerance" }, new Object[] { data.getRowDimension(), data.getColumnDimension(), getSeparabilityMetric(), (autoEstimate ? "(auto) " : "") + bandwidth, parallel, maxIter, tolerance }); } /** * For testing... * @param data * @param quantile * @param sep * @param seed * @param parallel * @return */ final protected static double autoEstimateBW(Array2DRowRealMatrix data, double quantile, GeometricallySeparable sep, Random seed, boolean parallel) { return autoEstimateBW( new NearestNeighbors(data, new NearestNeighborsParameters((int) (data.getRowDimension() * quantile)).setSeed(seed) .setForceParallel(parallel)).fit(), data.getDataRef(), quantile, sep, seed, parallel, null); } /** * Actually called internally * @param caller * @param quantile * @return */ final protected static double autoEstimateBW(MeanShift caller, double quantile) { LogTimer timer = new LogTimer(); NearestNeighbors nn = new NearestNeighbors(caller, new NearestNeighborsParameters((int) (caller.data.getRowDimension() * quantile)) .setForceParallel(caller.parallel)).fit(); caller.info("fit nearest neighbors model for auto-bandwidth automation in " + timer.toString()); return autoEstimateBW(nn, caller.data.getDataRef(), quantile, caller.getSeparabilityMetric(), caller.getSeed(), caller.parallel, caller); } final protected static double autoEstimateBW(NearestNeighbors nn, double[][] data, double quantile, GeometricallySeparable sep, Random seed, boolean parallel, Loggable logger) { if (quantile <= 0 || quantile > 1) throw new IllegalArgumentException("illegal quantile"); final int m = data.length; double bw = 0.0; final double[][] X = nn.data.getDataRef(); final int minsize = ParallelChunkingTask.ChunkingStrategy.DEF_CHUNK_SIZE; final int chunkSize = X.length < minsize ? minsize : X.length / 5; final int numChunks = ParallelChunkingTask.ChunkingStrategy.getNumChunks(chunkSize, m); Neighborhood neighb; if (!parallel) { /* * For each chunk of 500, get the neighbors and then compute the * sum of the row maxes of the distance matrix. */ int chunkStart, nextChunk; for (int chunk = 0; chunk < numChunks; chunk++) { chunkStart = chunk * chunkSize; nextChunk = chunk == numChunks - 1 ? m : chunkStart + chunkSize; double[][] nextMatrix = new double[nextChunk - chunkStart][]; for (int i = chunkStart, j = 0; i < nextChunk; i++, j++) nextMatrix[j] = X[i]; neighb = nn.getNeighbors(nextMatrix); for (double[] distRow : neighb.getDistances()) { //bw += VecUtils.max(distRow); bw += distRow[distRow.length - 1]; // it's sorted! } } } else { // Estimate bandwidth in parallel bw = ParallelBandwidthEstimator.doAll(X, nn); } return bw / (double) m; } /** * Estimates the bandwidth of the model in parallel for scalability * @author Taylor G Smith */ static class ParallelBandwidthEstimator extends ParallelChunkingTask<Double> implements java.io.Serializable { private static final long serialVersionUID = 1171269106158790138L; final NearestNeighbors nn; final int high; final int low; ParallelBandwidthEstimator(double[][] X, NearestNeighbors nn) { // Use the SimpleChunker super(X); this.nn = nn; this.low = 0; this.high = strategy.getNumChunks(X); } ParallelBandwidthEstimator(ParallelBandwidthEstimator task, int low, int high) { super(task); this.nn = task.nn; this.low = low; this.high = high; } @Override protected Double compute() { if (high - low <= 1) { // generally should equal one... return reduce(chunks.get(low)); } else { int mid = this.low + (this.high - this.low) / 2; ParallelBandwidthEstimator left = new ParallelBandwidthEstimator(this, low, mid); ParallelBandwidthEstimator right = new ParallelBandwidthEstimator(this, mid, high); left.fork(); Double l = right.compute(); Double r = left.join(); return l + r; } } @Override public Double reduce(Chunk chunk) { double bw = 0.0; Neighborhood neighb = nn.getNeighbors(chunk.get(), false); for (double[] distRow : neighb.getDistances()) { //bw += VecUtils.max(distRow); bw += distRow[distRow.length - 1]; // it's sorted! } return bw; } static double doAll(double[][] X, NearestNeighbors nn) { return getThreadPool().invoke(new ParallelBandwidthEstimator(X, nn)); } } /** * Handles the output for the {@link #singleSeed(double[], RadiusNeighbors, double[][], int)} * method. Implements comparable to be sorted by the value in the entry pair. * @author Taylor G Smith */ protected static class MeanShiftSeed implements Comparable<MeanShiftSeed> { final double[] dists; /** The number of points in the bandwidth */ final Integer count; final int iterations; MeanShiftSeed(final double[] dists, final int count, int iterations) { this.dists = dists; this.count = count; this.iterations = iterations; } /* * we don't need these methods in the actual algo, and they just * create more need for testing to get good coverage, so we can * just omit them * @Override public boolean equals(Object o) { if(this == o) return true; if(o instanceof MeanShiftSeed) { MeanShiftSeed m = (MeanShiftSeed)o; return VecUtils.equalsExactly(dists, m.dists) && count.intValue() == m.count.intValue(); } return false; } @Override public String toString() { return "{" + Arrays.toString(dists) + " : " + count + "}"; } @Override public int hashCode() { int h = 31; for(double d: dists) h ^= (int)d; return h ^ count; } */ EntryPair<double[], Integer> getPair() { return new EntryPair<>(dists, count); } @Override public int compareTo(MeanShiftSeed o2) { int comp = count.compareTo(o2.count); if (comp == 0) { final double[] d2 = o2.dists; for (int i = 0; i < dists.length; i++) { int c = Double.valueOf(dists[i]).compareTo(d2[i]); if (c != 0) return -c; } } return -comp; } } /** * Light struct to hold summary info * @author Taylor G Smith */ static class SummaryLite { final String name; final int iters; final String fmtTime; final String wallTime; boolean retained = false; SummaryLite(final String nm, final int iter, final String fmt, final String wall) { this.name = nm; this.iters = iter; this.fmtTime = fmt; this.wallTime = wall; } Object[] toArray() { return new Object[] { name, iters, fmtTime, wallTime, retained }; } } /** * The superclass for parallelized MeanShift tasks * @author Taylor G Smith * @param <T> */ abstract static class ParallelMSTask<T> extends ParallelChunkingTask<T> { private static final long serialVersionUID = 2139716909891672022L; final ConcurrentLinkedDeque<SummaryLite> summaries; final double[][] X; ParallelMSTask(double[][] X, ConcurrentLinkedDeque<SummaryLite> summaries) { super(X); this.summaries = summaries; this.X = X; } ParallelMSTask(ParallelMSTask<T> task) { super(task); this.summaries = task.summaries; this.X = task.X; } public String formatName(String str) { StringBuilder sb = new StringBuilder(); boolean hyphen = false; // have we hit the hyphen yet? boolean started_worker = false; boolean seen_k = false; boolean finished_worker = false; for (char c : str.toCharArray()) { if (hyphen || Character.isUpperCase(c)) { if (started_worker && !finished_worker) { if (c == 'k') { // past first 'r'... seen_k = true; continue; } // in the middle of the word "worker" if (c != 'r') continue; else if (!seen_k) continue; // At the last char in 'worker' finished_worker = true; sb.append("Kernel"); } else if (!started_worker && c == 'w') { started_worker = true; } else { sb.append(c); } } else if ('-' == c) { hyphen = true; sb.append(c); } } return sb.toString(); } } /** * Class that handles construction of the center intensity object * @author Taylor G Smith */ static abstract class CenterIntensity implements java.io.Serializable, Iterable<MeanShiftSeed> { private static final long serialVersionUID = -6535787295158719610L; abstract int getIters(); abstract boolean isEmpty(); abstract ArrayList<SummaryLite> getSummaries(); abstract int size(); } /** * A class that utilizes a {@link java.util.concurrent.ForkJoinPool} * as parallel executors to run many tasks across multiple cores. * @author Taylor G Smith */ static class ParallelSeedExecutor extends ParallelMSTask<ConcurrentSkipListSet<MeanShiftSeed>> { private static final long serialVersionUID = 632871644265502894L; final int maxIter; final RadiusNeighbors nbrs; final ConcurrentSkipListSet<MeanShiftSeed> computedSeeds; final int high, low; ParallelSeedExecutor(int maxIter, double[][] X, RadiusNeighbors nbrs, ConcurrentLinkedDeque<SummaryLite> summaries) { /** * Pass summaries reference to super */ super(X, summaries); this.maxIter = maxIter; this.nbrs = nbrs; this.computedSeeds = new ConcurrentSkipListSet<>(); this.low = 0; this.high = strategy.getNumChunks(X); } ParallelSeedExecutor(ParallelSeedExecutor task, int low, int high) { super(task); this.maxIter = task.maxIter; this.nbrs = task.nbrs; this.computedSeeds = task.computedSeeds; this.high = high; this.low = low; } @Override protected ConcurrentSkipListSet<MeanShiftSeed> compute() { if (high - low <= 1) { // generally should equal one... return reduce(chunks.get(low)); } else { int mid = this.low + (this.high - this.low) / 2; ParallelSeedExecutor left = new ParallelSeedExecutor(this, low, mid); ParallelSeedExecutor right = new ParallelSeedExecutor(this, mid, high); left.fork(); right.compute(); left.join(); return computedSeeds; } } @Override public ConcurrentSkipListSet<MeanShiftSeed> reduce(Chunk chunk) { for (double[] seed : chunk.get()) { MeanShiftSeed ms = singleSeed(seed, nbrs, X, maxIter); if (null == ms) continue; computedSeeds.add(ms); String nm = getName(); summaries.add(new SummaryLite(nm, ms.iterations, timer.formatTime(), timer.wallTime())); } return computedSeeds; } static ConcurrentSkipListSet<MeanShiftSeed> doAll(int maxIter, double[][] X, RadiusNeighbors nbrs, ConcurrentLinkedDeque<SummaryLite> summaries) { return getThreadPool().invoke(new ParallelSeedExecutor(maxIter, X, nbrs, summaries)); } } class ParallelCenterIntensity extends CenterIntensity { private static final long serialVersionUID = 4392163493242956320L; final ConcurrentSkipListSet<Integer> itrz = new ConcurrentSkipListSet<>(); final ConcurrentSkipListSet<MeanShiftSeed> computedSeeds; /** Serves as a reference for passing to parallel job */ final ConcurrentLinkedDeque<SummaryLite> summaries = new ConcurrentLinkedDeque<>(); final LogTimer timer; final RadiusNeighbors nbrs; ParallelCenterIntensity(RadiusNeighbors nbrs) { this.nbrs = nbrs; this.timer = new LogTimer(); // Execute forkjoinpool this.computedSeeds = ParallelSeedExecutor.doAll(maxIter, seeds, nbrs, summaries); for (MeanShiftSeed sd : computedSeeds) itrz.add(sd.iterations); } @Override public int getIters() { return itrz.last(); } @Override public ArrayList<SummaryLite> getSummaries() { return new ArrayList<>(summaries); } @Override public boolean isEmpty() { return computedSeeds.isEmpty(); } @Override public Iterator<MeanShiftSeed> iterator() { return computedSeeds.iterator(); } @Override public int size() { return computedSeeds.size(); } } /** * Compute the center intensity entry pairs serially and call the * {@link MeanShift#singleSeed(double[], RadiusNeighbors, double[][], int)} method * @author Taylor G Smith */ class SerialCenterIntensity extends CenterIntensity { private static final long serialVersionUID = -1117327079708746405L; int itrz = 0; final TreeSet<MeanShiftSeed> computedSeeds; final ArrayList<SummaryLite> summaries = new ArrayList<>(); SerialCenterIntensity(RadiusNeighbors nbrs) { LogTimer timer; // Now get single seed members MeanShiftSeed sd; this.computedSeeds = new TreeSet<>(); final double[][] X = data.getData(); int idx = 0; for (double[] seed : seeds) { idx++; timer = new LogTimer(); sd = singleSeed(seed, nbrs, X, maxIter); if (null == sd) continue; computedSeeds.add(sd); itrz = FastMath.max(itrz, sd.iterations); // If it actually converged, add the summary summaries.add(new SummaryLite("Kernel " + (idx - 1), sd.iterations, timer.formatTime(), timer.wallTime())); } } @Override public int getIters() { return itrz; } @Override public ArrayList<SummaryLite> getSummaries() { return summaries; } @Override public boolean isEmpty() { return computedSeeds.isEmpty(); } @Override public Iterator<MeanShiftSeed> iterator() { return computedSeeds.iterator(); } @Override public int size() { return computedSeeds.size(); } } /** * Get the kernel bandwidth * @return kernel bandwidth */ public double getBandwidth() { return bandwidth; } /** {@inheritDoc} */ @Override public boolean didConverge() { return converged; } /** {@inheritDoc} */ @Override public int itersElapsed() { return itersElapsed; } /** * Returns a copy of the seeds matrix * @return */ public double[][] getKernelSeeds() { return MatUtils.copy(seeds); } /** {@inheritDoc} */ @Override public int getMaxIter() { return maxIter; } /** {@inheritDoc} */ @Override public double getConvergenceTolerance() { return tolerance; } @Override public String getName() { return "MeanShift"; } @Override public Algo getLoggerTag() { return com.clust4j.log.Log.Tag.Algo.MEANSHIFT; } @Override protected MeanShift fit() { synchronized (fitLock) { if (null != labels) // Already fit this model return this; // Put the results into a Map (hash because tree imposes comparable casting) final LogTimer timer = new LogTimer(); centroids = new ArrayList<double[]>(); /* * Get the neighborhoods and center intensity object. Will iterate until * either the centers are found, or the max try count is exceeded. For each * iteration, will increase bandwidth. */ RadiusNeighbors nbrs = new RadiusNeighbors(this, bandwidth).fit(); // Compute the seeds and center intensity // If parallelism is permitted, try it. CenterIntensity intensity = null; if (parallel) { try { intensity = new ParallelCenterIntensity(nbrs); } catch (RejectedExecutionException e) { // Shouldn't happen... warn("parallel search failed; falling back to serial"); } } // Gets here if serial or if parallel failed... if (null == intensity) intensity = new SerialCenterIntensity(nbrs); // Check for points all too far from seeds if (intensity.isEmpty()) { error(new IllegalClusterStateException("No point " + "was within bandwidth=" + bandwidth + " of any seed; try increasing bandwidth")); } else { converged = true; itersElapsed = intensity.getIters(); // max iters elapsed } // Extract the centroids int idx = 0, m_prime = intensity.size(); final Array2DRowRealMatrix sorted_centers = new Array2DRowRealMatrix(m_prime, n); for (MeanShiftSeed entry : intensity) sorted_centers.setRow(idx++, entry.getPair().getKey()); // Fit the new neighbors model nbrs = new RadiusNeighbors(sorted_centers, new RadiusNeighborsParameters(bandwidth) .setSeed(this.random_state).setMetric(this.dist_metric).setForceParallel(parallel), true).fit(); // Post-processing. Remove near duplicate seeds // If dist btwn two kernels is less than bandwidth, remove one w fewer pts // Create a boolean mask, init true final boolean[] unique = new boolean[m_prime]; for (int i = 0; i < unique.length; i++) unique[i] = true; // Pre-filtered summaries... ArrayList<SummaryLite> allSummary = intensity.getSummaries(); // Iterate over sorted centers and query radii int redundant_ct = 0; int[] indcs; double[] center; for (int i = 0; i < m_prime; i++) { if (unique[i]) { center = sorted_centers.getRow(i); indcs = nbrs.getNeighbors(new double[][] { center }, bandwidth, false).getIndices()[0]; for (int id : indcs) unique[id] = false; unique[i] = true; // Keep this as true } } // Now assign the centroids... SummaryLite summ; for (int i = 0; i < unique.length; i++) { summ = allSummary.get(i); if (unique[i]) { summ.retained = true; centroids.add(sorted_centers.getRow(i)); } fitSummary.add(summ.toArray()); } // calc redundant ct redundant_ct = unique.length - centroids.size(); // also put the centroids into a matrix. We have to // wait to perform this op, because we have to know // the size of centroids first... Array2DRowRealMatrix centers = new Array2DRowRealMatrix(centroids.size(), n); for (int i = 0; i < centroids.size(); i++) centers.setRow(i, centroids.get(i)); // Build yet another neighbors model... NearestNeighbors nn = new NearestNeighbors(centers, new NearestNeighborsParameters(1) .setSeed(this.random_state).setMetric(this.dist_metric).setForceParallel(false), true).fit(); info((numClusters = centroids.size()) + " optimal kernel" + (numClusters != 1 ? "s" : "") + " identified"); info(redundant_ct + " nearly-identical kernel" + (redundant_ct != 1 ? "s" : "") + " removed"); // Get the nearest... final LogTimer clustTimer = new LogTimer(); Neighborhood knrst = nn.getNeighbors(data.getDataRef()); labels = MatUtils.flatten(knrst.getIndices()); // order the labels.. /* * Reduce labels to a sorted, gapless, list * sklearn line: cluster_centers_indices = np.unique(labels) */ ArrayList<Integer> centroidIndices = new ArrayList<Integer>(numClusters); for (Integer i : labels) // force autobox if (!centroidIndices.contains(i)) // Not race condition because synchronized centroidIndices.add(i); /* * final label assignment... * sklearn line: labels = np.searchsorted(cluster_centers_indices, labels) */ for (int i = 0; i < labels.length; i++) labels[i] = centroidIndices.indexOf(labels[i]); // Wrap up... // Count missing numNoisey = 0; for (int lab : labels) if (lab == NOISE_CLASS) numNoisey++; info(numNoisey + " record" + (numNoisey != 1 ? "s" : "") + " classified noise"); info("completed cluster labeling in " + clustTimer.toString()); sayBye(timer); return this; } } // End train @Override public ArrayList<double[]> getCentroids() { if (null != centroids) { final ArrayList<double[]> cent = new ArrayList<double[]>(); for (double[] d : centroids) cent.add(VecUtils.copy(d)); return cent; } else { error(new ModelNotFitException("model has not yet been fit")); return null; // can't happen } } @Override public int[] getLabels() { return super.handleLabelCopy(labels); } static MeanShiftSeed singleSeed(double[] seed, RadiusNeighbors rn, double[][] X, int maxIter) { final double bandwidth = rn.getRadius(), tolerance = 1e-3; final int n = X[0].length; // we know X is uniform int completed_iterations = 0; double norm, diff; while (true) { Neighborhood nbrs = rn.getNeighbors(new double[][] { seed }, bandwidth, false); int[] i_nbrs = nbrs.getIndices()[0]; // Check if exit if (i_nbrs.length == 0) break; // Save the old seed final double[] oldSeed = seed; // Get the points inside and simultaneously calc new seed final double[] newSeed = new double[n]; norm = 0; diff = 0; for (int i = 0; i < i_nbrs.length; i++) { final double[] record = X[i_nbrs[i]]; for (int j = 0; j < n; j++) { newSeed[j] += record[j]; // Last iter hack, go ahead and compute means simultaneously if (i == i_nbrs.length - 1) { newSeed[j] /= (double) i_nbrs.length; diff = newSeed[j] - oldSeed[j]; norm += diff * diff; } } } // Assign the new seed seed = newSeed; norm = FastMath.sqrt(norm); // Check stopping criteria if (completed_iterations++ == maxIter || norm < tolerance) return new MeanShiftSeed(seed, i_nbrs.length, completed_iterations); } // Default... shouldn't get here though return null; } @Override final protected Object[] getModelFitSummaryHeaders() { return new Object[] { "Seed ID", "Iterations", "Iter. Time", "Wall", "Retained" }; } @Override public int getNumberOfIdentifiedClusters() { return numClusters; } @Override public int getNumberOfNoisePoints() { return numNoisey; } /** {@inheritDoc} */ @Override public int[] predict(RealMatrix newData) { return CentroidUtils.predict(this, newData); } }