com.clust4j.algo.BoruvkaAlgorithm.java Source code

Java tutorial

Introduction

Here is the source code for com.clust4j.algo.BoruvkaAlgorithm.java

Source

/*******************************************************************************
 *    Copyright 2015, 2016 Taylor G Smith
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 *******************************************************************************/
package com.clust4j.algo;

import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.util.FastMath;

import com.clust4j.algo.Neighborhood;
import com.clust4j.algo.NearestNeighborHeapSearch.NodeData;
import com.clust4j.log.LogTimer;
import com.clust4j.log.Loggable;
import com.clust4j.metrics.pairwise.DistanceMetric;
import com.clust4j.metrics.pairwise.Pairwise;
import com.clust4j.utils.VecUtils;

/**
 * A graph traversal algorithm used in identifying the minimum spanning tree 
 * in a graph for which all edge weights are distinct. Used in conjunction with
 * {@link HDBSCAN}, and adapted from the <a href="https://github.com/lmcinnes/hdbscan">HDBSCAN python package</a>.
 * 
 * @see <a href="https://en.wikipedia.org/wiki/Bor%C5%AFvka%27s_algorithm">Boruvka's algorithm</a>
 * @author Taylor G Smith
 */
class BoruvkaAlgorithm implements java.io.Serializable {
    private static final long serialVersionUID = 3935595821188876442L;

    // the initialization reorganizes the trees
    final protected Boruvka alg;

    private final NearestNeighborHeapSearch outer_tree;
    private final int minSamples;
    private final DistanceMetric metric;
    private final boolean approxMinSpanTree;
    private final int leafSize;
    private final Loggable logger;
    private final double alpha;

    protected BoruvkaAlgorithm(NearestNeighborHeapSearch tree, int min_samples, DistanceMetric metric, int leafSize,
            boolean approx_min_span_tree, double alpha, Loggable logger) {

        this.outer_tree = tree;
        this.minSamples = min_samples;
        this.metric = metric;
        this.leafSize = leafSize;
        this.approxMinSpanTree = approx_min_span_tree;
        this.alpha = alpha;
        this.logger = logger;

        // Create the actual solver -- if using logger,
        // updates with info in the actual algorithm
        alg = (tree instanceof KDTree) ? new KDTreeBoruvAlg() : new BallTreeBoruvAlg();
    }

    protected static class BoruvkaUnionFind extends HDBSCAN.TreeUnionFind {
        BoruvkaUnionFind(int N) {
            super(N);
        }
    }

    protected static double ballTreeMinDistDual(double rad1, double rad2, int node1, int node2,
            double[][] centroidDist) {
        double distPt = centroidDist[node1][node2];
        return FastMath.max(0, (distPt - rad1 - rad2));
    }

    /*
     * Similar to {@link KDTree}<tt>.minRDistDual(...)</tt> but
     * uses one node bounds array instead of two instances of
     * {@link NearestNeighborHeapSearch}
     * @param metric
     * @param node1
     * @param node2
     * @param nodeBounds
     * @param n
     * @return
     *
    static double kdTreeMinDistDual(DistanceMetric metric, int node1, int node2, double[][][] nodeBounds, int n) {
       return metric.partialDistanceToDistance(kdTreeMinRDistDual(metric, node1, node2, nodeBounds, n));
    }
    */

    protected static double kdTreeMinRDistDual(DistanceMetric metric, int node1, int node2, double[][][] nodeBounds,
            int n) {
        double d, d1, d2, rdist = 0.0;
        boolean inf = metric.getP() == Double.POSITIVE_INFINITY;
        int j;

        for (j = 0; j < n; j++) {
            d1 = nodeBounds[0][node1][j] - nodeBounds[1][node2][j];
            d2 = nodeBounds[0][node2][j] - nodeBounds[1][node1][j];
            d = (d1 + FastMath.abs(d1)) + (d2 + FastMath.abs(d2));

            rdist = inf ? FastMath.max(rdist, 0.5 * d) : rdist + FastMath.pow(0.5 * d, metric.getP());
        }

        return rdist;
    }

    /**
     * The {@link NearestNeighborHeapSearch} 
     * tree traversal algorithm
     * @author Taylor G Smith
     */
    protected abstract class Boruvka {
        final static int INIT_VAL = -1;

        final NearestNeighborHeapSearch coreDistTree = outer_tree;
        final NearestNeighborHeapSearch TREE;
        final BoruvkaUnionFind componentUnionFind;

        final double[][] tree_data_ref;
        final double[][][] node_bounds;
        final int[] idx_array;
        final NodeData[] node_data_ref;
        final boolean partialDistTransform;

        int numPoints, numFeatures, numNodes, numEdges;
        double[] bounds;
        int[] components, componentOfPoint, componentOfNode, candidateNeighbors, candidatePoint;
        double[] candidateDistance;
        double[][] edges;
        double[] coreDistance;

        Boruvka(boolean partialTrans, NearestNeighborHeapSearch TREE) {
            this.TREE = TREE;
            this.tree_data_ref = TREE.getDataRef();
            this.node_bounds = TREE.getNodeBoundsRef();
            this.idx_array = TREE.getIndexArrayRef();
            this.node_data_ref = TREE.getNodeDataRef();

            this.numPoints = this.tree_data_ref.length;
            this.numFeatures = this.tree_data_ref[0].length;
            this.numNodes = this.node_data_ref.length;

            this.components = VecUtils.arange(numPoints);
            this.bounds = new double[numNodes];
            this.componentOfPoint = new int[numPoints];
            this.componentOfNode = new int[numNodes];
            this.candidateNeighbors = new int[numPoints];
            this.candidatePoint = new int[numPoints];
            this.candidateDistance = new double[numPoints];
            this.edges = new double[numPoints - 1][3];
            this.componentUnionFind = new BoruvkaUnionFind(numPoints);

            LogTimer s = new LogTimer();
            this.partialDistTransform = partialTrans;

            initComponents();
            computeBounds();

            if (null != logger)
                logger.info("completed Boruvka nearest neighbor search in " + s.toString());
        }

        final void initComponents() {
            int n;

            for (n = 0; n < this.numPoints; n++) {
                this.componentOfPoint[n] = n;
                this.candidateNeighbors[n] = INIT_VAL;
                this.candidatePoint[n] = INIT_VAL;
                this.candidateDistance[n] = Double.MAX_VALUE;
            }

            for (n = 0; n < numNodes; n++)
                this.componentOfNode[n] = -(n + 1);
        }

        final double[][] spanningTree() {
            int numComponents = this.tree_data_ref.length;

            while (numComponents > 1) {
                this.dualTreeTraversal(0, 0);
                numComponents = this.updateComponents();
            }

            return this.edges;
        }

        final int updateComponents() {
            int source, sink, c, component, n, i, p, currentComponent, currentSrcComponent, currentSinkComponent,
                    child1, child2, lastNumComponents;
            NodeData nodeInfo;

            // For each component there should be a:
            //   - candidate point (a point in the component)
            //   - candidate neighbor (the point to join with)
            //   - candidate_distance (the distance from point to neighbor)
            //
            // We will go through and and an edge to the edge list
            // for each of these, and the union the two points
            // together in the union find structure
            for (c = 0; c < this.components.length; c++ /* <- tee-hee */) {
                component = this.components[c];
                source = this.candidatePoint[component];
                sink = this.candidateNeighbors[component];

                //Src or sink is undefined...
                if (source == INIT_VAL || sink == INIT_VAL)
                    continue;

                currentSrcComponent = this.componentUnionFind.find(source);
                currentSinkComponent = this.componentUnionFind.find(sink);

                // Already joined these so ignore this edge
                if (currentSrcComponent == currentSinkComponent) {
                    this.candidatePoint[component] = INIT_VAL;
                    this.candidateNeighbors[component] = INIT_VAL;
                    this.candidateDistance[component] = Double.MAX_VALUE;
                    continue;
                }

                // Set edge
                this.edges[numEdges][0] = source;
                this.edges[numEdges][1] = sink;
                this.edges[numEdges][2] = this.partialDistTransform
                        ? metric.partialDistanceToDistance(this.candidateDistance[component])
                        : this.candidateDistance[component];
                this.numEdges++;

                // Join
                this.componentUnionFind.union(source, sink);

                // Reset everything and check for termination condition
                this.candidateDistance[component] = Double.MAX_VALUE;
                if (this.numEdges == this.numPoints - 1) {
                    this.components = this.componentUnionFind.components();
                    return components.length;
                }
            }

            // After joining everything, we go through to determine
            // the components of each point for an easier lookup. Makes
            // for faster pruning later...
            for (n = 0; n < this.tree_data_ref.length; n++)
                this.componentOfPoint[n] = this.componentUnionFind.find(n);

            for (n = this.node_data_ref.length - 1; n >= 0; n--) {
                nodeInfo = this.node_data_ref[n];

                // If node is leaf, check that every point in node is same component
                if (nodeInfo.isLeaf()) {
                    currentComponent = this.componentOfPoint[idx_array[nodeInfo.start()]];

                    boolean found = false;
                    for (i = nodeInfo.start() + 1; i < nodeInfo.end(); i++) {
                        p = idx_array[i];
                        if (componentOfPoint[p] != currentComponent) {
                            found = true;
                            break;
                        }
                    }

                    // Alternative to the python for... else construct.
                    if (!found)
                        this.componentOfNode[n] = currentComponent;
                }

                // If not leaf, check both child nodes are same component
                else {
                    child1 = 2 * n + 1;
                    child2 = 2 * n + 2;

                    if (this.componentOfNode[child1] == this.componentOfNode[child2])
                        this.componentOfNode[n] = this.componentOfNode[child1];
                }
            }

            // This is a tie breaking method
            if (approxMinSpanTree) {
                lastNumComponents = this.components.length;
                components = this.componentUnionFind.components();

                if (components.length == lastNumComponents) // i.e., if all is isComponents are true
                    for (n = 0; n < numNodes; n++) // Reset
                        bounds[n] = Double.MAX_VALUE;

            } else {
                this.components = this.componentUnionFind.components();
                for (n = 0; n < numNodes; n++)
                    this.bounds[n] = Double.MAX_VALUE;
            }

            return components.length;
        }

        abstract void computeBounds();

        abstract int dualTreeTraversal(int node1, int node2);
    }

    protected class KDTreeBoruvAlg extends Boruvka {
        KDTreeBoruvAlg() {
            super(true,
                    new KDTree(new Array2DRowRealMatrix(outer_tree.getDataRef(), false), leafSize, metric, logger));
        }

        @Override
        void computeBounds() {
            int n, i, m;

            // The python code uses the breadth-first search, but
            // we eliminated the breadth-first option in favor of depth-first
            // for all cases for the time being.
            Neighborhood queryResult = TREE.query(tree_data_ref, minSamples + 1, true, true);

            double[][] knnDist = queryResult.getDistances();
            int[][] knnIndices = queryResult.getIndices();

            // Assign the core distance array and change to rdist...
            this.coreDistance = new double[knnDist.length];
            for (i = 0; i < coreDistance.length; i++)
                coreDistance[i] = metric.distanceToPartialDistance(knnDist[i][minSamples]);

            for (n = 0; n < numPoints; n++) {
                for (i = 1; i < minSamples + 1; i++) {
                    m = knnIndices[n][i];

                    if (this.coreDistance[m] <= this.coreDistance[n]) {
                        this.candidatePoint[n] = n;
                        this.candidateNeighbors[n] = m;
                        this.candidateDistance[n] = this.coreDistance[n];
                        break;
                    }
                }
            }

            this.updateComponents();
            for (n = 0; n < numNodes; n++)
                this.bounds[n] = Double.MAX_VALUE;
        }

        @Override
        int dualTreeTraversal(int node1, int node2) {
            int[] pointIndices1, pointIndices2;
            int i, j, p, q, parent;

            double nodeDist, d, mrDist, newBound, newUpperBound, newLowerBound, leftDist, rightDist;

            NodeData node1Info = node_data_ref[node1], node2Info = node_data_ref[node2];

            int component1, component2, left, right;

            // Distance btwn query and ref nodes
            nodeDist = kdTreeMinRDistDual(metric, node1, node2, this.node_bounds, this.numFeatures);

            // If dist < current bound and nodes are not in the
            // same component, we continue
            if (nodeDist < this.bounds[node1]) {
                if (this.componentOfNode[node1] == this.componentOfNode[node2] && this.componentOfNode[node1] >= 0)
                    return 0;
                else {
                    /*
                     * Pass. This is the only condition in which
                     * the method will continue without exiting early
                     */
                }
            } else
                return 0;

            // If both nodes are leaves
            if (node1Info.isLeaf() && node2Info.isLeaf()) {
                newUpperBound = 0.0;
                newLowerBound = Double.MAX_VALUE;

                // Build the indices
                pointIndices1 = new int[node1Info.end() - node1Info.start()];
                pointIndices2 = new int[node2Info.end() - node2Info.start()];

                // Populate the indices
                for (i = node1Info.start(), j = 0; i < node1Info.end(); i++, j++)
                    pointIndices1[j] = this.idx_array[i];
                for (i = node2Info.start(), j = 0; i < node2Info.end(); i++, j++)
                    pointIndices2[j] = this.idx_array[i];

                for (i = 0; i < pointIndices1.length; i++) {
                    p = pointIndices1[i];
                    component1 = this.componentOfPoint[p];

                    if (this.coreDistance[p] > this.candidateDistance[component1])
                        continue;

                    for (j = 0; j < pointIndices2.length; j++) {
                        q = pointIndices2[j];
                        component2 = this.componentOfPoint[q];

                        if (this.coreDistance[q] > this.candidateDistance[component1])
                            continue;

                        // They belong to different components
                        if (component1 != component2) {

                            d = metric.getPartialDistance(this.tree_data_ref[p], this.tree_data_ref[q]);

                            mrDist = FastMath.max(
                                    // Avoid repeated division overhead
                                    (alpha == 1.0 ? d : d / alpha),

                                    // Nested max
                                    FastMath.max(this.coreDistance[p], this.coreDistance[q]));

                            if (mrDist < this.candidateDistance[component1]) {
                                this.candidateDistance[component1] = mrDist;
                                this.candidateNeighbors[component1] = q;
                                this.candidatePoint[component1] = p;
                            }
                        }
                    } // end for j

                    newUpperBound = FastMath.max(newUpperBound, this.candidateDistance[component1]);
                    newLowerBound = FastMath.min(newLowerBound, this.candidateDistance[component1]);
                } // end for i

                // Calc new bound
                newBound = FastMath.min(newUpperBound, newLowerBound + 2 * node1Info.radius());

                // Reassign new bound to min bounds[node1]
                if (newBound < this.bounds[node1]) {
                    this.bounds[node1] = newBound;

                    // propagate bounds up...
                    while (node1 > 0) {
                        parent = (node1 - 1) / 2;
                        left = 2 * parent + 1;
                        right = 2 * parent + 2;

                        newBound = FastMath.max(this.bounds[left], this.bounds[right]);
                        if (newBound < this.bounds[parent]) {
                            this.bounds[parent] = newBound;
                            node1 = parent;
                        } else
                            break;
                    } // end while
                } // end if inner
            } // end case 1 if

            // If node is a leaf or smaller than ref node
            else if (node1Info.isLeaf() || (!node2Info.isLeaf() && node2Info.radius() > node1Info.radius())) {

                left = 2 * node2 + 1;
                right = 2 * node2 + 2;

                node2Info = this.node_data_ref[left];
                leftDist = kdTreeMinRDistDual(metric, node1, left, node_bounds, this.numFeatures);

                node2Info = this.node_data_ref[right];
                rightDist = kdTreeMinRDistDual(metric, node1, right, node_bounds, this.numFeatures);

                if (leftDist < rightDist) {
                    this.dualTreeTraversal(node1, left);
                    this.dualTreeTraversal(node1, right);

                } else { // Navigate in opposite order
                    this.dualTreeTraversal(node1, right);
                    this.dualTreeTraversal(node1, left);
                }
            } // end case 2 if

            // Node is leaf or smaller than query node
            else {
                left = 2 * node1 + 1;
                right = 2 * node1 + 2;

                node1Info = this.node_data_ref[left];
                leftDist = kdTreeMinRDistDual(metric, left, node2, node_bounds, this.numFeatures);

                node1Info = this.node_data_ref[right];
                rightDist = kdTreeMinRDistDual(metric, right, node2, node_bounds, this.numFeatures);

                if (leftDist < rightDist) {
                    this.dualTreeTraversal(left, node2);
                    this.dualTreeTraversal(right, node2);

                } else {
                    this.dualTreeTraversal(right, node2);
                    this.dualTreeTraversal(left, node2);
                }
            }

            return 0;
        }
    }

    protected class BallTreeBoruvAlg extends Boruvka {
        final double[][] centroidDistances;

        BallTreeBoruvAlg() {
            super(false, new BallTree(new Array2DRowRealMatrix(outer_tree.getDataRef(), false), leafSize, metric,
                    logger));

            // Compute pairwise dist matrix for node_bounds
            centroidDistances = Pairwise.getDistance(node_bounds[0], metric, false, false);
        }

        @Override
        void computeBounds() {
            int n, i, m;

            // No longer doing breadth-first searches
            Neighborhood queryResult = TREE.query(tree_data_ref, minSamples, true, true);

            double[][] knnDist = queryResult.getDistances();
            int[][] knnIndices = queryResult.getIndices();

            // Assign the core distance array...
            this.coreDistance = new double[knnDist.length];
            for (i = 0; i < coreDistance.length; i++)
                coreDistance[i] = knnDist[i][minSamples - 1];

            for (n = 0; n < numPoints; n++) {
                for (i = minSamples - 1; i > 0; i--) {
                    m = knnIndices[n][i];

                    if (this.coreDistance[m] <= this.coreDistance[n]) {
                        this.candidatePoint[n] = n;
                        this.candidateNeighbors[n] = m;
                        this.candidateDistance[n] = this.coreDistance[n];
                    }
                }
            }

            updateComponents();

            for (n = 0; n < numNodes; n++)
                this.bounds[n] = Double.MAX_VALUE;
        }

        @Override
        int dualTreeTraversal(int node1, int node2) {
            int[] pointIndices1, pointIndices2;
            int i, j, p, q, parent //,child1, child2
            ;

            double nodeDist, d, mrDist, newBound, newUpperBound, newLowerBound, boundMax, boundMin, leftDist,
                    rightDist;

            NodeData node1Info = node_data_ref[node1], node2Info = node_data_ref[node2], parentInfo, leftInfo,
                    rightInfo;

            int component1, component2, left, right;

            // Distance btwn query and ref nodes
            nodeDist = ballTreeMinDistDual(node1Info.radius(), node2Info.radius(), node1, node2,
                    this.centroidDistances);

            // If dist < current bound and nodes are not in the
            // same component, we continue
            if (nodeDist < this.bounds[node1]) {
                if (this.componentOfNode[node1] == this.componentOfNode[node2] && this.componentOfNode[node1] >= 0)
                    return 0;
                else {
                    /*
                     * Pass. This is the only condition in which
                     * the method will continue without exiting early
                     */
                }
            } else
                return 0;

            // If both nodes are leaves
            if (node1Info.isLeaf() && node2Info.isLeaf()) {
                newUpperBound = Double.NEGATIVE_INFINITY;
                newLowerBound = Double.MAX_VALUE;
                newBound = 0.0;

                // Build the indices
                pointIndices1 = new int[node1Info.end() - node1Info.start()];
                pointIndices2 = new int[node2Info.end() - node2Info.start()];

                // Populate the indices
                for (i = node1Info.start(), j = 0; i < node1Info.end(); i++, j++)
                    pointIndices1[j] = this.idx_array[i];
                for (i = node2Info.start(), j = 0; i < node2Info.end(); i++, j++)
                    pointIndices2[j] = this.idx_array[i];

                for (i = 0; i < pointIndices1.length; i++) {
                    p = pointIndices1[i];
                    component1 = this.componentOfPoint[p];

                    if (this.coreDistance[p] > this.candidateDistance[component1])
                        continue;

                    for (j = 0; j < pointIndices2.length; j++) {
                        q = pointIndices2[j];
                        component2 = this.componentOfPoint[q];

                        if (this.coreDistance[q] > this.candidateDistance[component1])
                            continue;

                        // They belong to different components
                        if (component1 != component2) {
                            d = metric.getDistance(this.tree_data_ref[p], this.tree_data_ref[q]);

                            mrDist = FastMath.max(
                                    // Avoid repeated division overhead
                                    (alpha == 1.0 ? d : d / alpha),

                                    // Nested max
                                    FastMath.max(this.coreDistance[p], this.coreDistance[q]));

                            if (mrDist < this.candidateDistance[component1]) {
                                this.candidateDistance[component1] = mrDist;
                                this.candidateNeighbors[component1] = q;
                                this.candidatePoint[component1] = p;
                            }
                        }
                    } // end for j

                    newUpperBound = FastMath.max(newUpperBound, this.candidateDistance[component1]);
                    newLowerBound = FastMath.min(newLowerBound, this.candidateDistance[component1]);
                } // end for i

                // Calc new bound
                newBound = FastMath.min(newUpperBound, newLowerBound + 2 * node1Info.radius());

                // Reassign new bound to min bounds[node1]
                if (newBound < this.bounds[node1]) {
                    this.bounds[node1] = newBound;

                    // propagate bounds up...
                    while (node1 > 0) {
                        parent = (node1 - 1) / 2;
                        left = 2 * parent + 1;
                        right = 2 * parent + 2;

                        parentInfo = this.node_data_ref[parent];
                        leftInfo = this.node_data_ref[left];
                        rightInfo = this.node_data_ref[right];

                        boundMax = FastMath.max(this.bounds[left], this.bounds[right]);
                        boundMin = FastMath.min(this.bounds[left] + 2 * (parentInfo.radius() - leftInfo.radius()),
                                this.bounds[right] + 2 * (parentInfo.radius() - rightInfo.radius()));

                        if (boundMin > 0)
                            newBound = FastMath.min(boundMax, boundMin);
                        else
                            newBound = boundMax;

                        if (newBound < this.bounds[parent]) {
                            this.bounds[parent] = newBound;
                            node1 = parent;
                        } else
                            break;
                    } // end while
                } // end if inner
            } // end case 1 if

            // If node is a leaf or smaller than ref node
            else if (node1Info.isLeaf() || (!node2Info.isLeaf() && node2Info.radius() > node1Info.radius())) {
                left = 2 * node2 + 1;
                right = 2 * node2 + 2;

                node2Info = this.node_data_ref[left];
                leftDist = ballTreeMinDistDual(node1Info.radius(), node2Info.radius(), node1, left,
                        this.centroidDistances);

                node2Info = this.node_data_ref[right];
                rightDist = ballTreeMinDistDual(node1Info.radius(), node2Info.radius(), node1, right,
                        this.centroidDistances);

                if (leftDist < rightDist) {
                    this.dualTreeTraversal(node1, left);
                    this.dualTreeTraversal(node1, right);

                } else { // Navigate in opposite order
                    this.dualTreeTraversal(node1, right);
                    this.dualTreeTraversal(node1, left);
                }
            } // end case 2 if

            // Node is leaf or smaller than query node
            else {
                left = 2 * node1 + 1;
                right = 2 * node1 + 2;

                node1Info = this.node_data_ref[left];
                leftDist = ballTreeMinDistDual(node1Info.radius(), node2Info.radius(), left, node2,
                        this.centroidDistances);

                node1Info = this.node_data_ref[right];
                rightDist = ballTreeMinDistDual(node1Info.radius(), node2Info.radius(), right, node2,
                        this.centroidDistances);

                if (leftDist < rightDist) {
                    this.dualTreeTraversal(left, node2);
                    this.dualTreeTraversal(right, node2);

                } else {
                    this.dualTreeTraversal(right, node2);
                    this.dualTreeTraversal(left, node2);
                }
            }

            return 0;
        }
    }

    protected final double[][] spanningTree() {
        return alg.spanningTree();
    }
}