Java tutorial
/** * Copyright (c) Acroquest Technology Co, Ltd. All Rights Reserved. * Please read the associated COPYRIGHTS file for more details. * * THE SOFTWARE IS PROVIDED BY Acroquest Technolog Co., Ltd., * WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING * BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDER BE LIABLE FOR ANY * CLAIM, DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING * OR DISTRIBUTING THIS SOFTWARE OR ITS DERIVATIVES. */ package acromusashi.stream.ml.clustering.kmeans; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Random; import java.util.Set; import java.util.TreeMap; import org.apache.commons.collections.ComparatorUtils; import org.apache.commons.math.util.MathUtils; import acromusashi.stream.ml.clustering.kmeans.entity.CentroidMapping; import acromusashi.stream.ml.clustering.kmeans.entity.CentroidsComparator; import acromusashi.stream.ml.clustering.kmeans.entity.KmeansDataSet; import acromusashi.stream.ml.clustering.kmeans.entity.KmeansPoint; import acromusashi.stream.ml.clustering.kmeans.entity.KmeansResult; import com.google.common.collect.Lists; import com.google.common.collect.Maps; /** * KMeans??? * * @author kimura */ public class KmeansCalculator { /** BinarySearch???????--1???????? */ private static final int COMPENSATE_INDEX = -2; /** * ???? */ private KmeansCalculator() { } /** * ??????<br> * * @param pointList ? * @param clusterNum * @param maxIteration * @param convergenceThres ?????? * @return ??? */ public static KmeansDataSet createDataModel(List<KmeansPoint> pointList, int clusterNum, int maxIteration, double convergenceThres) { // ???????null???? if (pointList.size() < clusterNum) { return null; } // ???? List<KmeansPoint> centroids = createInitialCentroids(pointList, clusterNum); long[] clusteredNum = new long[clusterNum]; // ????? for (int exeIndex = 0; exeIndex < maxIteration; exeIndex++) { Map<Integer, List<KmeansPoint>> assignments = Maps.newHashMap(); for (int centroidIndex = 0; centroidIndex < clusterNum; centroidIndex++) { assignments.put(centroidIndex, Lists.<KmeansPoint>newArrayList()); } for (KmeansPoint targetPoint : pointList) { KmeansResult result = nearestCentroid(targetPoint, centroids); assignments.get(result.getCentroidIndex()).add(targetPoint); } List<KmeansPoint> newCentroids = Lists.newArrayList(); for (Map.Entry<Integer, List<KmeansPoint>> entry : assignments.entrySet()) { if (entry.getValue().isEmpty()) { newCentroids.add(centroids.get(entry.getKey())); } else { newCentroids.add(calculateCentroid(entry.getValue())); } clusteredNum[entry.getKey()] = entry.getValue().size(); } boolean isConvergenced = isConvergenced(centroids, newCentroids, convergenceThres); centroids = newCentroids; if (isConvergenced == true) { break; } } double[][] centroidPoints = new double[clusterNum][]; for (int centroidIndex = 0; centroidIndex < clusterNum; centroidIndex++) { centroidPoints[centroidIndex] = centroids.get(centroidIndex).getDataPoint(); } KmeansDataSet createdModel = new KmeansDataSet(); createdModel.setCentroids(centroidPoints); createdModel.setClusteredNum(clusteredNum); return createdModel; } /** * ????? * * @param basePoints * @return */ public static KmeansPoint calculateCentroid(List<KmeansPoint> basePoints) { double[] firstDataPoint = basePoints.get(0).getDataPoint(); double[] centroidSum = Arrays.copyOf(firstDataPoint, firstDataPoint.length); for (int pointIndex = 1; pointIndex < basePoints.size(); pointIndex++) { for (int coordinateIndex = 0; coordinateIndex < centroidSum.length; coordinateIndex++) { centroidSum[coordinateIndex] = centroidSum[coordinateIndex] + basePoints.get(pointIndex).getDataPoint()[coordinateIndex]; } } double[] centroidPoints = sub(centroidSum, basePoints.size()); KmeansPoint centroid = new KmeansPoint(); centroid.setDataPoint(centroidPoints); return centroid; } /** * ??????? * * @param basePoints ?? * @param newPoints ? * @param convergenceThres ?? * @return ?????true?????????false */ public static boolean isConvergenced(List<KmeansPoint> basePoints, List<KmeansPoint> newPoints, double convergenceThres) { boolean result = true; for (int pointIndex = 0; pointIndex < basePoints.size(); pointIndex++) { double distance = MathUtils.distance(basePoints.get(pointIndex).getDataPoint(), newPoints.get(pointIndex).getDataPoint()); if (distance > convergenceThres) { result = false; break; } } return result; } /** * ???????????? * * @param targetPoint ? * @param centroids ? * @return Kmeans? */ public static KmeansResult nearestCentroid(double[] targetPoint, double[][] centroids) { int nearestCentroidIndex = 0; Double minDistance = Double.MAX_VALUE; double[] currentCentroid = null; Double currentDistance; for (int index = 0; index < centroids.length; index++) { currentCentroid = centroids[index]; if (currentCentroid != null) { currentDistance = MathUtils.distance(targetPoint, currentCentroid); if (currentDistance < minDistance) { minDistance = currentDistance; nearestCentroidIndex = index; } } } currentCentroid = centroids[nearestCentroidIndex]; KmeansResult result = new KmeansResult(); result.setDataPoint(targetPoint); result.setCentroidIndex(nearestCentroidIndex); result.setCentroid(currentCentroid); result.setDistance(minDistance); return result; } /** * ???????????? * * @param targetPoint ? * @param centroids * @return Kmeans? */ public static KmeansResult nearestCentroid(KmeansPoint targetPoint, List<KmeansPoint> centroids) { int nearestCentroidIndex = 0; Double minDistance = Double.MAX_VALUE; KmeansPoint currentCentroid = null; Double currentDistance; for (int index = 0; index < centroids.size(); index++) { currentCentroid = centroids.get(index); if (currentCentroid != null && currentCentroid.getDataPoint() != null) { currentDistance = MathUtils.distance(targetPoint.getDataPoint(), currentCentroid.getDataPoint()); if (currentDistance < minDistance) { minDistance = currentDistance; nearestCentroidIndex = index; } } } currentCentroid = centroids.get(nearestCentroidIndex); KmeansResult result = new KmeansResult(); result.setDataPoint(targetPoint.getDataPoint()); result.setCentroidIndex(nearestCentroidIndex); result.setCentroid(currentCentroid.getDataPoint()); result.setDistance(minDistance); return result; } /** * ?????????? * * @param targetPoint ? * @param dataSet * @return ?????? */ public static KmeansResult classify(KmeansPoint targetPoint, KmeansDataSet dataSet) { // KMean? int nearestCentroidIndex = 0; Double minDistance = Double.MAX_VALUE; double[] currentCentroid = null; Double currentDistance; for (int index = 0; index < dataSet.getCentroids().length; index++) { currentCentroid = dataSet.getCentroids()[index]; if (currentCentroid != null) { currentDistance = MathUtils.distance(targetPoint.getDataPoint(), currentCentroid); if (currentDistance < minDistance) { minDistance = currentDistance; nearestCentroidIndex = index; } } } currentCentroid = dataSet.getCentroids()[nearestCentroidIndex]; KmeansResult result = new KmeansResult(); result.setDataPoint(targetPoint.getDataPoint()); result.setCentroidIndex(nearestCentroidIndex); result.setCentroid(currentCentroid); result.setDistance(minDistance); return result; } /** * KMeans++????? * * @param basePoints ?? * @param clusterNum * @return */ public static List<KmeansPoint> createInitialCentroids(List<KmeansPoint> basePoints, int clusterNum) { Random random = new Random(); List<KmeansPoint> resultList = Lists.newArrayList(); // ?????????? List<KmeansPoint> pointList = Lists.newArrayList(basePoints); KmeansPoint firstCentroid = pointList.remove(random.nextInt(pointList.size())); resultList.add(firstCentroid); double[] dxs; // KMeans++?????? // ??1????????1???? for (int centroidIndex = 1; centroidIndex < clusterNum; centroidIndex++) { // ????????????? dxs = computeDxs(pointList, resultList); // ?????????? double r = random.nextDouble() * dxs[dxs.length - 1]; int next = Arrays.binarySearch(dxs, r); int index = 0; if (next > 0) { index = next - 1; } else if (next < 0) { index = COMPENSATE_INDEX - next; } while (index > 0 && resultList.contains(pointList.get(index))) { index = index - 1; } resultList.add(pointList.get(index)); } return resultList; } /** * ???????? * * @param basePoints ?? * @param centroids ?? * @return ????? */ public static double[] computeDxs(List<KmeansPoint> basePoints, List<KmeansPoint> centroids) { double[] dxs = new double[basePoints.size()]; double sum = 0.0d; double[] nearestCentroid; for (int pointIndex = 0; pointIndex < basePoints.size(); pointIndex++) { // ??????(dx)???????? KmeansPoint targetPoint = basePoints.get(pointIndex); KmeansResult kmeanResult = KmeansCalculator.nearestCentroid(targetPoint, centroids); nearestCentroid = kmeanResult.getCentroid(); double dx = MathUtils.distance(targetPoint.getDataPoint(), nearestCentroid); double probabilityDist = Math.pow(dx, 2); sum += probabilityDist; dxs[pointIndex] = sum; } return dxs; } /** * Kmeans??<br> * ???<br> * <ol> * <li>????????(?n????n?????)</li> * <li>n?????????????????????????????</li> * <li>???????</li> * </ol> * * @param baseKmeans Kmeans * @param targetKmeans Kmeans * @return ? */ public static final KmeansDataSet mergeKmeans(KmeansDataSet baseKmeans, KmeansDataSet targetKmeans) { KmeansDataSet merged = new KmeansDataSet(); int centroidNum = (int) ComparatorUtils.min(baseKmeans.getCentroids().length, targetKmeans.getCentroids().length, ComparatorUtils.NATURAL_COMPARATOR); // ??????? List<CentroidMapping> allDistance = calculateDistances(baseKmeans.getCentroids(), targetKmeans.getCentroids(), centroidNum); // n????????????????? Collections.sort(allDistance, new CentroidsComparator()); Map<Integer, Integer> resultMapping = createCentroidMappings(centroidNum, allDistance); // ?? double[][] mergedCentroids = mergeCentroids(baseKmeans.getCentroids(), targetKmeans.getCentroids(), resultMapping); merged.setCentroids(mergedCentroids); return merged; } /** * ?Counts?? * * @param baseCounts Counts * @param targetCounts Counts * @param resultMapping ?? * @return ?Counts */ protected static List<Long> mergeCounts(List<Long> baseCounts, List<Long> targetCounts, Map<Integer, Integer> resultMapping) { int countNum = resultMapping.size(); List<Long> mergedCounts = new ArrayList<>(countNum); for (int count = 0; count < countNum; count++) { mergedCounts.add(0L); } for (Entry<Integer, Integer> resultEntry : resultMapping.entrySet()) { mergedCounts.set(resultEntry.getKey(), baseCounts.get(resultEntry.getKey()) + targetCounts.get(resultEntry.getValue())); } return mergedCounts; } /** * ????? * * @param basePoints ?? * @param targetPoints ?? * @return ?Counts */ protected static List<double[]> mergeInitPoints(List<double[]> basePoints, List<double[]> targetPoints) { List<double[]> mergedFeatures = new ArrayList<>(); mergedFeatures.addAll(basePoints); mergedFeatures.addAll(targetPoints); return mergedFeatures; } /** * ????????????<br> * ??????????<br> * * @param baseCentroids ? * @param targetCentroids ? * @param resultMapping ?? * @return ?? */ public static double[][] mergeCentroids(double[][] baseCentroids, double[][] targetCentroids, Map<Integer, Integer> resultMapping) { // ?????? double[][] mergedCentroids = new double[resultMapping.size()][]; for (Map.Entry<Integer, Integer> targetEntry : resultMapping.entrySet()) { double[] baseCentroid = baseCentroids[targetEntry.getKey()]; double[] targetCentroid = targetCentroids[targetEntry.getValue()]; mergedCentroids[targetEntry.getKey()] = average(baseCentroid, targetCentroid); } return mergedCentroids; } /** * ??? * * @param centroidNum * @param allDistance ? * @return */ protected static Map<Integer, Integer> createCentroidMappings(int centroidNum, List<CentroidMapping> allDistance) { Set<Integer> baseSet = new HashSet<>(); Set<Integer> targetSet = new HashSet<>(); Map<Integer, Integer> resultMapping = new TreeMap<>(); int mappingNum = 0; // ????? for (CentroidMapping targetDistance : allDistance) { // ????????? if (baseSet.contains(targetDistance.getBaseIndex()) || targetSet.contains(targetDistance.getTargetIndex())) { continue; } baseSet.add(targetDistance.getBaseIndex()); targetSet.add(targetDistance.getTargetIndex()); resultMapping.put(targetDistance.getBaseIndex(), targetDistance.getTargetIndex()); mappingNum++; // ???????? if (mappingNum >= centroidNum) { break; } } return resultMapping; } /** * ???????? * * @param baseCentroids ? * @param targetCentroids ? * @param centroidNum * @return ? */ protected static List<CentroidMapping> calculateDistances(double[][] baseCentroids, double[][] targetCentroids, int centroidNum) { // ??????? List<CentroidMapping> allDistance = new ArrayList<>(); for (int baseIndex = 0; baseIndex < centroidNum; baseIndex++) { for (int targetIndex = 0; targetIndex < centroidNum; targetIndex++) { CentroidMapping centroidMapping = new CentroidMapping(); centroidMapping.setBaseIndex(baseIndex); centroidMapping.setTargetIndex(targetIndex); double distance = MathUtils.distance(baseCentroids[baseIndex], targetCentroids[targetIndex]); centroidMapping.setEuclideanDistance(distance); allDistance.add(centroidMapping); } } return allDistance; } /** * ???? * * @param base ? * @param target ?? * @return ?? */ protected static double[] average(double[] base, double[] target) { int dataNum = base.length; double[] average = new double[dataNum]; for (int index = 0; index < dataNum; index++) { average[index] = (base[index] + target[index]) / 2.0; } return average; } /** * double???? * * @param base ? * @param subNumber ? * @return ? */ protected static double[] sub(double[] base, double subNumber) { int dataNum = base.length; double[] result = new double[dataNum]; for (int index = 0; index < dataNum; index++) { result[index] = base[index] / subNumber; } return result; } }