Java tutorial
/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.clustering.cluster; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.clustering.algorithm.optimisation.ClusteringOptimizationType; import org.deeplearning4j.clustering.algorithm.strategy.OptimisationStrategy; import org.deeplearning4j.clustering.cluster.info.ClusterInfo; import org.deeplearning4j.clustering.cluster.info.ClusterSetInfo; import org.deeplearning4j.util.MathUtils; import org.deeplearning4j.util.MultiThreadUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.*; import java.util.concurrent.ExecutorService; /** * * Basic cluster utilities */ public class ClusterUtils { private ClusterUtils() { } /** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */ public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> points, ExecutorService executorService) { final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true); List<Runnable> tasks = new ArrayList<>(); for (final Point point : points) { tasks.add(new Runnable() { public void run() { try { PointClassification result = classifyPoint(clusterSet, point); if (result.isNewLocation()) clusterSetInfo.getPointLocationChange().incrementAndGet(); clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter() .put(point.getId(), result.getDistanceFromCenter()); } catch (Exception e) { e.printStackTrace(); } } }); } MultiThreadUtils.parallelTasks(tasks, executorService); return clusterSetInfo; } public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) { return clusterSet.classifyPoint(point, false); } public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, ExecutorService executorService) { List<Runnable> tasks = new ArrayList<>(); int nClusters = clusterSet.getClusterCount(); for (int i = 0; i < nClusters; i++) { final Cluster cluster = clusterSet.getClusters().get(i); tasks.add(new Runnable() { public void run() { try { final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); refreshClusterCenter(cluster, clusterInfo); deriveClusterInfoDistanceStatistics(clusterInfo); } catch (Exception e) { e.printStackTrace(); } } }); } MultiThreadUtils.parallelTasks(tasks, executorService); } public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) { int pointsCount = cluster.getPoints().size(); if (pointsCount == 0) return; Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length())); for (Point point : cluster.getPoints()) { INDArray arr = point.getArray(); center.getArray().addi(arr); } center.getArray().divi(pointsCount); cluster.setCenter(center); } public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) { int pointCount = info.getPointDistancesFromCenter().size(); if (pointCount == 0) return; double[] distances = ArrayUtils .toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {})); double max = MathUtils.max(distances); double total = MathUtils.sum(distances); info.setMaxPointDistanceFromCenter(max); info.setTotalPointDistanceFromCenter(total); info.setAveragePointDistanceFromCenter(total / pointCount); info.setPointDistanceFromCenterVariance(MathUtils.variance(distances)); } public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet, final List<Point> points, INDArray previousDxs, ExecutorService executorService) { final int pointsCount = points.size(); final INDArray dxs = Nd4j.create(pointsCount); final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); List<Runnable> tasks = new ArrayList<>(); for (int i = 0; i < pointsCount; i++) { final int i2 = i; tasks.add(new Runnable() { public void run() { Point point = points.get(i2); dxs.putScalar(i2, Math.pow(newCluster.getDistanceToCenter(point), 2)); } }); } MultiThreadUtils.parallelTasks(tasks, executorService); for (int i = 0; i < pointsCount; i++) { double previousMinDistance = previousDxs.getDouble(i); if (dxs.getDouble(i) > previousMinDistance) dxs.putScalar(i, previousMinDistance); } return dxs; } public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) { ExecutorService executor = MultiThreadUtils.newExecutorService(); ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor); executor.shutdownNow(); return info; } public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) { final ClusterSetInfo info = new ClusterSetInfo(true); int clusterCount = clusterSet.getClusterCount(); List<Runnable> tasks = new ArrayList<>(); for (int i = 0; i < clusterCount; i++) { final Cluster cluster = clusterSet.getClusters().get(i); tasks.add(new Runnable() { public void run() { info.getClustersInfos().put(cluster.getId(), computeClusterInfos(cluster, clusterSet.getAccumulation())); } }); } MultiThreadUtils.parallelTasks(tasks, executorService); tasks = new ArrayList<>(); for (int i = 0; i < clusterCount; i++) { final int clusterIdx = i; final Cluster fromCluster = clusterSet.getClusters().get(i); tasks.add(new Runnable() { public void run() { try { for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) { Cluster toCluster = clusterSet.getClusters().get(k); double distance = Nd4j.getExecutioner() .execAndReturn(Nd4j.getOpFactory().createAccum(clusterSet.getAccumulation(), fromCluster.getCenter().getArray(), toCluster.getCenter().getArray())) .getFinalResult().doubleValue(); info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(), distance); } } catch (Exception e) { e.printStackTrace(); } } }); } MultiThreadUtils.parallelTasks(tasks, executorService); return info; } public static ClusterInfo computeClusterInfos(Cluster cluster, String distanceFunction) { ClusterInfo info = new ClusterInfo(true); for (int i = 0, j = cluster.getPoints().size(); i < j; i++) { Point point = cluster.getPoints().get(i); double distance = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createAccum(distanceFunction, cluster.getCenter().getArray(), point.getArray())).getFinalResult().doubleValue(); info.getPointDistancesFromCenter().put(point.getId(), distance); info.setTotalPointDistanceFromCenter(info.getTotalPointDistanceFromCenter() + distance); } if (!cluster.getPoints().isEmpty()) info.setAveragePointDistanceFromCenter( info.getTotalPointDistanceFromCenter() / cluster.getPoints().size()); return info; } public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, ExecutorService executor) { if (optimization.isClusteringOptimizationType( ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) { int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); return splitCount > 0; } if (optimization.isClusteringOptimizationType( ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) { int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); return splitCount > 0; } return false; } public static List<Cluster> getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info, int count) { List<Cluster> clusters = new ArrayList<>(clusterSet.getClusters()); Collections.sort(clusters, new Comparator<Cluster>() { public int compare(Cluster o1, Cluster o2) { Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter(); Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter(); return -o1TotalDistance.compareTo(o2TotalDistance); } }); return clusters.subList(0, count); } public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet, final ClusterSetInfo info, double maximumAverageDistance) { List<Cluster> clusters = new ArrayList<>(); for (Cluster cluster : clusterSet.getClusters()) { ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); if (clusterInfo != null && clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance) clusters.add(cluster); } return clusters; } public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet, final ClusterSetInfo info, double maximumDistance) { List<Cluster> clusters = new ArrayList<>(); for (Cluster cluster : clusterSet.getClusters()) { ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); if (clusterInfo != null && clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) clusters.add(cluster); } return clusters; } public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, ExecutorService executorService) { List<Cluster> clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count); splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); return clustersToSplit.size(); } public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { List<Cluster> clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, maxWithinClusterDistance); splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); return clustersToSplit.size(); } public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { List<Cluster> clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, maxWithinClusterDistance); splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); return clustersToSplit.size(); } public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, ExecutorService executorService) { List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count); splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); } public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, List<Cluster> clusters, final double maxDistance, ExecutorService executorService) { final Random random = new Random(); List<Runnable> tasks = new ArrayList<>(); for (final Cluster cluster : clusters) { tasks.add(new Runnable() { public void run() { try { ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); List<String> fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance); int rank = Math.min(fartherPoints.size(), 3); String pointId = fartherPoints.get(random.nextInt(rank)); Point point = cluster.removePoint(pointId); clusterSet.addNewClusterWithCenter(point); } catch (Exception e) { e.printStackTrace(); } } }); } MultiThreadUtils.parallelTasks(tasks, executorService); } public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, List<Cluster> clusters, ExecutorService executorService) { final Random random = new Random(); List<Runnable> tasks = new ArrayList<>(); for (final Cluster cluster : clusters) { tasks.add(new Runnable() { public void run() { Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size())); clusterSet.addNewClusterWithCenter(point); } }); } MultiThreadUtils.parallelTasks(tasks, executorService); } }