com.clust4j.algo.NearestNeighborHeapSearch.java Source code

Java tutorial

Introduction

Here is the source code for com.clust4j.algo.NearestNeighborHeapSearch.java

Source

/*******************************************************************************
 *    Copyright 2015, 2016 Taylor G Smith
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 *******************************************************************************/
package com.clust4j.algo;

import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;

import com.clust4j.log.Loggable;
import com.clust4j.metrics.pairwise.Distance;
import com.clust4j.metrics.pairwise.DistanceMetric;
import com.clust4j.metrics.pairwise.GeometricallySeparable;
import com.clust4j.utils.DeepCloneable;
import com.clust4j.utils.MatUtils;
import com.clust4j.utils.QuadTup;
import com.clust4j.utils.VecUtils;

import static com.clust4j.GlobalState.Mathematics.*;

import java.util.Arrays;

/**
 * A datastructure for optimized high dimensional k-neighbors and radius
 * searches. Based on sklearns' BinaryTree class.
 * @author Taylor G Smith
 * @see <a href="https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/neighbors/binary_tree.pxi">sklearn BinaryTree</a>
 */
abstract class NearestNeighborHeapSearch implements java.io.Serializable {
    private static final long serialVersionUID = -5617532034886067210L;

    final static public int DEF_LEAF_SIZE = 40;
    final static public DistanceMetric DEF_DIST = Distance.EUCLIDEAN;
    final static String MEM_ERR = "Internal: memory layout is flawed: " + "not enough nodes allocated";

    double[][] data_arr;
    int[] idx_array;
    NodeData[] node_data;
    double[][][] node_bounds;

    /** If there's a logger, for warnings will issue warn message */
    final Loggable logger;
    /** Constrained to Dist, not Sim due to nearest neighbor requirements */
    final DistanceMetric dist_metric;
    int n_trims, n_leaves, n_splits, n_calls, leaf_size, n_levels, n_nodes;
    final int N_SAMPLES, N_FEATURES;
    /** Whether or not the algorithm uses the Inf distance, {@link Distance#CHEBYSHEV} */
    final boolean infinity_dist;

    /**
     * Ensure valid metric
     */
    abstract boolean checkValidDistMet(GeometricallySeparable dist);

    public NearestNeighborHeapSearch(final RealMatrix X) {
        this(X, DEF_LEAF_SIZE, DEF_DIST);
    }

    public NearestNeighborHeapSearch(final RealMatrix X, int leaf_size) {
        this(X, leaf_size, DEF_DIST);
    }

    public NearestNeighborHeapSearch(final RealMatrix X, DistanceMetric dist) {
        this(X, DEF_LEAF_SIZE, dist);
    }

    public NearestNeighborHeapSearch(final RealMatrix X, Loggable logger) {
        this(X, DEF_LEAF_SIZE, DEF_DIST, logger);
    }

    /**
     * Default constructor without logger object
     * @param X
     * @param leaf_size
     * @param dist
     */
    public NearestNeighborHeapSearch(final RealMatrix X, int leaf_size, DistanceMetric dist) {
        this(X, leaf_size, dist, null);
    }

    /**
     * Constructor with logger and distance metric
     * @param X
     * @param dist
     * @param logger
     */
    public NearestNeighborHeapSearch(final RealMatrix X, DistanceMetric dist, Loggable logger) {
        this(X, DEF_LEAF_SIZE, dist, logger);
    }

    /**
     * Constructor with logger object
     * @param X
     * @param leaf_size
     * @param dist
     * @param logger
     */
    public NearestNeighborHeapSearch(final RealMatrix X, int leaf_size, DistanceMetric dist, Loggable logger) {
        this(X.getData(), leaf_size, dist, logger);
    }

    /**
     * Constructor with logger object
     * @param X
     * @param leaf_size
     * @param dist
     * @param logger
     */
    protected NearestNeighborHeapSearch(final double[][] X, int leaf_size, DistanceMetric dist, Loggable logger) {
        this.data_arr = MatUtils.copy(X);
        this.leaf_size = leaf_size;
        this.logger = logger;

        if (leaf_size < 1)
            throw new IllegalArgumentException("illegal leaf size: " + leaf_size);

        if (!checkValidDistMet(dist)) {
            if (null != logger)
                logger.warn(dist + " is not valid for " + this.getClass() + ". Reverting to " + DEF_DIST);
            this.dist_metric = DEF_DIST;
        } else {
            this.dist_metric = dist;
        }

        // Whether the algorithm is using the infinity distance (Chebyshev)
        this.infinity_dist = this.dist_metric.getP() == Double.POSITIVE_INFINITY
                || Double.isInfinite(this.dist_metric.getP());

        // determine number of levels in the tree, and from this
        // the number of nodes in the tree.  This results in leaf nodes
        // with numbers of points between leaf_size and 2 * leaf_size
        MatUtils.checkDims(this.data_arr);
        N_SAMPLES = data_arr.length;
        N_FEATURES = X[0].length;

        /*
        // Should round up or always take floor function?...
        double nlev = FastMath.log(2, FastMath.max(1, (N_SAMPLES-1)/leaf_size)) + 1;
        this.n_levels = (int)FastMath.round(nlev);
        this.n_nodes = (int)(FastMath.pow(2, nlev) - 1);
        */

        this.n_levels = (int) (FastMath.log(2, FastMath.max(1, (N_SAMPLES - 1) / leaf_size)) + 1);
        this.n_nodes = (int) (FastMath.pow(2, n_levels) - 1);

        // allocate arrays for storage
        this.idx_array = VecUtils.arange(N_SAMPLES);

        // Add new NodeData objs to node_data arr
        this.node_data = new NodeData[n_nodes];
        for (int i = 0; i < node_data.length; i++)
            node_data[i] = new NodeData();

        // allocate tree specific data
        allocateData(this, n_nodes, N_FEATURES);
        recursiveBuild(0, 0, N_SAMPLES);
    }

    // ========================== Inner classes ==========================

    interface Density {
        double getDensity(double dist, double h);

        double getNorm(double h, int d);
    }

    /**
     * Provides efficient, reduced kernel approximations for points
     * that are faster and simpler than the {@link Kernel} class methods.
     * @author Taylor G Smith
     */
    public static enum PartialKernelDensity implements Density, java.io.Serializable {
        LOG_COSINE {
            @Override
            public double getDensity(double dist, double h) {
                return dist < h ? FastMath.log(FastMath.cos(0.5 * Math.PI * dist / h)) : Double.NEGATIVE_INFINITY;
            }

            @Override
            public double getNorm(double h, int d) {
                double factor = 0;
                double tmp = 2d / Math.PI;

                for (int k = 1; k < d + 1; k += 2) {
                    factor += tmp;
                    tmp *= -(d - k) * (d - k - 1) * FastMath.pow((2.0 / Math.PI), 2);
                }

                return FastMath.log(factor) + logSn(d - 1);
            }
        },

        LOG_EPANECHNIKOV {
            @Override
            public double getDensity(double dist, double h) {
                return dist < h ? FastMath.log(1.0 - (dist * dist) / (h * h)) : Double.NEGATIVE_INFINITY;
            }

            @Override
            public double getNorm(double h, int d) {
                return logVn(d) + FastMath.log(2.0 / (d + 2.0));
            }
        },

        LOG_EXPONENTIAL {
            @Override
            public double getDensity(double dist, double h) {
                return -dist / h;
            }

            @Override
            public double getNorm(double h, int d) {
                return logSn(d - 1) + lgamma(d);
            }
        },

        LOG_GAUSSIAN {
            @Override
            public double getDensity(double dist, double h) {
                return -0.5 * (dist * dist) / (h * h);
            }

            @Override
            public double getNorm(double h, int d) {
                return 0.5 * d * LOG_2PI;
            }
        },

        LOG_LINEAR {
            @Override
            public double getDensity(double dist, double h) {
                return dist < h ? FastMath.log(1 - dist / h) : Double.NEGATIVE_INFINITY;
            }

            @Override
            public double getNorm(double h, int d) {
                return logVn(d) - FastMath.log(d + 1.0);
            }
        },

        LOG_TOPHAT {
            @Override
            public double getDensity(double dist, double h) {
                return dist < h ? 0 : Double.NEGATIVE_INFINITY;
            }

            @Override
            public double getNorm(double h, int d) {
                return logVn(d);
            }
        }
    }

    /**
     * A hacky container for passing double references...
     * Allows us to modify the value of a double as if
     * we had passed a pointer. Since much of this code
     * is adapted from Pyrex, Cython and C code, it
     * leans heavily on passing pointers.
     * @author Taylor G Smith
     */
    // Tested: passing
    public static class MutableDouble implements Comparable<Double>, java.io.Serializable {
        private static final long serialVersionUID = -4636023903600763877L;
        public Double value = new Double(0);

        MutableDouble() {
        }

        MutableDouble(Double value) {
            this.value = value;
        }

        @Override
        public int compareTo(final Double n) {
            return value.compareTo(n);
        }
    }

    /**
     * Node data container
     * @author Taylor G Smith
     */
    // Tested: passing
    public static class NodeData implements DeepCloneable, java.io.Serializable {
        private static final long serialVersionUID = -2469826821608908612L;
        int idx_start, idx_end;
        boolean is_leaf;
        double radius;

        public NodeData() {
        }

        public NodeData(int st, int ed, boolean is, double rad) {
            idx_start = st;
            idx_end = ed;
            is_leaf = is;
            radius = rad;
        }

        @Override
        public String toString() {
            return "NodeData: [" + idx_start + ", " + idx_end + ", " + is_leaf + ", " + radius + "]";
        }

        @Override
        public NodeData copy() {
            return new NodeData(idx_start, idx_end, is_leaf, radius);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o)
                return true;
            if (o instanceof NodeData) {
                NodeData nd = (NodeData) o;
                return nd.idx_start == this.idx_start && nd.idx_end == this.idx_end && nd.is_leaf == this.is_leaf
                        && nd.radius == this.radius;
            }

            return false;
        }

        public boolean isLeaf() {
            return is_leaf;
        }

        public int end() {
            return idx_end;
        }

        public double radius() {
            return radius;
        }

        public int start() {
            return idx_start;
        }
    }

    /**
     * Abstract super class for NodeHeap and
     * NeighborHeap classes
     * @author Taylor G Smith
     */
    abstract static class Heap implements java.io.Serializable {
        private static final long serialVersionUID = 8073174366388667577L;

        abstract static class Node {
            double val;
            int i1;
            int i2;

            Node() {
            }

            Node(double val, int i1, int i2) {
                this.val = val;
                this.i1 = i1;
                this.i2 = i2;
            }
        }

        Heap() {
        }

        static void swapNodes(Node[] arr, int i1, int i2) {
            Node tmp = arr[i1];
            arr[i1] = arr[i2];
            arr[i2] = tmp;
        }

        static void dualSwap(double[] darr, int[] iarr, int i1, int i2) {
            final double dtmp = darr[i1];
            darr[i1] = darr[i2];
            darr[i2] = dtmp;

            final int itmp = iarr[i1];
            iarr[i1] = iarr[i2];
            iarr[i2] = itmp;
        }
    }

    /**
     * A max-heap structure to keep track of distances/indices of neighbors
      * This is based on the sklearn.neighbors.binary_tree module's NeighborsHeap class
      * 
     * @author Taylor G Smith, adapted from sklearn
     * @see <a href="https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/neighbors/binary_tree.pxi">sklearn NodeHeap</a>
     */
    static class NeighborsHeap extends Heap {
        private static final long serialVersionUID = 3065531260075044616L;
        double[][] distances;
        int[][] indices;

        NeighborsHeap(int nPts, int k) {
            super();
            distances = MatUtils.rep(Double.POSITIVE_INFINITY, nPts, k);
            indices = new int[nPts][k];
        }

        Neighborhood getArrays(boolean sort) {
            if (sort)
                this.sort();
            return new Neighborhood(distances, indices);
        }

        int push(int row, double val, int i_val) {
            int i, ic1, ic2, i_swap, size = distances[0].length;
            double[] dist_arr = distances[row];
            int[] ind_arr = indices[row];

            if (val > dist_arr[0])
                return 0;

            // Insert at pos 0
            dist_arr[0] = val;
            ind_arr[0] = i_val;

            // Descend heap, swap vals until max heap criteria met
            i = 0;
            while (true) {
                ic1 = 2 * i + 1;
                ic2 = ic1 + 1;

                if (ic1 >= size)
                    break;
                else if (ic2 >= size) {
                    if (dist_arr[ic1] > val)
                        i_swap = ic1;
                    else
                        break;
                } else if (dist_arr[ic1] >= dist_arr[ic2]) {
                    if (val < dist_arr[ic1])
                        i_swap = ic1;
                    else
                        break;
                } else {
                    if (val < dist_arr[ic2])
                        i_swap = ic2;
                    else
                        break;
                }

                dist_arr[i] = dist_arr[i_swap];
                ind_arr[i] = ind_arr[i_swap];

                i = i_swap;
            }

            dist_arr[i] = val;
            ind_arr[i] = i_val;

            return 0;
        }

        int sort() {
            for (int row = 0; row < distances.length; row++) {
                simultaneous_sort(this.distances[row], this.indices[row], distances[row].length);
            }

            return 0;
        }

        double largest(int row) {
            return distances[row][0];
        }

        static int simultaneous_sort(double[] dist, int[] idx, int size) {
            int pivot_idx, i, store_idx;
            double pivot_val;

            if (size <= 1) { // pass
            }

            else if (size == 2) {
                if (dist[0] > dist[1])
                    dualSwap(dist, idx, 0, 1);
            }

            /*
            else {
               int[] order = VecUtils.argSort(dist);
               dualOrderInPlace(dist, idx, order);
            }
            */

            else if (size == 3) {
                if (dist[0] > dist[1])
                    dualSwap(dist, idx, 0, 1);

                if (dist[1] > dist[2]) {
                    dualSwap(dist, idx, 1, 2);
                    if (dist[0] > dist[1])
                        dualSwap(dist, idx, 0, 1);
                }
            }

            else {
                pivot_idx = size / 2;
                if (dist[0] > dist[size - 1])
                    dualSwap(dist, idx, 0, size - 1);

                if (dist[size - 1] > dist[pivot_idx]) {
                    dualSwap(dist, idx, size - 1, pivot_idx);
                    if (dist[0] > dist[size - 1])
                        dualSwap(dist, idx, 0, size - 1);
                }
                pivot_val = dist[size - 1];

                store_idx = 0;
                for (i = 0; i < size - 1; i++) {
                    if (dist[i] < pivot_val) {
                        dualSwap(dist, idx, i, store_idx);
                        store_idx++;
                    }
                }

                dualSwap(dist, idx, store_idx, size - 1);
                pivot_idx = store_idx;

                if (pivot_idx > 1)
                    simultaneous_sort(dist, idx, pivot_idx);

                if (pivot_idx + 2 < size) {
                    // Can't pass reference to middle of array, so sort copy
                    // and then iterate over sorted, replacing in place
                    final int sliceStart = pivot_idx + 1;
                    final int sliceEnd = dist.length;

                    final int newLen = sliceEnd - sliceStart;
                    double[] a = new double[newLen];
                    int[] b = new int[newLen];

                    System.arraycopy(dist, sliceStart, a, 0, newLen);
                    System.arraycopy(idx, sliceStart, b, 0, newLen);

                    simultaneous_sort(a, b, size - pivot_idx - 1);

                    // Now iter over and replace...
                    for (int k = 0, p = sliceStart; p < sliceEnd; k++, p++) {
                        dist[p] = a[k];
                        idx[p] = b[k];
                    }
                }
            }

            return 0;
        }
    }

    /**
     * A min heap implementation for keeping track of nodes
     * during a breadth first search. This is based on the
     * sklearn.neighbors.binary_tree module's NodeHeap class.
     * 
     * <p>
     * Internally, the data is stored in a simple binary 
     * heap which meetsthe min heap condition:
     * 
     * <p>
     * <tt>heap[i].val < min(heap[2 * i + 1].val, heap[2 * i + 2].val)</tt>
     * 
     * @author Taylor G Smith, adapted from sklearn
     * @see <a href="https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/neighbors/binary_tree.pxi">sklearn NodeHeap</a>
     */
    static class NodeHeap extends Heap {
        private static final long serialVersionUID = 5621403002445703132L;
        NodeHeapData[] data;
        int n;

        /** Node class. */
        static class NodeHeapData extends Node {
            NodeHeapData() {
                super();
            }

            NodeHeapData(double val, int i1, int i2) {
                super(val, i1, i2);
            }

            @Override
            public boolean equals(Object o) {
                if (o == this)
                    return true;
                if (o instanceof NodeHeapData) {
                    NodeHeapData n = (NodeHeapData) o;
                    return n.val == this.val && n.i1 == this.i1 && n.i2 == this.i2;
                }

                return false;
            }

            @Override
            public String toString() {
                return "{" + val + ", " + i1 + ", " + i2 + "}";
            }
        }

        NodeHeap(int size) {
            super();
            size = FastMath.max(size, 1);
            data = new NodeHeapData[size];

            //n = size;
            clear();
        }

        void clear() {
            n = 0;
        }

        NodeHeapData peek() {
            return data[0];
        }

        /**
         * Remove and return first element in heap
         * @return
         */
        NodeHeapData pop() {
            if (this.n == 0)
                throw new IllegalStateException("cannot pop an empty heap");

            int i, i_child1, i_child2, i_swap;
            NodeHeapData popped_element = this.data[0];

            // pop off the first element, move the last element to the front,
            // and then perform swaps until the heap is back in order
            this.data[0] = this.data[this.n - 1];

            // Omitted from sklearn, but added here; make last element null again...
            this.data[this.n - 1] = null;
            this.n--;

            i = 0;

            while (i < this.n) {
                i_child1 = 2 * i + 1;
                i_child2 = 2 * i + 2;
                i_swap = 0;

                if (i_child2 < this.n) {
                    if (this.data[i_child1].val <= this.data[i_child2].val)
                        i_swap = i_child1;
                    else
                        i_swap = i_child2;
                } else if (i_child1 < this.n) {
                    i_swap = i_child1;
                } else {
                    break;
                }

                if (i_swap > 0 && this.data[i_swap].val <= this.data[i].val) {
                    swapNodes(this.data, i, i_swap);
                    i = i_swap;
                } else {
                    break;
                }
            }

            return popped_element;
        }

        int push(NodeHeapData node) {
            // Add to the heap
            int i;
            this.n++;

            // If the new n exceeds current length,
            // double the size of the data array
            if (this.n > this.data.length)
                resize(2 * this.n);

            // Put new element at end, perform swaps
            i = this.n - 1;
            this.data[i] = node;
            reorderFromPush(i);

            return 0;
        }

        private void reorderFromPush(int i) {
            int i_parent;
            while (i > 0) {
                i_parent = (i - 1) / 2;
                if (this.data[i_parent].val <= this.data[i].val)
                    break;
                else {
                    swapNodes(this.data, i, i_parent);
                    i = i_parent;
                }
            }
        }

        int resize(int new_size) {
            if (new_size < 1)
                throw new IllegalArgumentException(
                        "cannot resize heap " + "to size less than 1 (" + new_size + ")");

            // Resize larger or smaller
            int size = this.data.length;
            final int oldN = n;
            NodeHeapData[] newData = new NodeHeapData[new_size];

            // Original sklearn line included if clause, but due to our
            // new IAE check, we can skip it and enter for loop automatically:
            // if(size > 0 && new_size > 0)

            for (int i = 0; i < FastMath.min(size, new_size); i++)
                newData[i] = this.data[i];

            // Original sklearn line, but seems to be
            // buggy. n is supposed to be count of objs inside,
            // and as it stands, this makes n equal to total size
            // of the heap.
            /*
            if(new_size < size)
               this.n = new_size;
            */

            // New line that accts for above corner case:
            if (new_size < size)
                this.n = FastMath.min(new_size, oldN);

            this.data = newData;
            return 0;
        }

        @Override
        public String toString() {
            return Arrays.toString(this.data);
        }
    }

    // ========================== Getters ==========================
    public double[][] getData() {
        return MatUtils.copy(data_arr);
    }

    double[][] getDataRef() {
        return data_arr;
    }

    public int getLeafSize() {
        return leaf_size;
    }

    public DistanceMetric getMetric() {
        return dist_metric;
    }

    public double[][][] getNodeBounds() {
        int m = node_bounds.length;

        double[][][] out = new double[m][][];
        for (int i = 0; i < m; i++)
            out[i] = MatUtils.copy(node_bounds[i]);

        return out;
    }

    double[][][] getNodeBoundsRef() {
        return node_bounds;
    }

    public int[] getIndexArray() {
        return VecUtils.copy(idx_array);
    }

    int[] getIndexArrayRef() {
        return idx_array;
    }

    public NodeData[] getNodeData() {
        NodeData[] copy = new NodeData[node_data.length];
        for (int i = 0; i < copy.length; i++)
            copy[i] = node_data[i].copy();
        return copy;
    }

    NodeData[] getNodeDataRef() {
        return node_data;
    }

    // ========================== Instance methods ==========================
    double dist(final double[] a, final double[] b) {
        n_calls++;
        return dist_metric.getDistance(a, b);
    }

    public int getNumCalls() {
        return n_calls;
    }

    double rDist(final double[] a, final double[] b) {
        n_calls++;
        return dist_metric.getPartialDistance(a, b);
    }

    double rDistToDist(final double d) {
        return dist_metric.partialDistanceToDistance(d);
    }

    private void rDistToDistInPlace(final double[][] d) {
        final int m = d.length, n = d[0].length;
        for (int i = 0; i < m; i++)
            for (int j = 0; j < n; j++)
                d[i][j] = rDistToDist(d[i][j]);
    }

    private void estimateKernelDensitySingleDepthFirst(int i_node, double[] pt, PartialKernelDensity kern, double h,
            double logKNorm, double logAbsTol, double logRelTol, double localLogMinBound,
            double localLogBoundSpread, MutableDouble globalLogMinBound, MutableDouble globalLogBoundSpread) {

        int i, i1, i2, N1, N2;
        double[][] data = this.data_arr;
        NodeData nodeInfo = this.node_data[i_node];
        double dist_pt, logDensContribution;

        double child1LogMinBound, child2LogMinBound, child1LogBoundSpread, child2LogBoundSpread;
        MutableDouble dist_UB = new MutableDouble(), dist_LB = new MutableDouble();

        N1 = nodeInfo.idx_end - nodeInfo.idx_start;
        N2 = N_SAMPLES;
        double logN1 = FastMath.log(N1), logN2 = FastMath.log(N2);

        // If local bounds equal to within errors
        if (logKNorm + localLogBoundSpread - logN1 + logN2 <= logAddExp(logAbsTol,
                (logRelTol + logKNorm + localLogMinBound))) {
            return;
        }

        // If global bounds are within rel tol & abs tol
        else if (logKNorm + globalLogBoundSpread.value <= logAddExp(logAbsTol,
                (logRelTol + logKNorm + globalLogMinBound.value))) {
            return;
        }

        // node is a leaf
        else if (nodeInfo.is_leaf) {
            globalLogMinBound.value = logSubExp(globalLogMinBound.value, localLogMinBound);
            globalLogBoundSpread.value = logSubExp(globalLogBoundSpread.value, localLogBoundSpread);

            for (i = nodeInfo.idx_start; i < nodeInfo.idx_end; i++) {
                dist_pt = this.dist(pt, data[idx_array[i]]);
                logDensContribution = kern.getDensity(dist_pt, h);
                globalLogMinBound.value = logAddExp(globalLogMinBound.value, logDensContribution);
            }
        }

        // Split and query
        else {
            i1 = 2 * i_node + 1;
            i2 = 2 * i_node + 2;

            N1 = this.node_data[i1].idx_end - this.node_data[i1].idx_start;
            N2 = this.node_data[i2].idx_end - this.node_data[i2].idx_start;
            logN1 = FastMath.log(N1);
            logN2 = FastMath.log(N2);

            // Mutates distLB & distUB internally
            minMaxDist(this, i1, pt, dist_LB, dist_UB);
            child1LogMinBound = logN1 + kern.getDensity(dist_UB.value, h);
            child1LogBoundSpread = logSubExp(logN1 + kern.getDensity(dist_LB.value, h), child1LogMinBound);

            // Mutates distLB & distUB internally
            minMaxDist(this, i2, pt, dist_LB, dist_UB);
            child2LogMinBound = logN2 + kern.getDensity(dist_UB.value, h);
            child2LogBoundSpread = logSubExp(logN2 + kern.getDensity(dist_LB.value, h), child2LogMinBound);

            // Update log min bound
            globalLogMinBound.value = logSubExp(globalLogMinBound.value, localLogMinBound);
            globalLogMinBound.value = logAddExp(globalLogMinBound.value, child1LogMinBound);
            globalLogMinBound.value = logAddExp(globalLogMinBound.value, child2LogMinBound);

            // Update log bound spread
            globalLogBoundSpread.value = logSubExp(globalLogBoundSpread.value, localLogBoundSpread);
            globalLogBoundSpread.value = logAddExp(globalLogBoundSpread.value, child1LogBoundSpread);
            globalLogBoundSpread.value = logAddExp(globalLogBoundSpread.value, child2LogBoundSpread);

            // Recurse left then right
            estimateKernelDensitySingleDepthFirst(i1, pt, kern, h, logKNorm, logAbsTol, logRelTol,
                    child1LogMinBound, child1LogBoundSpread, globalLogMinBound, globalLogBoundSpread);

            estimateKernelDensitySingleDepthFirst(i2, pt, kern, h, logKNorm, logAbsTol, logRelTol,
                    child2LogMinBound, child2LogBoundSpread, globalLogMinBound, globalLogBoundSpread);
        }
    }

    // Tested: passing
    public static int findNodeSplitDim(double[][] data, int[] idcs) {
        // Gets the difference between the vector of column
        // maxes and the vector of column mins, then finds the
        // arg max.

        // computes equivalent of (sklearn): 
        // j_max = np.argmax(np.max(data, 0) - np.min(data, 0))
        int n = data[0].length, idx, argMax = -1;
        double[] maxVec = VecUtils.rep(Double.NEGATIVE_INFINITY, n),
                minVec = VecUtils.rep(Double.POSITIVE_INFINITY, n), current;
        double diff, maxDiff = Double.NEGATIVE_INFINITY;

        // Optimized to one KxN pass
        for (int i = 0; i < idcs.length; i++) {
            idx = idcs[i];
            current = data[idx];

            for (int j = 0; j < n; j++) {
                if (current[j] > maxVec[j])
                    maxVec[j] = current[j];
                if (current[j] < minVec[j])
                    minVec[j] = current[j];

                // If the last iter, we can calc difference right now
                if (i == idcs.length - 1) {
                    diff = maxVec[j] - minVec[j];
                    if (diff > maxDiff) {
                        maxDiff = diff;
                        argMax = j;
                    }
                }
            }
        }

        return argMax;
    }

    /**
     * Returns a QuadTup with references to the arrays
     * @return
     */
    public QuadTup<double[][], int[], NodeData[], double[][][]> getArrays() {
        return new QuadTup<>(data_arr, idx_array, node_data, node_bounds);
    }

    public Triple<Integer, Integer, Integer> getTreeStats() {
        return new ImmutableTriple<>(n_trims, n_leaves, n_splits);
    }

    public double[] kernelDensity(double[][] X, double bandwidth, PartialKernelDensity kern, double absTol,
            double relTol, boolean returnLog) {

        double b_c = bandwidth, logAbsTol = FastMath.log(absTol), logRelTol = FastMath.log(relTol);

        MutableDouble logMinBound = new MutableDouble(), logMaxBound = new MutableDouble(),
                logBoundSpread = new MutableDouble();
        MutableDouble dist_LB = new MutableDouble(), dist_UB = new MutableDouble();
        int m = data_arr.length, n = data_arr[0].length, i;

        // Ensure X col dim matches training data col dim
        MatUtils.checkDims(X);
        if (X[0].length != n)
            throw new DimensionMismatchException(n, X[0].length);

        final double logKNorm = logKernelNorm(b_c, n, kern), logM = FastMath.log(m), log2 = FastMath.log(2);
        double[][] Xarr = MatUtils.copy(X);
        double[] logDensity = new double[Xarr.length], pt;

        for (i = 0; i < Xarr.length; i++) {
            pt = Xarr[i];

            minMaxDist(this, 0, pt, dist_LB, dist_UB);
            logMinBound.value = logM + kern.getDensity(dist_UB.value, b_c);
            logMaxBound.value = logM + kern.getDensity(dist_LB.value, b_c);
            logBoundSpread.value = logSubExp(logMaxBound.value, logMinBound.value);

            estimateKernelDensitySingleDepthFirst(0, pt, kern, b_c, logKNorm, logAbsTol, logRelTol,
                    logMinBound.value, logBoundSpread.value, logMinBound, logBoundSpread);

            logDensity[i] = logAddExp(logMinBound.value, logBoundSpread.value - log2);
        }

        // Norm results
        for (i = 0; i < logDensity.length; i++)
            logDensity[i] += logKNorm;

        return returnLog ? logDensity : VecUtils.exp(logDensity);
    }

    private double logAddExp(double x1, double x2) {
        final double a = FastMath.max(x1, x2);
        if (Double.NEGATIVE_INFINITY == a)
            return a;
        return a + FastMath.log(FastMath.exp(x1 - a) + FastMath.exp(x2 - a));
    }

    static double logKernelNorm(double h, int i, PartialKernelDensity kern) {
        return -kern.getNorm(h, i) - i * FastMath.log(h);
    }

    static double logSn(int n) {
        return LOG_2PI + logVn(n - 1);
    }

    private double logSubExp(double x1, double x2) {
        if (x1 <= x2)
            return Double.NEGATIVE_INFINITY;
        return x1 + FastMath.log(1 - FastMath.exp(x2 - x1));
    }

    static double logVn(int n) {
        return 0.5 * n * LOG_PI - lgamma(0.5 * n + 1);
    }

    public static void partitionNodeIndices(double[][] data, int[] nodeIndices, int splitDim, int splitIndex,
            int nFeatures, int nPoints) {

        int left = 0;
        int right = nPoints - 1;
        double d1, d2;

        while (true) {
            int midindex = left;

            for (int i = left; i < right; i++) {
                d1 = data[nodeIndices[i]][splitDim];
                d2 = data[nodeIndices[right]][splitDim];

                if (d1 < d2) {
                    swap(nodeIndices, i, midindex);
                    midindex++;
                }
            }

            swap(nodeIndices, midindex, right);
            if (midindex == splitIndex) {
                break;
            } else if (midindex < splitIndex) {
                left = midindex + 1;
            } else {
                right = midindex - 1;
            }
        }
    }

    void resetNumCalls() {
        n_calls = 0;
    }

    void recursiveBuild(int i_node, int idx_start, int idx_end) {
        int i_max, n_points = idx_end - idx_start, n_mid = n_points / 2;
        initNode(this, i_node, idx_start, idx_end);

        if (2 * i_node + 1 >= this.n_nodes) {
            node_data[i_node].is_leaf = true;

            if (idx_end - idx_start > 2 * leaf_size) {
                if (null != logger)
                    logger.warn(MEM_ERR);
            } else {
                /*really should only hit this block*/}

        } else if (idx_end - idx_start < 2) {
            if (null != logger)
                logger.warn(MEM_ERR);
            node_data[i_node].is_leaf = true;
        } else {
            // split node and recursively build child nodes
            node_data[i_node].is_leaf = false;
            i_max = findNodeSplitDim(data_arr, idx_array);
            partitionNodeIndices(data_arr, idx_array, i_max, n_mid, N_FEATURES, n_points);

            recursiveBuild(2 * i_node + 1, idx_start, idx_start + n_mid);
            recursiveBuild(2 * i_node + 2, idx_start + n_mid, idx_end);
        }
    }

    /**
     * Swap two indices in place
     * @param idcs
     * @param i1
     * @param i2
     */
    static void swap(int[] idcs, int i1, int i2) {
        int tmp = idcs[i1];
        idcs[i1] = idcs[i2];
        idcs[i2] = tmp;
    }

    /**
     * Default query, which calls {@link #query(double[][], int, boolean, boolean)}
     * <tt>(X, 1, false, true)</tt>
     * @param X
     * @return the neighborhood
     */
    public Neighborhood query(double[][] X) {
        return query(X, 1, false, true);
    }

    public Neighborhood query(double[][] X, int k, boolean dualTree, boolean sort) {
        MatUtils.checkDims(X);

        final int n = data_arr[0].length, mPrime = X.length;

        if (n != X[0].length)
            throw new DimensionMismatchException(n, X[0].length);
        if (this.N_SAMPLES < k)
            throw new IllegalArgumentException(k + " is greater than rows in data");
        if (k < 1)
            throw new IllegalArgumentException(k + " must exceed 0");

        double[][] Xarr = X;

        // Initialize neighbor heap
        NeighborsHeap heap = new NeighborsHeap(mPrime, k);

        double[] bounds, pt;
        double reduced_dist_LB;

        this.n_trims = 0;
        this.n_leaves = 0;
        this.n_splits = 0;

        if (dualTree) {
            NearestNeighborHeapSearch other = newInstance(Xarr, leaf_size, dist_metric, logger);

            reduced_dist_LB = minRDistDual(this, 0, other, 0);
            bounds = VecUtils.rep(Double.POSITIVE_INFINITY, this.N_SAMPLES);
            queryDualDepthFirst(0, other, 0, bounds, heap, reduced_dist_LB);
        } else {
            int i;

            for (i = 0; i < mPrime; i++) {
                pt = Xarr[i];
                reduced_dist_LB = minRDist(this, 0, pt);
                querySingleDepthFirst(0, pt, i, heap, reduced_dist_LB);
            }
        }

        Neighborhood distances_indices = heap.getArrays(sort);
        int[][] indices = distances_indices.getValue();
        double[][] distances = distances_indices.getKey();
        rDistToDistInPlace(distances); // set back to dist

        return new Neighborhood(distances, indices);
    }

    private void queryDualDepthFirst(int i_node1, NearestNeighborHeapSearch other, int i_node2, double[] bounds,
            NeighborsHeap heap, double reduced_dist_LB) {
        NodeData node_info1 = this.node_data[i_node1], node_info2 = other.node_data[i_node2];
        double[][] data1 = this.data_arr, data2 = other.data_arr;
        int i1, i2, i_pt, i_parent;
        double bound_max, dist_pt, reduced_dist_LB1, reduced_dist_LB2;

        // If nodes are farther apart than current bound
        if (reduced_dist_LB > bounds[i_node2]) { // Pass here
        }

        // If both nodes are leaves
        else if (node_info1.is_leaf && node_info2.is_leaf) {
            bounds[i_node2] = 0;

            for (i2 = node_info2.idx_start; i2 < node_info2.idx_end; i2++) {
                i_pt = other.idx_array[i2];

                if (heap.largest(i_pt) <= reduced_dist_LB)
                    continue;

                for (i1 = node_info1.idx_start; i1 < node_info1.idx_end; i1++) {

                    // sklearn line:
                    // data1 + n_features * self.idx_array[i1],
                    // data2 + n_features * i_pt
                    dist_pt = rDist(data1[idx_array[i1]], data2[i_pt]);
                    if (dist_pt < heap.largest(i_pt))
                        heap.push(i_pt, dist_pt, idx_array[i1]);
                }

                // Keep track of node bound
                bounds[i_node2] = FastMath.max(bounds[i_node2], heap.largest(i_pt));
            }

            // Update bounds
            while (i_node2 > 0) {
                i_parent = (i_node2 - 1) / 2;
                bound_max = FastMath.max(bounds[2 * i_parent + 1], bounds[2 * i_parent + 2]);
                if (bound_max < bounds[i_parent]) {
                    bounds[i_parent] = bound_max;
                    i_node2 = i_parent;
                } else
                    break;
            }
        }

        // When node 1 is a leaf or is smaller
        else if (node_info1.is_leaf || (!node_info2.is_leaf && node_info2.radius > node_info1.radius)) {

            reduced_dist_LB1 = minRDistDual(this, i_node1, other, 2 * i_node2 + 1);
            reduced_dist_LB2 = minRDistDual(this, i_node1, other, 2 * i_node2 + 2);

            if (reduced_dist_LB1 < reduced_dist_LB2) {
                queryDualDepthFirst(i_node1, other, 2 * i_node2 + 1, bounds, heap, reduced_dist_LB1);
                queryDualDepthFirst(i_node1, other, 2 * i_node2 + 2, bounds, heap, reduced_dist_LB2);
            } else {
                // Do it in the opposite order...
                queryDualDepthFirst(i_node1, other, 2 * i_node2 + 2, bounds, heap, reduced_dist_LB2);
                queryDualDepthFirst(i_node1, other, 2 * i_node2 + 1, bounds, heap, reduced_dist_LB1);
            }
        }

        // Otherwise node 2 is a leaf or is smaller
        else {
            reduced_dist_LB1 = minRDistDual(this, 2 * i_node1 + 1, other, i_node2);
            reduced_dist_LB2 = minRDistDual(this, 2 * i_node1 + 2, other, i_node2);

            if (reduced_dist_LB1 < reduced_dist_LB2) {
                queryDualDepthFirst(2 * i_node1 + 1, other, i_node2, bounds, heap, reduced_dist_LB1);
                queryDualDepthFirst(2 * i_node1 + 2, other, i_node2, bounds, heap, reduced_dist_LB2);
            } else {
                // Do it in the opposite order...
                queryDualDepthFirst(2 * i_node1 + 2, other, i_node2, bounds, heap, reduced_dist_LB2);
                queryDualDepthFirst(2 * i_node1 + 1, other, i_node2, bounds, heap, reduced_dist_LB1);
            }
        }
    }

    private void ensurePositiveRadius(final double radius) {
        RadiusNeighbors.validateRadius(radius);
    }

    public Neighborhood queryRadius(final RealMatrix X, double[] radius, boolean sort) {
        return queryRadius(X.getData(), radius, sort);
    }

    public Neighborhood queryRadius(double[][] X, double[] radius, boolean sort) {
        int i, m_prime = X.length;
        int[] idx_arr_i, counts_arr;
        double[] dist_arr_i, pt;

        // Assumes non-jagged rows but caught in dist ops...
        MatUtils.checkDims(X);
        if (X[0].length != N_FEATURES)
            throw new DimensionMismatchException(X[0].length, N_FEATURES);

        VecUtils.checkDims(radius);
        if (m_prime != radius.length)
            throw new DimensionMismatchException(m_prime, radius.length);

        for (double rad : radius)
            ensurePositiveRadius(rad);

        // Prepare for iter
        int[][] indices = new int[m_prime][];
        double[][] dists = new double[m_prime][];

        idx_arr_i = new int[N_SAMPLES];
        dist_arr_i = new double[N_SAMPLES];
        counts_arr = new int[m_prime];

        // For each row in X
        for (i = 0; i < m_prime; i++) {
            // The current row
            pt = X[i];

            counts_arr[i] = queryRadiusSingle(0, pt, radius[i], idx_arr_i, dist_arr_i, 0, true);

            if (sort)
                NeighborsHeap.simultaneous_sort(dist_arr_i, idx_arr_i, counts_arr[i]);

            // There's a chance the length could be zero if there are no neighbors in the radius...
            indices[i] = counts_arr.length == 0 ? new int[] {} : VecUtils.slice(idx_arr_i, 0, counts_arr[i]);
            dists[i] = counts_arr.length == 0 ? new double[] {} : VecUtils.slice(dist_arr_i, 0, counts_arr[i]);
        }

        return new Neighborhood(dists, indices);
    }

    public Neighborhood queryRadius(double[][] X, double radius, boolean sort) {
        MatUtils.checkDims(X);
        ensurePositiveRadius(radius);

        int n = X[0].length;
        if (n != N_FEATURES)
            throw new DimensionMismatchException(n, N_FEATURES);

        return queryRadius(X, VecUtils.rep(radius, X.length), sort);
    }

    private int queryRadiusSingle(final int i_node, final double[] pt, final double r, final int[] indices,
            final double[] distances, int count, final boolean returnDists) {

        double[][] data = this.data_arr;
        NodeData nodeInfo = node_data[i_node];

        int i;
        double reduced_r, dist_pt;

        // Lower bound (min)
        MutableDouble dist_LB = new MutableDouble(0.0);

        // Upper bound (max)
        MutableDouble dist_UB = new MutableDouble(0.0);

        // Find min dist and max dist from pts
        minMaxDist(this, i_node, pt, dist_LB, dist_UB);

        // If min dist is greater than radius, then pass
        if (dist_LB.value > r) {
        } // pass

        // All points within radius
        else if (dist_UB.value <= r) {
            for (i = nodeInfo.idx_start; i < nodeInfo.idx_end; i++) {
                /*// can't really happen?
                if(count < 0 || count >= N_SAMPLES) {
                   String err = "count is too big; this should not happen";
                   if(null != logger)
                      logger.error(err);
                   throw new IllegalStateException(err);
                }
                */

                indices[count] = idx_array[i];
                if (returnDists)
                    distances[count] = this.dist(pt, data[idx_array[i]]);

                count++;
            }
        }

        // this is a leaf node; check every point
        else if (nodeInfo.is_leaf) {
            reduced_r = this.dist_metric.distanceToPartialDistance(r);

            for (i = nodeInfo.idx_start; i < nodeInfo.idx_end; i++) {
                dist_pt = this.rDist(pt, data[idx_array[i]]);

                if (dist_pt <= reduced_r) {
                    /*// can't really happen?
                    if(count < 0 || count >= N_SAMPLES) {
                       String err = "count is too big; this should not happen";
                       if(null != logger)
                          logger.error(err);
                       throw new IllegalStateException(err);
                    }
                    */

                    indices[count] = idx_array[i];
                    if (returnDists)
                        distances[count] = this.dist_metric.partialDistanceToDistance(dist_pt);

                    count++;
                }
            }
        }

        // Otherwise node is not a leaf. Recursively check subnodes
        else {
            count = this.queryRadiusSingle(2 * i_node + 1, pt, r, indices, distances, count, returnDists);

            count = this.queryRadiusSingle(2 * i_node + 2, pt, r, indices, distances, count, returnDists);
        }

        return count;
    }

    private void querySingleDepthFirst(int i_node, double[] pt, int i_pt, NeighborsHeap heap,
            double reduced_dist_LB) {
        NodeData nodeInfo = this.node_data[i_node];

        double dist_pt, reduced_dist_LB_1, reduced_dist_LB_2;
        int i, i1, i2;

        // Query point is outside node radius
        if (reduced_dist_LB > heap.largest(i_pt))
            this.n_trims++;

        // This is a leaf node
        else if (nodeInfo.is_leaf) {
            this.n_leaves++;
            for (i = nodeInfo.idx_start; i < nodeInfo.idx_end; i++) {
                dist_pt = rDist(pt, this.data_arr[idx_array[i]]);

                if (dist_pt < heap.largest(i_pt)) { // in radius
                    heap.push(i_pt, dist_pt, idx_array[i]);
                }
            }
        }

        // Node is not a leaf
        else {
            this.n_splits++;
            i1 = 2 * i_node + 1;
            i2 = i1 + 1;

            reduced_dist_LB_1 = minRDist(this, i1, pt);
            reduced_dist_LB_2 = minRDist(this, i2, pt);

            // Recurse
            if (reduced_dist_LB_1 <= reduced_dist_LB_2) {
                querySingleDepthFirst(i1, pt, i_pt, heap, reduced_dist_LB_1);
                querySingleDepthFirst(i2, pt, i_pt, heap, reduced_dist_LB_2);

            } else { // opposite order

                querySingleDepthFirst(i2, pt, i_pt, heap, reduced_dist_LB_2);
                querySingleDepthFirst(i1, pt, i_pt, heap, reduced_dist_LB_1);
            }
        }
    }

    public int[] twoPointCorrelation(double[][] X, double r) {
        return twoPointCorrelation(X, r, false);
    }

    public int[] twoPointCorrelation(double[][] X, double r, boolean dual) {
        return twoPointCorrelation(X, VecUtils.rep(r, X.length), dual);
    }

    public int[] twoPointCorrelation(double[][] X, double[] r) {
        return twoPointCorrelation(X, r, false);
    }

    public int[] twoPointCorrelation(double[][] X, double[] r, boolean dual) {
        int i;

        MatUtils.checkDims(X);
        if (X[0].length != N_FEATURES)
            throw new DimensionMismatchException(X[0].length, N_FEATURES);

        double[][] Xarr = MatUtils.copy(X);
        double[] rarr = VecUtils.reorder(r, VecUtils.argSort(r));

        // count array
        int[] carr = new int[r.length];

        if (dual) {
            NearestNeighborHeapSearch other = newInstance(Xarr, leaf_size, dist_metric, logger);
            this.twoPointDual(0, other, 0, rarr, carr, 0, rarr.length);
        } else {
            for (i = 0; i < Xarr.length; i++)
                this.twoPointSingle(0, Xarr[i], rarr, carr, 0, rarr.length);
        }

        return carr;
    }

    private void twoPointDual(int i_node1, NearestNeighborHeapSearch other, int i_node2, double[] r, int[] count,
            int i_min, int i_max) {

        double[][] data1 = this.data_arr;
        double[][] data2 = other.data_arr;

        int[] idx_array1 = this.idx_array;
        int[] idx_array2 = other.idx_array;

        NodeData nodeInfo1 = this.node_data[i_node1];
        NodeData nodeInfo2 = other.node_data[i_node2];

        int i1, i2, j, Npts;
        double dist_pt;
        double dist_LB, dist_UB;

        dist_LB = minDistDual(this, i_node1, other, i_node2);
        dist_UB = maxDistDual(this, i_node1, other, i_node2);

        // Check for cuts
        while (i_min < i_max) {
            if (dist_LB > r[i_min])
                i_min++;
            else
                break;
        }

        while (i_max > i_min) {
            Npts = ((nodeInfo1.idx_end - nodeInfo1.idx_start) * (nodeInfo2.idx_end - nodeInfo2.idx_start));
            if (dist_UB <= r[i_max - 1]) {
                count[i_max - 1] += Npts;
                i_max--;
            } else
                break;
        }

        if (i_min < i_max) {
            if (nodeInfo1.is_leaf && nodeInfo2.is_leaf) {
                for (i1 = nodeInfo1.idx_start; i1 < nodeInfo1.idx_end; i1++) {
                    for (i2 = nodeInfo2.idx_start; i2 < nodeInfo2.idx_end; i2++) {

                        dist_pt = this.dist(data1[idx_array1[i1]], data2[idx_array2[i2]]);
                        j = i_max - 1;

                        while (j >= i_min && dist_pt <= r[j])
                            count[j--]++;
                    }
                }

            } else if (nodeInfo1.is_leaf) {
                for (i2 = 2 * i_node2 + 1; i2 < 2 * i_node2 + 3; i2++)
                    this.twoPointDual(i_node1, other, i2, r, count, i_min, i_max);

            } else if (nodeInfo2.is_leaf) {
                for (i1 = 2 * i_node1 + 1; i1 < 2 * i_node1 + 3; i1++)
                    this.twoPointDual(i1, other, i_node2, r, count, i_min, i_max);

            } else {
                for (i1 = 2 * i_node1 + 1; i1 < 2 * i_node1 + 3; i1++)
                    for (i2 = 2 * i_node2 + 1; i2 < 2 * i_node2 + 3; i2++)
                        this.twoPointDual(i1, other, i2, r, count, i_min, i_max);
            }
        }
    }

    private void twoPointSingle(int i_node, double[] pt, double[] r, int[] count, int i_min, int i_max) {
        double[][] data = this.data_arr;
        NodeData nodeInfo = node_data[i_node];

        int i, j, Npts;
        double dist_pt;

        MutableDouble dist_LB = new MutableDouble(0.0), dist_UB = new MutableDouble(0.0);
        minMaxDist(this, i_node, pt, dist_LB, dist_UB);

        while (i_min < i_max) {
            if (dist_LB.value > r[i_min])
                i_min++;
            else
                break;
        }

        while (i_max > i_min) {
            Npts = nodeInfo.idx_end - nodeInfo.idx_start;
            if (dist_UB.value <= r[i_max - 1]) {
                count[i_max - 1] += Npts;
                i_max--;
            } else
                break;

        }

        if (i_min < i_max) {
            if (nodeInfo.is_leaf) {
                for (i = nodeInfo.idx_start; i < nodeInfo.idx_end; i++) {
                    dist_pt = this.dist(pt, data[idx_array[i]]);
                    j = i_max - 1;
                    while (j >= i_min && dist_pt <= r[j])
                        count[j--]++;
                    // same as count[j]++; j--;
                }
            } else {
                this.twoPointSingle(2 * i_node + 1, pt, r, count, i_min, i_max);
                this.twoPointSingle(2 * i_node + 2, pt, r, count, i_min, i_max);
            }
        }
    }

    // Init functions
    abstract void allocateData(NearestNeighborHeapSearch tree, int n_nodes, int n_features);

    abstract void initNode(NearestNeighborHeapSearch tree, int i_node, int idx_start, int idx_end);

    // Dist functions
    //abstract double maxDist      (NearestNeighborHeapSearch tree, int i_node, double[] pt);
    abstract double minDist(NearestNeighborHeapSearch tree, int i_node, double[] pt);

    abstract double maxDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2,
            int iNode2);

    abstract double minDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2,
            int iNode2);

    abstract void minMaxDist(NearestNeighborHeapSearch tree, int i_node, double[] pt, MutableDouble minDist,
            MutableDouble maxDist);

    //abstract double maxRDist   (NearestNeighborHeapSearch tree, int i_node, double[] pt);
    abstract double minRDist(NearestNeighborHeapSearch tree, int i_node, double[] pt);

    abstract double maxRDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2,
            int iNode2);

    abstract double minRDistDual(NearestNeighborHeapSearch tree1, int iNode1, NearestNeighborHeapSearch tree2,
            int iNode2);

    // Hack for new instance functions
    abstract NearestNeighborHeapSearch newInstance(double[][] arr, int leaf, DistanceMetric dist, Loggable logger);
}