com.clust4j.algo.NNHSTests.java Source code

Java tutorial

Introduction

Here is the source code for com.clust4j.algo.NNHSTests.java

Source

/*******************************************************************************
 *    Copyright 2015, 2016 Taylor G Smith
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 *******************************************************************************/
package com.clust4j.algo;

import static org.junit.Assert.*;

import java.util.ArrayList;
import java.util.Random;

import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Precision;
import org.junit.Test;

import com.clust4j.TestSuite;
import com.clust4j.algo.BallTree;
import com.clust4j.algo.KDTree;
import com.clust4j.algo.NearestNeighborHeapSearch.Heap;
import com.clust4j.algo.NearestNeighborHeapSearch.NodeHeap.NodeHeapData;
import com.clust4j.algo.NearestNeighborHeapSearch.MutableDouble;
import com.clust4j.algo.NearestNeighborHeapSearch.NeighborsHeap;
import com.clust4j.algo.NearestNeighborHeapSearch.NodeData;
import com.clust4j.algo.NearestNeighborHeapSearch.NodeHeap;
import com.clust4j.algo.NearestNeighborHeapSearch.PartialKernelDensity;
import com.clust4j.algo.Neighborhood;
import com.clust4j.log.Loggable;
import com.clust4j.metrics.pairwise.Distance;
import com.clust4j.metrics.pairwise.DistanceMetric;
import com.clust4j.utils.MatUtils;
import com.clust4j.utils.QuadTup;
import com.clust4j.utils.VecUtils;
import com.clust4j.utils.Series.Inequality;
import com.clust4j.utils.VecUtils.DoubleSeries;
import com.clust4j.utils.VecUtils.IntSeries;

public class NNHSTests {
    final public static Array2DRowRealMatrix IRIS = TestSuite.IRIS_DATASET.getData();

    final static double[][] a = new double[][] { new double[] { 0, 1, 0, 2 }, new double[] { 0, 0, 1, 2 },
            new double[] { 5, 6, 7, 4 } };

    @Test
    public void testKD1() {
        final Array2DRowRealMatrix mat = new Array2DRowRealMatrix(a, false);
        KDTree kd = new KDTree(mat);

        QuadTup<double[][], int[], NodeData[], double[][][]> arrays = kd.getArrays();

        assertTrue(MatUtils.equalsExactly(arrays.getFirst(), a));
        assertTrue(VecUtils.equalsExactly(new int[] { 0, 1, 2 }, arrays.getSecond()));

        Triple<Integer, Integer, Integer> stats = kd.getTreeStats();
        assertTrue(stats.getLeft() == 0);
        assertTrue(stats.getMiddle() == 0);
        assertTrue(stats.getRight() == 0);

        NodeData data = arrays.getThird()[0];
        assertTrue(data.idx_start == 0);
        assertTrue(data.idx_end == 3);
        assertTrue(data.is_leaf);
        assertTrue(data.radius == 1);
    }

    @Test
    public void testBall1() {
        final Array2DRowRealMatrix mat = new Array2DRowRealMatrix(a, false);
        BallTree ball = new BallTree(mat);

        QuadTup<double[][], int[], NodeData[], double[][][]> arrays = ball.getArrays();

        assertTrue(MatUtils.equalsExactly(arrays.getFirst(), a));
        assertTrue(VecUtils.equalsExactly(new int[] { 0, 1, 2 }, arrays.getSecond()));

        Triple<Integer, Integer, Integer> stats = ball.getTreeStats();
        assertTrue(stats.getLeft() == 0);
        assertTrue(stats.getMiddle() == 0);
        assertTrue(stats.getRight() == 0);

        NodeData data = arrays.getThird()[0];
        assertTrue(data.idx_start == 0);
        assertTrue(data.idx_end == 3);
        assertTrue(data.is_leaf);
        assertTrue(data.radius == 6.716480559869961);

        double[][][] trip = arrays.getFourth();
        assertTrue(trip.length == 1);
        assertTrue(trip[0][0][0] == 1.6666666666666667);
        assertTrue(trip[0][0][1] == 2.3333333333333333);
        assertTrue(trip[0][0][2] == 2.6666666666666667);
        assertTrue(trip[0][0][3] == 2.6666666666666667);
    }

    @Test
    public void testKernelDensitiesAndNorms() {
        /*
         * These are the numbers the sklearn code produces (though numpy rounds up more than java)
         */

        // Test where dist > h first
        double dist = 5.0, h = 1.3;

        assertTrue(PartialKernelDensity.LOG_GAUSSIAN.getDensity(dist, h) == -7.396449704142011);
        assertTrue(PartialKernelDensity.LOG_TOPHAT.getDensity(dist, h) == Double.NEGATIVE_INFINITY);
        assertTrue(PartialKernelDensity.LOG_EPANECHNIKOV.getDensity(dist, h) == Double.NEGATIVE_INFINITY);
        assertTrue(PartialKernelDensity.LOG_EXPONENTIAL.getDensity(dist, h) == -3.846153846153846);
        assertTrue(PartialKernelDensity.LOG_LINEAR.getDensity(dist, h) == Double.NEGATIVE_INFINITY);
        assertTrue(PartialKernelDensity.LOG_COSINE.getDensity(dist, h) == Double.NEGATIVE_INFINITY);

        // Test where dist < h second
        dist = 1.3;
        h = 5.0;

        assertTrue(PartialKernelDensity.LOG_GAUSSIAN.getDensity(dist, h) == -0.033800000000000004);
        assertTrue(PartialKernelDensity.LOG_TOPHAT.getDensity(dist, h) == 0.0);
        assertTrue(PartialKernelDensity.LOG_EPANECHNIKOV.getDensity(dist, h) == -0.06999337182053497);
        assertTrue(PartialKernelDensity.LOG_EXPONENTIAL.getDensity(dist, h) == -0.26);
        assertTrue(PartialKernelDensity.LOG_LINEAR.getDensity(dist, h) == -0.3011050927839216);
        assertTrue(PartialKernelDensity.LOG_COSINE.getDensity(dist, h) == -0.08582521637384073);

        /*
         * Now test Kernel norms...
         */
        h = 1.3;
        int d = 5;

        assertTrue(NearestNeighborHeapSearch.logKernelNorm(h, d,
                PartialKernelDensity.LOG_GAUSSIAN) == -5.906513988360818);
        assertTrue(NearestNeighborHeapSearch.logKernelNorm(h, d,
                PartialKernelDensity.LOG_TOPHAT) == -2.972672434613881);
        assertTrue(NearestNeighborHeapSearch.logKernelNorm(h, d,
                PartialKernelDensity.LOG_EPANECHNIKOV) == -1.7199094661185133);
        assertTrue(NearestNeighborHeapSearch.logKernelNorm(h, d,
                PartialKernelDensity.LOG_EXPONENTIAL) == -7.760164177395928);
        assertTrue(NearestNeighborHeapSearch.logKernelNorm(h, d,
                PartialKernelDensity.LOG_LINEAR) == -1.1809129653858264);
        assertTrue(NearestNeighborHeapSearch.logKernelNorm(h, d,
                PartialKernelDensity.LOG_COSINE) == -1.588674327991151);
    }

    @Test
    public void testEstimateKernelDensity() {
        final double h = 0.5, at = 0.0, rt = 1e-8;
        final KDTree k = new KDTree(IRIS);

        /*
         *  GAUSSIAN
         */
        double[] exp = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_GAUSSIAN, at, rt, false);
        double[] log = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_GAUSSIAN, at, rt, true);
        double[] expected_exp = new double[] { 12.28012213, 10.45762713, 10.79173934, 9.91100044, 11.83821592,
                7.37523119, 10.15806662, 12.78705716, 6.73568504, 11.16047533, 9.59242696, 11.78397871, 10.05861109,
                5.13144797, 3.79949718, 2.7569215, 7.40977567, 12.2664965, 5.67837233, 10.22914483, 9.25843299,
                10.81676261, 6.22015249, 10.04296168, 8.6977775, 10.0373372, 11.93847251, 11.9609578, 11.85992468,
                10.87558964, 10.90385856, 9.69118458, 6.39025706, 4.89462441, 11.16047533, 10.66063678, 8.4050554,
                11.16047533, 7.16013574, 12.57084638, 11.81003215, 2.28644394, 7.79668271, 9.71624306, 7.34342451,
                10.21286369, 9.97460948, 10.33002923, 10.38942623, 12.55079235, 4.49638551, 8.32967062, 6.11506094,
                5.93506098, 9.26420516, 9.14402071, 8.32269348, 2.2884424, 7.84879076, 5.06544992, 2.1216892,
                8.87592108, 4.46667403, 10.6051664, 4.99393933, 6.30282393, 8.16600372, 7.6139944, 5.50766354,
                6.93374452, 8.06491113, 7.72412736, 8.39118297, 9.05573515, 8.0914415, 7.38763938, 6.64913168,
                8.64158676, 10.68749816, 4.44914895, 5.90974638, 5.2120019, 7.77505348, 8.99693593, 6.34485692,
                6.61439069, 7.6741988, 5.84235461, 7.6827928, 7.16366372, 7.25664652, 10.39958075, 8.17091265,
                2.4916162, 8.7106951, 8.17098773, 8.98842272, 9.23702581, 1.86565956, 8.88946089, 3.52157191,
                7.4583668, 5.2160347, 7.55914067, 6.82334224, 2.29047669, 2.06620082, 2.98708186, 4.23827541,
                2.29040748, 8.65940973, 8.90651813, 7.86138621, 5.52498516, 4.34094775, 7.39080672, 8.51477518,
                1.15631644, 1.23170302, 5.02624612, 6.22942732, 5.97238892, 1.87082373, 10.01809726, 6.82327907,
                3.89462823, 10.46733863, 10.30262734, 7.61239032, 3.63114801, 3.19895026, 1.15852059, 7.23468987,
                9.47646535, 4.19472305, 2.29812615, 5.17062191, 8.23661115, 9.89631121, 7.19670567, 6.7522757,
                5.4892544, 7.4583668, 5.75394048, 5.42296292, 7.0839312, 7.93392041, 9.57929288, 5.50956848,
                8.60444823 };

        double[] expected_log = new double[] { 2.50798187, 2.34733158, 2.37878097, 2.2936453, 2.47133294,
                1.99812725, 2.31826813, 2.5484335, 1.90741952, 2.41237855, 2.26097393, 2.46674087, 2.30842909,
                1.63538788, 1.33486874, 1.01411466, 2.00280016, 2.50687168, 1.73666463, 2.32524098, 2.22553481,
                2.38109702, 1.82779442, 2.30687206, 2.16306753, 2.30631186, 2.47976617, 2.48164783, 2.47316504,
                2.38652079, 2.38911672, 2.27121667, 1.8547745, 1.58813754, 2.41237855, 2.36655815, 2.12883336,
                2.41237855, 1.96852894, 2.53138035, 2.46894935, 0.82699775, 2.05369835, 2.27379903, 1.99380529,
                2.32364807, 2.30004281, 2.33505511, 2.34078858, 2.5297838, 1.50327385, 2.11982391, 1.81075473,
                1.7808773, 2.22615807, 2.21310019, 2.11898594, 0.82787141, 2.06035948, 1.62244296, 0.75221256,
                2.18334211, 1.49664407, 2.36134128, 1.60822504, 1.84099778, 2.09997965, 2.02998792, 1.70614049,
                1.9364, 2.08752269, 2.04434885, 2.12718151, 2.20339828, 2.0908069, 1.99980825, 1.89448627,
                2.15658622, 2.36907466, 1.49271283, 1.77660292, 1.65096402, 2.05092034, 2.19688407, 1.84764455,
                1.88924768, 2.0378639, 1.7651339, 2.03898313, 1.96902154, 1.98191781, 2.34176549, 2.10058061,
                0.91293158, 2.16455159, 2.1005898, 2.19593738, 2.22321995, 0.62361464, 2.1848664, 1.25890746,
                2.00933646, 1.65173748, 2.02275752, 1.92034942, 0.82875996, 0.72571157, 1.09429695, 1.44415644,
                0.82872974, 2.15864656, 2.18678338, 2.06196295, 1.70928056, 1.4680927, 2.00023689, 2.14180291,
                0.14523947, 0.20839778, 1.61467341, 1.82928441, 1.787147, 0.62637883, 2.30439318, 1.92034016,
                1.35959823, 2.3482598, 2.33239894, 2.02977723, 1.28954886, 1.16282271, 0.14714384, 1.97888749,
                2.24881139, 1.43382732, 0.83209407, 1.64299297, 2.10858899, 2.29216208, 1.97362338, 1.90987959,
                1.70279244, 2.00933646, 1.74988492, 1.69064233, 1.95782901, 2.07114729, 2.25960378, 1.7064863,
                2.15227931 };

        assertTrue(VecUtils.equalsWithTolerance(exp, expected_exp, 1e-8));
        assertTrue(VecUtils.equalsWithTolerance(log, expected_log, 1e-8));

        /*
         * TOPHAT
         */
        exp = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_TOPHAT, at, rt, false);
        log = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_TOPHAT, at, rt, true);
        expected_exp = new double[] { 81.05694691, 61.60327965, 71.33011328, 58.36100178, 61.60327965, 35.66505664,
                64.84555753, 106.99516993, 35.66505664, 74.57239116, 58.36100178, 97.2683363, 64.84555753,
                16.21138938, 9.72683363, 6.48455575, 35.66505664, 81.05694691, 12.96911151, 61.60327965, 42.1496124,
                71.33011328, 6.48455575, 58.36100178, 25.93822301, 64.84555753, 94.02605842, 71.33011328,
                84.29922479, 71.33011328, 71.33011328, 55.1187239, 25.93822301, 19.45366726, 74.57239116,
                71.33011328, 35.66505664, 74.57239116, 35.66505664, 100.51061417, 77.81466904, 3.24227788,
                35.66505664, 51.87644602, 22.69594514, 64.84555753, 61.60327965, 68.08783541, 71.33011328,
                94.02605842, 16.21138938, 38.90733452, 16.21138938, 22.69594514, 45.39189027, 45.39189027,
                29.18050089, 12.96911151, 29.18050089, 6.48455575, 9.72683363, 38.90733452, 6.48455575, 45.39189027,
                9.72683363, 25.93822301, 29.18050089, 38.90733452, 6.48455575, 35.66505664, 19.45366726,
                32.42277877, 25.93822301, 19.45366726, 32.42277877, 29.18050089, 25.93822301, 19.45366726,
                38.90733452, 19.45366726, 25.93822301, 19.45366726, 45.39189027, 38.90733452, 16.21138938,
                16.21138938, 35.66505664, 6.48455575, 38.90733452, 42.1496124, 25.93822301, 42.1496124, 45.39189027,
                12.96911151, 48.63416815, 42.1496124, 51.87644602, 38.90733452, 9.72683363, 45.39189027, 6.48455575,
                25.93822301, 16.21138938, 22.69594514, 29.18050089, 6.48455575, 3.24227788, 9.72683363, 3.24227788,
                3.24227788, 22.69594514, 32.42277877, 42.1496124, 12.96911151, 6.48455575, 22.69594514, 29.18050089,
                6.48455575, 6.48455575, 6.48455575, 29.18050089, 19.45366726, 9.72683363, 32.42277877, 25.93822301,
                16.21138938, 38.90733452, 38.90733452, 29.18050089, 6.48455575, 12.96911151, 6.48455575,
                29.18050089, 25.93822301, 3.24227788, 3.24227788, 16.21138938, 22.69594514, 42.1496124, 25.93822301,
                35.66505664, 12.96911151, 25.93822301, 25.93822301, 22.69594514, 25.93822301, 19.45366726,
                42.1496124, 9.72683363, 29.18050089 };

        expected_log = new double[] { 4.39515196, 4.12071511, 4.26731858, 4.06664789, 4.12071511, 3.5741714,
                4.1720084, 4.67278369, 3.5741714, 4.31177035, 4.06664789, 4.57747351, 4.1720084, 2.78571404,
                2.27488842, 1.86942331, 3.5741714, 4.39515196, 2.56257049, 4.12071511, 3.74122549, 4.26731858,
                1.86942331, 4.06664789, 3.25571767, 4.1720084, 4.54357196, 4.26731858, 4.43437267, 4.26731858,
                4.26731858, 4.00948948, 3.25571767, 2.9680356, 4.31177035, 4.26731858, 3.5741714, 4.31177035,
                3.5741714, 4.61026334, 4.35432996, 1.17627613, 3.5741714, 3.94886485, 3.12218628, 4.1720084,
                4.12071511, 4.22079857, 4.26731858, 4.54357196, 2.78571404, 3.66118278, 2.78571404, 3.12218628,
                3.81533346, 3.81533346, 3.37350071, 2.56257049, 3.37350071, 1.86942331, 2.27488842, 3.66118278,
                1.86942331, 3.81533346, 2.27488842, 3.25571767, 3.37350071, 3.66118278, 1.86942331, 3.5741714,
                2.9680356, 3.47886122, 3.25571767, 2.9680356, 3.47886122, 3.37350071, 3.25571767, 2.9680356,
                3.66118278, 2.9680356, 3.25571767, 2.9680356, 3.81533346, 3.66118278, 2.78571404, 2.78571404,
                3.5741714, 1.86942331, 3.66118278, 3.74122549, 3.25571767, 3.74122549, 3.81533346, 2.56257049,
                3.88432633, 3.74122549, 3.94886485, 3.66118278, 2.27488842, 3.81533346, 1.86942331, 3.25571767,
                2.78571404, 3.12218628, 3.37350071, 1.86942331, 1.17627613, 2.27488842, 1.17627613, 1.17627613,
                3.12218628, 3.47886122, 3.74122549, 2.56257049, 1.86942331, 3.12218628, 3.37350071, 1.86942331,
                1.86942331, 1.86942331, 3.37350071, 2.9680356, 2.27488842, 3.47886122, 3.25571767, 2.78571404,
                3.66118278, 3.66118278, 3.37350071, 1.86942331, 2.56257049, 1.86942331, 3.37350071, 3.25571767,
                1.17627613, 1.17627613, 2.78571404, 3.12218628, 3.74122549, 3.25571767, 3.5741714, 2.56257049,
                3.25571767, 3.25571767, 3.12218628, 3.25571767, 2.9680356, 3.74122549, 2.27488842, 3.37350071 };

        assertTrue(VecUtils.equalsWithTolerance(exp, expected_exp, 1e-8));
        assertTrue(VecUtils.equalsWithTolerance(log, expected_log, 1e-8));

        /*
         * EPANECHNIKOV
         */
        exp = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_EPANECHNIKOV, at, rt, false);
        log = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_EPANECHNIKOV, at, rt, true);
        expected_exp = new double[] { 136.56474416, 107.77331662, 118.27829694, 112.05312341, 114.77663683,
                45.91065473, 69.25505544, 161.07636491, 49.80138818, 126.44883719, 90.26501608, 124.50347046,
                110.88590338, 28.79142754, 14.00664043, 14.39571377, 38.90733452, 136.95381751, 20.23181395,
                96.10111626, 53.69212164, 109.71868334, 11.28312701, 57.58285509, 28.79142754, 89.87594274,
                121.77995704, 135.39752412, 128.39420391, 119.83459032, 122.94717708, 72.3676422, 31.51494096,
                27.23513416, 126.44883719, 91.04316277, 48.63416815, 126.44883719, 59.13914847, 152.12767797,
                109.32961, 9.72683363, 60.69544185, 52.91397495, 25.67884078, 107.77331662, 84.81798925,
                113.22034345, 108.94053665, 157.5747048, 26.06791413, 48.2450948, 34.23845438, 36.1838211,
                48.63416815, 55.24841502, 29.18050089, 24.1225474, 49.41231484, 13.61756708, 15.95200715,
                49.41231484, 10.11590697, 48.63416815, 14.39571377, 44.35436135, 39.29640786, 50.57953487,
                16.73015384, 64.5861753, 32.29308765, 36.96196779, 30.34772092, 31.51494096, 45.13250804,
                52.5249016, 34.62752772, 28.4023542, 51.74675491, 22.95532737, 47.07787477, 35.40567441,
                66.53154203, 42.01992128, 22.17718068, 19.45366726, 57.97192843, 16.73015384, 53.69212164,
                61.47358854, 31.90401431, 53.69212164, 67.69876206, 27.23513416, 77.42559569, 59.52822181,
                78.20374238, 47.07787477, 17.50830053, 82.87262252, 12.45034705, 43.96528801, 21.78810733,
                36.57289445, 38.51826117, 16.73015384, 9.72683363, 19.06459391, 9.72683363, 9.72683363, 28.79142754,
                38.90733452, 48.2450948, 29.18050089, 10.11590697, 34.23845438, 42.79806797, 12.83942039,
                12.83942039, 12.0612737, 50.57953487, 28.79142754, 19.8427406, 48.2450948, 40.85270124, 22.17718068,
                48.63416815, 52.5249016, 43.57621466, 14.78478712, 19.45366726, 12.83942039, 33.84938103,
                30.34772092, 9.72683363, 9.72683363, 26.06791413, 35.01660107, 50.96860822, 39.29640786,
                52.13582826, 22.95532737, 43.96528801, 41.63084793, 35.79474776, 40.4636279, 29.18050089,
                47.46694811, 23.34440071, 44.7434347 };

        expected_log = new double[] { 4.91679882, 4.6800301, 4.7730403, 4.71897308, 4.74298795, 3.82669722,
                4.23779615, 5.08187857, 3.90804286, 4.83983778, 4.50274997, 4.82433359, 4.70850178, 3.36007769,
                2.63953153, 2.66693051, 3.66118278, 4.91964377, 3.00725631, 4.56540093, 3.98326628, 4.69791967,
                2.42330842, 4.05322487, 3.36007769, 4.49843031, 4.80221579, 4.90821507, 4.85510525, 4.78611238,
                4.81175481, 4.28175927, 3.45046175, 3.30450784, 4.83983778, 4.51133371, 3.88432633, 4.83983778,
                4.07989312, 5.02472015, 4.69436726, 2.27488842, 4.1058686, 3.96866748, 3.24566734, 4.6800301,
                4.44050766, 4.72933586, 4.6908022, 5.05989966, 3.26070521, 3.87629416, 3.53334941, 3.58861209,
                3.88432633, 4.01183965, 3.37350071, 3.18314698, 3.90019968, 2.61136066, 2.76958466, 3.90019968,
                2.31410913, 3.88432633, 2.66693051, 3.79221104, 3.67113311, 3.92354705, 2.81721271, 4.16800038,
                3.4748532, 3.60988949, 3.41272142, 3.45046175, 3.80960279, 3.96128737, 3.54464896, 3.34647204,
                3.94636172, 3.13355004, 3.85180314, 3.5668721, 4.19767615, 3.73814382, 3.09906386, 2.9680356,
                4.0599589, 2.81721271, 3.98326628, 4.11860763, 3.46273184, 3.98326628, 4.21506789, 3.30450784,
                4.34931742, 4.08645052, 4.3593175, 3.85180314, 2.86267508, 4.41730476, 2.5217485, 3.78340041,
                3.08136429, 3.59930738, 3.65113245, 2.81721271, 2.27488842, 2.94783289, 2.27488842, 2.27488842,
                3.36007769, 3.66118278, 3.87629416, 3.37350071, 2.31410913, 3.53334941, 3.75649296, 2.55252016,
                2.55252016, 2.4899998, 3.92354705, 3.36007769, 2.98783823, 3.87629416, 3.70997295, 3.09906386,
                3.88432633, 3.96128737, 3.77451147, 2.69359875, 2.9680356, 2.55252016, 3.52192071, 3.41272142,
                2.27488842, 2.27488842, 3.26070521, 3.55582227, 3.93120992, 3.67113311, 3.95385239, 3.13355004,
                3.78340041, 3.72884143, 3.57780117, 3.70040349, 3.37350071, 3.86003364, 3.15035716, 3.80094472 };

        assertTrue(VecUtils.equalsWithTolerance(exp, expected_exp, 1e-8));
        assertTrue(VecUtils.equalsWithTolerance(log, expected_log, 1e-8));

        /*
         * EXPONENTIAL
         */
        exp = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_EXPONENTIAL, at, rt, false);
        log = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_EXPONENTIAL, at, rt, true);
        expected_exp = new double[] { 2.81678649, 2.44420675, 2.46085198, 2.32334735, 2.66174388, 1.75833401,
                2.26258436, 2.90604716, 1.66101735, 2.65197964, 2.21241444, 2.61487079, 2.36662588, 1.3429302,
                1.08582846, 0.87810901, 1.74938498, 2.80594815, 1.43966768, 2.33798563, 2.09649275, 2.44929755,
                1.52013833, 2.24654141, 1.97405298, 2.3113064, 2.66193891, 2.73534441, 2.69575639, 2.49807964,
                2.5396906, 2.18998074, 1.57014777, 1.29623897, 2.65197964, 2.36627864, 1.93596341, 2.65197964,
                1.76027591, 2.87279821, 2.63973029, 0.80072927, 1.86171852, 2.17588605, 1.74625427, 2.36335638,
                2.27009911, 2.39545832, 2.38680074, 2.84340518, 1.47476553, 2.20254506, 1.80134266, 1.64905292,
                2.37559262, 2.3070159, 2.1904533, 0.84587259, 2.13461142, 1.43116099, 0.76391881, 2.25300808,
                1.36079879, 2.61769873, 1.43585157, 1.84406234, 2.13951774, 1.98781592, 1.67412091, 1.86701684,
                2.16352523, 2.00735168, 2.20720179, 2.29847941, 2.13683403, 2.06054352, 1.89680547, 2.22597632,
                2.60765944, 1.31116214, 1.63714073, 1.46722224, 2.04579549, 2.32608004, 1.78748257, 1.86802593,
                2.10754762, 1.70910096, 2.04791726, 1.91072784, 1.92581394, 2.58256346, 2.1324481, 0.88450342,
                2.25258946, 2.15077007, 2.33978932, 2.32637351, 0.77329659, 2.31090614, 1.12911703, 2.11289612,
                1.45396034, 2.02929272, 1.81955277, 0.73616377, 0.91562102, 0.97918172, 1.36387393, 0.82200445,
                2.21816716, 2.26983463, 2.0555581, 1.70486348, 1.42975657, 1.96151767, 2.21193959, 0.46586307,
                0.46478071, 1.59243667, 1.70298342, 1.7806412, 0.63135066, 2.52095548, 1.81997945, 1.210823,
                2.6026903, 2.58550811, 2.02807087, 1.21502588, 1.03827884, 0.47923644, 1.94235239, 2.39078077,
                1.42544878, 0.78684905, 1.50742558, 2.15781815, 2.5164454, 1.93006564, 1.81774451, 1.62710122,
                2.11289612, 1.58247195, 1.52460124, 1.92625489, 2.12969908, 2.39554574, 1.60979498, 2.25995104 };

        expected_log = new double[] { 1.03559669, 0.89372063, 0.90050762, 0.84300897, 0.9789815, 0.56436677,
                0.81650768, 1.06679379, 0.50743028, 0.9753064, 0.79408443, 0.96121468, 0.86146526, 0.29485394,
                0.08234326, -0.12998453, 0.55926429, 1.0317415, 0.36441231, 0.84928972, 0.74026583, 0.89580127,
                0.41880134, 0.80939188, 0.68008878, 0.8378129, 0.97905477, 1.00625736, 0.99167883, 0.91552229,
                0.93204226, 0.78389275, 0.45116974, 0.25946697, 0.9753064, 0.86131853, 0.66060509, 0.9753064,
                0.56547056, 1.05528654, 0.97067675, -0.22223238, 0.6215, 0.77743596, 0.55747307, 0.8600828,
                0.81982349, 0.87357458, 0.86995387, 1.04500234, 0.38849901, 0.78961354, 0.58853231, 0.50020114,
                0.86524693, 0.83595487, 0.78410851, -0.16738653, 0.75828462, 0.358486, -0.26929377, 0.81226625,
                0.30807188, 0.96229558, 0.3617581, 0.61197093, 0.76058045, 0.68703651, 0.5152882, 0.62434188,
                0.77173894, 0.69681628, 0.79172556, 0.83224778, 0.75932531, 0.72296979, 0.64017114, 0.80019562,
                0.95845305, 0.27091387, 0.49295126, 0.38337098, 0.7157867, 0.84418446, 0.58080824, 0.62488222,
                0.74552501, 0.53596748, 0.7168233, 0.64748424, 0.6553487, 0.94878249, 0.75727066, -0.12272889,
                0.81208043, 0.76582595, 0.85006089, 0.84431062, -0.25709261, 0.83763972, 0.12143594, 0.74805957,
                0.3742911, 0.70768732, 0.59859074, -0.30630268, -0.08815273, -0.02103803, 0.31032913, -0.19600947,
                0.79668125, 0.81970698, 0.72054739, 0.53348504, 0.3575042, 0.6737185, 0.79386977, -0.76386354,
                -0.76618957, 0.46526534, 0.53238167, 0.57697352, -0.45989384, 0.92463799, 0.59882521, 0.1913003,
                0.95654564, 0.94992205, 0.70708503, 0.19476537, 0.03756438, -0.7355612, 0.66389981, 0.87162,
                0.3544867, -0.23971885, 0.41040328, 0.76909759, 0.92284735, 0.65755401, 0.59759645, 0.48680004,
                0.74805957, 0.45898815, 0.4217329, 0.65557765, 0.75598069, 0.87361107, 0.47610683, 0.81534315 };

        assertTrue(VecUtils.equalsWithTolerance(exp, expected_exp, 1e-8));
        assertTrue(VecUtils.equalsWithTolerance(log, expected_log, 1e-8));

        /*
         * LINEAR
         */
        exp = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_LINEAR, at, rt, false);
        log = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_LINEAR, at, rt, true);
        expected_exp = new double[] { 160.9282694, 127.53018651, 130.14235947, 127.06395876, 132.09672819,
                50.65924756, 75.57647598, 179.90963654, 58.10252844, 154.52998575, 98.87756863, 131.51181617,
                129.0150588, 36.08103914, 20.05828275, 20.73257963, 43.47854467, 160.11680757, 26.32721888,
                107.17710239, 58.7694493, 120.17478496, 17.56479497, 63.23096165, 34.44799898, 101.58528645,
                131.42827689, 156.00513214, 146.21446149, 134.0154519, 142.01816604, 77.61149955, 36.76669062,
                32.87159631, 154.52998575, 96.51714225, 54.07440652, 154.52998575, 70.24275874, 173.00716974,
                124.08736307, 16.21138938, 70.20707524, 58.95239275, 31.16071744, 121.51159135, 93.91208357,
                128.14529958, 120.81621787, 174.91680235, 32.73452138, 53.63978727, 41.82357635, 43.89184638,
                53.92778303, 60.38366728, 35.08006041, 32.84501217, 57.6022944, 19.86549055, 22.08598522,
                54.33870995, 16.53892596, 58.11028608, 20.37844064, 52.92242602, 46.16945232, 56.92813616,
                23.84451782, 74.68419345, 39.63789851, 41.95799374, 35.34634228, 38.74351531, 52.36152552,
                63.04288352, 40.34582251, 34.0039662, 59.29449178, 28.4406359, 58.0248369, 44.60769876, 75.06894691,
                46.66720883, 29.77086576, 25.10401887, 64.17972879, 23.84451782, 64.25783636, 69.54657107,
                38.15913986, 62.91758961, 77.74249696, 36.01279683, 87.08271188, 70.25392351, 91.88226395,
                53.6381145, 23.51959171, 95.98586896, 18.66695873, 57.31741601, 27.30430924, 44.31795605,
                44.22810104, 23.84451782, 16.21138938, 25.92314559, 16.21138938, 16.21138938, 35.49515485,
                43.55055473, 54.76906451, 36.93561646, 16.53892596, 39.87542221, 51.72322503, 19.05452461,
                19.05452461, 18.29001715, 58.62218765, 35.26711472, 26.68765305, 56.23803193, 46.55365872,
                27.9276858, 56.76518959, 61.99328508, 52.76666665, 21.19119874, 26.20168155, 19.05452461,
                43.21052816, 35.70999588, 16.21138938, 16.21138938, 32.66915064, 44.19686106, 60.37715009,
                46.61929426, 58.6886308, 30.00580075, 57.31741601, 48.89422033, 42.90669119, 46.59416438,
                35.71995611, 53.43249648, 30.96540812, 51.08482935 };

        expected_log = new double[] { 5.08095873, 4.84835309, 4.86862892, 4.84469057, 4.88353444, 3.92512179,
                4.32514507, 5.19245471, 4.06220918, 5.04038816, 4.5938824, 4.8790967, 4.85992913, 3.5857675,
                2.99864217, 3.03170636, 3.77226759, 5.0759036, 3.27060334, 4.67448263, 4.07362215, 4.78894722,
                2.86589661, 4.14679408, 3.53945091, 4.62089871, 4.87846128, 5.04988891, 4.98507446, 4.89795511,
                4.95595498, 4.35171561, 3.60459229, 3.49260895, 5.04038816, 4.56972063, 3.990361, 5.04038816,
                4.25195722, 5.15333304, 4.82098586, 2.78571404, 4.25144909, 4.07673022, 3.43915825, 4.80000966,
                4.54235906, 4.85316477, 4.79427053, 5.16431045, 3.48843022, 3.98229109, 3.73346021, 3.78172857,
                3.9876458, 4.10071866, 3.55763289, 3.4917999, 4.0535624, 2.98898408, 3.09494325, 3.99523686,
                2.80571675, 4.06234269, 3.01447751, 3.96882718, 3.83231837, 4.0417897, 3.17155433, 4.31326847,
                3.67978569, 3.73666897, 3.56519492, 3.65696339, 3.95817208, 4.14381519, 3.69748786, 3.52647717,
                4.08251641, 3.34781896, 4.06087114, 3.79790646, 4.31840698, 3.84304175, 3.39353026, 3.22302795,
                4.16168741, 3.17155433, 4.16290368, 4.24199662, 3.64176531, 4.14182577, 4.35340204, 3.58387434,
                4.46685838, 4.25211616, 4.52050802, 3.98225991, 3.15783376, 4.56420098, 2.92675505, 4.04860452,
                3.30704454, 3.79138992, 3.78936036, 3.17155433, 2.78571404, 3.25513622, 2.78571404, 2.78571404,
                3.5693962, 3.77392244, 4.00312552, 3.6091763, 2.80571675, 3.68576015, 3.94590691, 2.94730459,
                2.94730459, 2.9063554, 4.07111325, 3.56295094, 3.28420103, 4.02959325, 3.8406056, 3.32961852,
                4.03892328, 4.12702607, 3.96587968, 3.05358594, 3.26582359, 2.94730459, 3.76608417, 3.57543065,
                2.78571404, 2.78571404, 3.48643123, 3.78865377, 4.10061072, 3.8420145, 4.07224602, 3.40139072,
                4.04860452, 3.8896592, 3.75902779, 3.84147531, 3.57570953, 3.97841911, 3.43287071, 3.93348757 };

        assertTrue(VecUtils.equalsWithTolerance(exp, expected_exp, 1e-8));
        assertTrue(VecUtils.equalsWithTolerance(log, expected_log, 1e-8));
        assertTrue(new DoubleSeries(log, Inequality.NOT_EQUAL_TO, Double.NaN).all());

        /*
         * COSINE
         */
        exp = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_COSINE, at, rt, false);
        log = k.kernelDensity(IRIS.getData(), h, PartialKernelDensity.LOG_COSINE, at, rt, true);

        // all should be nan...
        assertTrue(new DoubleSeries(exp, Inequality.EQUAL_TO, Double.NaN).all());
        assertTrue(new DoubleSeries(log, Inequality.EQUAL_TO, Double.NaN).all());
    }

    // ================== constructor tests
    @Test
    public void testConst1() {
        Array2DRowRealMatrix A = new Array2DRowRealMatrix(a);
        Loggable log = null;

        // test kd constructors
        KDTree kd = new KDTree(A);
        kd = new KDTree(A, 5);
        kd = new KDTree(A, 5, Distance.EUCLIDEAN);
        assertTrue(kd.getLeafSize() == 5);
        kd = new KDTree(A, Distance.EUCLIDEAN);
        kd = new KDTree(A, log);
        assertTrue(kd.logger == null);
        kd = new KDTree(A, 5, Distance.EUCLIDEAN, null);

        BallTree ball = new BallTree(A);
        ball = new BallTree(A, 5);
        ball = new BallTree(A, 5, Distance.EUCLIDEAN);
        assertTrue(5 == ball.getLeafSize());
        ball = new BallTree(A, Distance.EUCLIDEAN);
        ball = new BallTree(A, log);
        assertTrue(ball.logger == null);
        ball = new BallTree(A, 5, Distance.EUCLIDEAN, null);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testConstIAE1() {
        Array2DRowRealMatrix A = new Array2DRowRealMatrix(a);
        new KDTree(A, 0);
    }

    // Create anonymous DistanceMetric class to test
    @Test
    public void testConst2() {
        Array2DRowRealMatrix A = new Array2DRowRealMatrix(a);
        KDTree kd = new KDTree(A, new DistanceMetric() {
            private static final long serialVersionUID = 6792348831585297421L;

            @Override
            public double getDistance(final double[] a, final double[] b) {
                return 0.0;
            }

            @Override
            public double getP() {
                return 0.0;
            }

            @Override
            public double getPartialDistance(final double[] a, final double[] b) {
                return getDistance(a, b);
            }

            @Override
            public double partialDistanceToDistance(double d) {
                return d;
            }

            @Override
            public double distanceToPartialDistance(double d) {
                return d;
            }

            @Override
            public String getName() {
                return "Test anonymous DistanceMetric";
            }
        });

        assertTrue(kd.getMetric().equals(Distance.EUCLIDEAN));
    }

    private static void passByRef(MutableDouble md, double x) {
        md.value = x;
    }

    @Test
    public void testMutableDouble() {
        MutableDouble md = new MutableDouble(145d);
        passByRef(md, 15d);
        assertTrue(md.value == 15d);
        assertTrue(md.compareTo(14d) == 1);
        assertTrue(new MutableDouble().value == 0d);
    }

    @Test
    public void testNodeDataContainerClass() {
        // Test def constructor
        NodeData node = new NodeData();
        assertTrue(node.idx_start == 0);
        assertTrue(node.idx_end == 0);
        assertTrue(!node.is_leaf);
        assertTrue(node.radius == 0.0);

        // Test arg constructor
        node = new NodeData(1, 2, true, 5.9);
        assertTrue(node.idx_start == 1);
        assertTrue(node.idx_end == 2);
        assertTrue(node.is_leaf);
        assertTrue(node.radius == 5.9);

        // Test immutability
        NodeData node2 = node.copy();
        node2.idx_start = 15;
        node2.idx_end = 67;
        node2.is_leaf = false;
        node2.radius = 5.6;
        assertTrue(node.start() == 1);
        assertTrue(node.end() == 2);
        assertTrue(node.isLeaf());
        assertTrue(node.radius() == 5.9);

        // ensure won't throw exception
        node.toString();
    }

    @Test
    public void testGetterRefMutability() {
        Array2DRowRealMatrix A = new Array2DRowRealMatrix(a);
        KDTree kd = new KDTree(A);

        double[][] data = kd.getData();
        double[][] dataRef = kd.getDataRef();
        dataRef[0][0] = 150d;
        assertFalse(MatUtils.equalsExactly(kd.getDataRef(), data));

        double[][][] bounds = kd.getNodeBounds();
        double[][][] boundsRef = kd.getNodeBoundsRef();
        boundsRef[0][0][0] = 150;
        assertFalse(MatUtils.equalsExactly(kd.getNodeBoundsRef()[0], bounds[0]));

        int[] idcs = kd.getIndexArray();
        int[] idcsRef = kd.getIndexArrayRef();
        idcsRef[0] = 150;
        assertFalse(VecUtils.equalsExactly(kd.getIndexArrayRef(), idcs));

        NodeData[] nodes = kd.getNodeData();
        NodeData[] nodeRef = kd.getNodeDataRef();
        nodeRef[0].idx_end = 150;
        assertFalse(kd.getNodeDataRef()[0].idx_end == nodes[0].idx_end);
    }

    @Test
    public void testInstanceMethod() {
        Array2DRowRealMatrix A = new Array2DRowRealMatrix(a);
        KDTree kd = new KDTree(A);

        double[] b = new double[] { 0, 1, 2 };
        double[] c = new double[] { 3, 4, 5 };
        assertTrue(kd.dist(b, c) == Distance.EUCLIDEAN.getDistance(b, c));
        assertTrue(kd.rDist(b, c) == Distance.EUCLIDEAN.getPartialDistance(b, c));
        assertTrue(kd.rDistToDist(kd.rDist(b, c)) == Distance.EUCLIDEAN.partialDistanceToDistance(kd.rDist(b, c)));
        assertTrue(kd.getNumCalls() == 4);
        kd.resetNumCalls();

        assertTrue(kd.getNumCalls() == 0);
    }

    @Test
    public void testNodeFind() {
        Array2DRowRealMatrix A = new Array2DRowRealMatrix(a);
        KDTree kd = new KDTree(A);
        final int findNode = KDTree.findNodeSplitDim(a, kd.idx_array);
        assertTrue(findNode == 2);
    }

    @Test
    public void testSwap() {
        int[] ex = new int[] { 0, 1, 2 };
        KDTree.swap(ex, 0, 1);
        assertTrue(VecUtils.equalsExactly(ex, new int[] { 1, 0, 2 }));
    }

    @Test(expected = IllegalStateException.class)
    public void testNodeHeap1() {
        NodeHeap nh1 = new NodeHeap(0);
        assertTrue(nh1.data.length == 1); // picks max (size, 1)
        nh1 = new NodeHeap(2);
        assertTrue(nh1.data.length == 2);
        assertTrue(nh1.n == 0);
        nh1.clear();
        assertTrue(nh1.n == 0);
        assertTrue(null == nh1.peek());
        nh1.pop(); // throws the exception on empty heap
    }

    @Test
    public void testNodeHeapPushesPops() {
        NodeHeap heap = new NodeHeap(3);

        NodeHeapData h = new NodeHeapData(1.0, 0, 0);
        heap.push(new NodeHeapData(12.0, 1, 2));
        heap.push(new NodeHeapData(9.0, 4, 5));
        heap.push(new NodeHeapData(11.0, 9, -1));
        heap.push(h);

        assertTrue(heap.data.length == 8);
        assertTrue(heap.data[0].val == 1.0);
        assertTrue(heap.data[1].val == 9.0);
        assertTrue(heap.data[2].val == 11.0);
        assertTrue(heap.data[3].val == 12.0);

        assertTrue(heap.data[0].i1 == 0);
        assertTrue(heap.data[1].i1 == 4);
        assertTrue(heap.data[2].i1 == 9);
        assertTrue(heap.data[3].i1 == 1);

        assertTrue(heap.data[0].i2 == 0);
        assertTrue(heap.data[1].i2 == 5);
        assertTrue(heap.data[2].i2 == -1);
        assertTrue(heap.data[3].i2 == 2);

        assertTrue(heap.data[0].equals(new NodeHeapData(1.0, 0, 0)));
        assertTrue(heap.data[0].equals(h));
        assertFalse(heap.data[0].equals(new Integer(1)));
        assertTrue(heap.n == 4);

        assertTrue(heap.pop().equals(h));
        assertTrue(heap.data[0].val == 9.0);
        assertTrue(heap.data[1].val == 12.0);
        assertTrue(heap.data[2].val == 11.0);
        assertTrue(null == heap.data[3]);

        // Ensure no NPE
        heap.toString();
    }

    @Test
    public void testDualSwap() {
        double[] a = new double[] { 0, 1, 2 };
        int[] b = new int[] { 3, 4, 5 };
        Heap.dualSwap(a, b, 0, 1);
        assertTrue(VecUtils.equalsExactly(a, new double[] { 1, 0, 2 }));
        assertTrue(VecUtils.equalsExactly(b, new int[] { 4, 3, 5 }));
    }

    @Test
    public void testBigKD() {
        Array2DRowRealMatrix x = new Array2DRowRealMatrix(IRIS.getData(), false);
        KDTree kd = new KDTree(x);
        assertTrue(VecUtils.equalsExactly(kd.idx_array, new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
                14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
                39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 53, 57, 59, 60, 62, 64, 69, 71, 79, 80, 81, 82, 89, 92,
                93, 98, 99, 88, 67, 61, 94, 95, 96, 74, 97, 90, 87, 65, 75, 106, 86, 68, 54, 55, 73, 91, 56, 63, 78,
                51, 58, 66, 50, 84, 85, 138, 76, 70, 52, 121, 123, 126, 127, 72, 146, 77, 113, 119, 149, 109, 110,
                111, 112, 100, 114, 115, 116, 117, 118, 101, 120, 102, 122, 103, 124, 125, 104, 105, 128, 129, 130,
                131, 132, 133, 134, 135, 136, 137, 83, 139, 140, 141, 142, 143, 144, 145, 107, 147, 148, 108 }));
    }

    @Test
    public void testQuerySmall() {
        KDTree kd = new KDTree(new Array2DRowRealMatrix(a));
        assertTrue(VecUtils.equalsExactly(kd.idx_array, new int[] { 0, 1, 2 }));
        assertTrue(kd.node_bounds.length == 2);
        assertTrue(kd.node_bounds[0].length == 1);
        assertTrue(kd.node_bounds[1].length == 1);
        assertTrue(VecUtils.equalsExactly(kd.node_bounds[0][0], new double[] { 0.0, 0.0, 0.0, 2.0 }));
        assertTrue(VecUtils.equalsExactly(kd.node_bounds[1][0], new double[] { 5.0, 6.0, 7.0, 4.0 }));

        double[][] expectedDists = new double[][] { new double[] { 0.0 }, new double[] { 0.0 } };
        int[][] expectedIndices = new int[][] { new int[] { 0 }, new int[] { 1 } };

        Neighborhood neighb;
        boolean[] trueFalse = new boolean[] { true, false };
        for (boolean dualTree : trueFalse) {
            for (boolean sort : trueFalse) {
                neighb = new Neighborhood(
                        kd.query(new double[][] { new double[] { 0, 1, 0, 2 }, new double[] { 0, 0, 1, 2 } }, 1,
                                dualTree, sort));

                assertTrue(MatUtils.equalsExactly(expectedDists, neighb.getDistances()));
                assertTrue(MatUtils.equalsExactly(expectedIndices, neighb.getIndices()));
            }
        }
    }

    @Test
    public void testSimultaneousSort() {
        double[] dists = new double[] { 3.69675274, 2.89351805, 1.79065633, 0.44375205, 7.77409946, 7.08011014,
                8.41547227, 5.57512117, 8.85578907, 2.60367035 };

        int[] indices = new int[] { 4, 1, 0, 7, 6, 5, 8, 2, 3, 9 };

        NeighborsHeap.simultaneous_sort(dists, indices, dists.length);
        assertTrue(VecUtils.equalsExactly(dists, new double[] { 0.44375205, 1.79065633, 2.60367035, 2.89351805,
                3.69675274, 5.57512117, 7.08011014, 7.77409946, 8.41547227, 8.85578907 }));

        assertTrue(VecUtils.equalsExactly(indices, new int[] { 7, 0, 9, 1, 4, 2, 5, 6, 8, 3 }));

        dists = new double[] { 0.7, 0.1 };
        indices = new int[] { 2, 1 };
        NeighborsHeap.simultaneous_sort(dists, indices, dists.length);
        assertTrue(VecUtils.equalsExactly(dists, new double[] { 0.1, 0.7 }));

        assertTrue(VecUtils.equalsExactly(indices, new int[] { 1, 2 }));

        dists = new double[] { 0.7 };
        indices = new int[] { 2 };
        NeighborsHeap.simultaneous_sort(dists, indices, dists.length);
        assertTrue(VecUtils.equalsExactly(dists, new double[] { 0.7 }));

        assertTrue(VecUtils.equalsExactly(indices, new int[] { 2 }));

        dists = new double[] { 0.7, 0.1, 0.3 };
        indices = new int[] { 2, 1, 3 };
        NeighborsHeap.simultaneous_sort(dists, indices, dists.length);
        assertTrue(VecUtils.equalsExactly(dists, new double[] { 0.1, 0.3, 0.7 }));

        assertTrue(VecUtils.equalsExactly(indices, new int[] { 1, 3, 2 }));

        dists = new double[] { 0.3, 0.7, 0.1 };
        indices = new int[] { 2, 1, 3 };
        NeighborsHeap.simultaneous_sort(dists, indices, dists.length);
        assertTrue(VecUtils.equalsExactly(dists, new double[] { 0.1, 0.3, 0.7 }));

        assertTrue(VecUtils.equalsExactly(indices, new int[] { 3, 2, 1 }));
    }

    @Test
    public void testNeighborsHeap() {
        double[][] X = new double[][] { new double[] { 0.15464338, -0.26063195, -0.48111094 },
                new double[] { -0.95392127, 0.72765662, 0.46466226 },
                new double[] { 0.57011545, -1.53581033, 0.52009414 } };

        final int k = 1;
        NeighborsHeap heap = new NeighborsHeap(X.length, k);
        for (int i = 0; i < X.length; i++)
            for (int j = 0; j < X[0].length; j++)
                heap.push(i, X[i][j], j);
        Neighborhood neighb = new Neighborhood(heap.getArrays(true));

        double[][] dists = neighb.getDistances();
        int[][] inds = neighb.getIndices();

        assertTrue(MatUtils.equalsExactly(dists, new double[][] { new double[] { -0.48111094 },
                new double[] { -0.95392127 }, new double[] { -1.53581033 } }));

        assertTrue(MatUtils.equalsExactly(inds, new int[][] { new int[] { 2 }, new int[] { 0 }, new int[] { 1 } }));
    }

    @Test
    public void testNeighborHeapOrderInPlace() {
        double[][] X = new double[][] { new double[] { 0.15464338, -0.26063195, -0.48111094, 0.0002354, 1.12345 },
                new double[] { -0.95392127, 0.72765662, 0.46466226, -0.9128421, 5.12345 },
                new double[] { 0.57011545, -1.53581033, 0.52009414, 0.1958271, -4.3918 } };

        final int k = 3;
        NeighborsHeap heap = new NeighborsHeap(X.length, k);
        for (int i = 0; i < X.length; i++)
            for (int j = 0; j < X[0].length; j++)
                heap.push(i, X[i][j], j);
        Neighborhood neighb = new Neighborhood(heap.getArrays(true));

        double[][] dists = neighb.getDistances();
        int[][] inds = neighb.getIndices();

        assertTrue(MatUtils.equalsExactly(dists,
                new double[][] { new double[] { -0.48111094, -0.26063195, 0.0002354 },
                        new double[] { -0.95392127, -0.9128421, 0.46466226 },
                        new double[] { -4.3918, -1.53581033, 0.1958271 } }));

        assertTrue(MatUtils.equalsExactly(inds,
                new int[][] { new int[] { 2, 1, 3 }, new int[] { 0, 3, 2 }, new int[] { 4, 1, 3 } }));
    }

    @Test
    public void testNeighborHeapTwoAndLessLen() {
        double[][] X = new double[][] { new double[] { 0.15464338, -0.26063195 },
                new double[] { -0.95392127, 0.72765662 }, new double[] { 0.57011545, -1.53581033 } };

        int k = 1;
        NeighborsHeap heap = new NeighborsHeap(X.length, k);
        for (int i = 0; i < X.length; i++)
            for (int j = 0; j < X[0].length; j++)
                heap.push(i, X[i][j], j);
        Neighborhood neighb = new Neighborhood(heap.getArrays(true));

        double[][] dists = neighb.getDistances();
        int[][] inds = neighb.getIndices();

        assertTrue(MatUtils.equalsExactly(dists, new double[][] { new double[] { -0.26063195 },
                new double[] { -0.95392127 }, new double[] { -1.53581033 } }));

        assertTrue(MatUtils.equalsExactly(inds, new int[][] { new int[] { 1 }, new int[] { 0 }, new int[] { 1 } }));

        k = 2;
        heap = new NeighborsHeap(X.length, k);
        for (int i = 0; i < X.length; i++)
            for (int j = 0; j < X[0].length; j++)
                heap.push(i, X[i][j], j);
        neighb = new Neighborhood(heap.getArrays(true));

        dists = neighb.getDistances();
        inds = neighb.getIndices();

        assertTrue(MatUtils.equalsExactly(dists, new double[][] { new double[] { -0.26063195, 0.15464338 },
                new double[] { -0.95392127, 0.72765662 }, new double[] { -1.53581033, 0.57011545 } }));

        assertTrue(MatUtils.equalsExactly(inds,
                new int[][] { new int[] { 1, 0 }, new int[] { 0, 1 }, new int[] { 1, 0 } }));
    }

    @Test
    public void testNeighborHeapNoSortAndLargest() {
        double[][] X = new double[][] { new double[] { 0.15464338, -0.26063195, -0.48111094, 0.0002354, 1.12345 },
                new double[] { -0.95392127, 0.72765662, 0.46466226, -0.9128421, 5.12345 },
                new double[] { 0.57011545, -1.53581033, 0.52009414, 0.1958271, -4.3918 } };

        final int k = 3;
        NeighborsHeap heap = new NeighborsHeap(X.length, k);
        for (int i = 0; i < X.length; i++)
            for (int j = 0; j < X[0].length; j++)
                heap.push(i, X[i][j], j);
        Neighborhood neighb = new Neighborhood(heap.getArrays(false));

        double[][] dists = neighb.getDistances();
        for (int row = 0; row < dists.length; row++)
            assertTrue(heap.largest(row) == VecUtils.max(dists[row]));
    }

    @Test
    public void testDistToRDist() {
        double[] a = new double[] { 5, 0, 0 };
        double[] b = new double[] { 0, 0, 1 };
        KDTree kd = new KDTree(IRIS);
        assertTrue(kd.dist(a, b) == 5.0990195135927845);
        assertTrue(kd.rDistToDist(25.999999999999996) == kd.dist(a, b));
        assertTrue(Precision.equals(kd.rDist(a, b), 25.999999999999996, 1e-8));
        assertTrue(Precision.equals(kd.rDistToDist(kd.rDist(a, b)), kd.dist(a, b), 1e-8));
    }

    @Test
    public void testMinRDistDual() {
        Array2DRowRealMatrix X1 = IRIS;

        double[][] query = new double[10][];
        int idx = 0;
        for (double[] row : IRIS.getData()) {
            if (idx == query.length)
                break;
            query[idx++] = row; // copied implicitly
        }
        Array2DRowRealMatrix X2 = new Array2DRowRealMatrix(query, false);

        NearestNeighborHeapSearch tree1 = new KDTree(X1);
        NearestNeighborHeapSearch tree2 = new KDTree(X2);

        double dist = tree1.minRDistDual(tree1, 0, tree2, 0);
        assertTrue(0.0 == dist);
        dist = tree1.minRDistDual(tree1, 2, tree2, 0);
        assertTrue(7.930000000000001 == dist);

        tree1 = new BallTree(X1);
        tree2 = new BallTree(X2);

        dist = tree1.minRDistDual(tree1, 0, tree2, 0);
        assertTrue(0.0 == dist);

        dist = tree1.minRDistDual(tree1, 2, tree2, 0);
        // TODO: assertion
    }

    @Test
    public void testMinRDist() {
        Array2DRowRealMatrix X1 = IRIS;
        NearestNeighborHeapSearch tree1 = new KDTree(X1);
        double[] a = new double[] { 5.1, 3.5, 1.4, 0.2 };

        assertTrue(tree1.minRDist(tree1, 1, a) == 0);
        assertTrue(tree1.minRDist(tree1, 2, a) == 10.000000000000004);

        a = new double[] { 4.9, 3.0, 1.4, 0.2 };
        assertTrue(tree1.minRDist(tree1, 1, a) == 0);
        assertTrue(tree1.minRDist(tree1, 2, a) == 10.000000000000004);
    }

    @Test
    public void moreNodeHeapTests() {
        NodeHeap nh = new NodeHeap(10);
        assertTrue(nh.n == 0);

        nh.push(new NodeHeapData());
        assertTrue(nh.n == 1);

        nh.resize(15);
        assertTrue(nh.n == 1);

        nh.resize(2);
        assertTrue(nh.n == 1);

        // Now test some pushes...
        Random seed = new Random(5);
        NodeHeapData node;
        for (int i = 0; i < 10; i++) {

            node = new NodeHeapData(10.0 - i,
                    //seed.nextDouble() * seed.nextInt(40),
                    seed.nextInt(5), seed.nextInt(100));

            nh.push(node);
        }

        assertTrue(nh.n == 11);
        nh.pop();
        assertTrue(nh.n == 10);
        assertTrue(nh.peek().val == 1.0);

        nh.toString(); // Ensure does not throw NPE
    }

    @Test(expected = IllegalArgumentException.class)
    public void nodeHeapResizeUnder1() {
        NodeHeap nh = new NodeHeap(10);
        nh.resize(0); // Here is the exception
    }

    @Test
    public void testQueryBig() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        double[][] query = new double[10][];

        int idx = 0;
        for (double[] row : IRIS.getData()) {
            if (idx == query.length)
                break;
            query[idx++] = row; // copied implicitly
        }

        double[][] expectedDists = new double[][] { new double[] { 0., 0.1, 0.14142136 },
                new double[] { 0., 0.14142136, 0.14142136 }, new double[] { 0., 0.14142136, 0.24494897 },
                new double[] { 0., 0.14142136, 0.17320508 }, new double[] { 0., 0.14142136, 0.17320508 },
                new double[] { 0., 0.33166248, 0.34641016 }, new double[] { 0., 0.2236068, 0.26457513 },
                new double[] { 0., 0.1, 0.14142136 }, new double[] { 0., 0.14142136, 0.3 },
                new double[] { 0., 0., 0. } };

        int[][] expectedIndices = new int[][] { new int[] { 0, 17, 4 }, new int[] { 1, 45, 12 },
                new int[] { 2, 47, 3 }, new int[] { 3, 47, 29 }, new int[] { 4, 0, 17 }, new int[] { 5, 18, 10 },
                new int[] { 6, 47, 2 }, new int[] { 7, 39, 49 }, new int[] { 8, 38, 3 }, new int[] { 37, 9, 34 } };

        // Assert node data equal
        NodeData[] expectedNodeData = new NodeData[] { new NodeData(0, 150, false, 10.29635857961444),
                new NodeData(0, 75, true, 3.5263295365010903), new NodeData(75, 150, true, 4.506106967216822) };

        NodeData comparison;
        for (int i = 0; i < expectedNodeData.length; i++) {
            comparison = tree.node_data[i];
            comparison.toString(); // Just to make sure toString() doesn't create NPE
            assertTrue(comparison.equals(expectedNodeData[i]));
        }

        Neighborhood neighb;
        boolean[] trueFalse = new boolean[] { false, true };
        for (boolean dualTree : trueFalse) {

            neighb = tree.query(query, 3, dualTree, true);

            assertTrue(MatUtils.equalsWithTolerance(expectedDists, neighb.getDistances(), 1e-8));
            assertTrue(MatUtils.equalsExactly(expectedIndices, neighb.getIndices()));
        }
    }

    @Test
    public void testQueryRadiusNoSort() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        double[][] query = new double[10][];

        int idx = 0;
        for (double[] row : IRIS.getData()) {
            if (idx == query.length)
                break;
            query[idx++] = row; // copied implicitly
        }

        double[][] expectedNonSortedDists = new double[][] {
                new double[] { 0., 0.53851648, 0.50990195, 0.64807407, 0.14142136, 0.6164414, 0.51961524,
                        0.17320508, 0.46904158, 0.37416574, 0.37416574, 0.59160798, 0.54772256, 0.1, 0.74161985,
                        0.33166248, 0.43588989, 0.3, 0.64807407, 0.46904158, 0.59160798, 0.54772256, 0.31622777,
                        0.14142136, 0.14142136, 0.53851648, 0.53851648, 0.38729833, 0.6244998, 0.46904158,
                        0.37416574, 0.41231056, 0.46904158, 0.14142136, 0.17320508, 0.76811457, 0.45825757,
                        0.6164414, 0.59160798, 0.36055513, 0.58309519, 0.3, 0.2236068 },

                new double[] { 0.53851648, 0., 0.3, 0.33166248, 0.60827625, 0.50990195, 0.42426407, 0.50990195,
                        0.17320508, 0.45825757, 0.14142136, 0.678233, 0.54772256, 0.70710678, 0.76157731,
                        0.78102497, 0.55677644, 0.64807407, 0.2236068, 0.5, 0.59160798, 0.5, 0.34641016, 0.24494897,
                        0.678233, 0.17320508, 0.3, 0.78740079, 0.17320508, 0.50990195, 0.45825757, 0.52915026,
                        0.54772256, 0.678233, 0.14142136, 0.36055513, 0.31622777 },

                new double[] { 0.50990195, 0.3, 0., 0.24494897, 0.50990195, 0.26457513, 0.41231056, 0.43588989,
                        0.31622777, 0.37416574, 0.26457513, 0.5, 0.51961524, 0.75498344, 0.7, 0.50990195,
                        0.64807407, 0.64031242, 0.46904158, 0.50990195, 0.6164414, 0.54772256, 0.3, 0.33166248,
                        0.78102497, 0.31622777, 0.31622777, 0.31622777, 0.36055513, 0.48989795, 0.43588989, 0.3,
                        0.65574385, 0.26457513, 0.78102497, 0.14142136, 0.33166248 },

                new double[] { 0.64807407, 0.33166248, 0.24494897, 0., 0.64807407, 0.33166248, 0.5, 0.3, 0.31622777,
                        0.37416574, 0.26457513, 0.51961524, 0.65574385, 0.70710678, 0.64807407, 0.53851648,
                        0.42426407, 0.54772256, 0.72111026, 0.678233, 0.17320508, 0.2236068, 0.31622777, 0.50990195,
                        0.31622777, 0.3, 0.58309519, 0.60827625, 0.3, 0.7, 0.26457513, 0.14142136, 0.45825757 },

                new double[] { 0.14142136, 0.60827625, 0.50990195, 0.64807407, 0., 0.6164414, 0.45825757, 0.2236068,
                        0.52915026, 0.42426407, 0.34641016, 0.64031242, 0.54772256, 0.17320508, 0.79372539,
                        0.26457513, 0.53851648, 0.26457513, 0.56568542, 0.52915026, 0.57445626, 0.63245553,
                        0.34641016, 0.24494897, 0.28284271, 0.53851648, 0.57445626, 0.5, 0.55677644, 0.78102497,
                        0.52915026, 0.4472136, 0.51961524, 0.52915026, 0.24494897, 0.17320508, 0.72801099,
                        0.45825757, 0.58309519, 0.64031242, 0.3, 0.56568542, 0.33166248, 0.3 },

                new double[] { 0.6164414, 0.6164414, 0., 0.7, 0.34641016, 0.678233, 0.6164414, 0.4, 0.59160798,
                        0.33166248, 0.38729833, 0.53851648, 0.41231056, 0.678233, 0.64807407, 0.52915026,
                        0.64807407, 0.53851648, 0.45825757, 0.47958315, 0.60827625, 0.64807407, 0.7, 0.60827625,
                        0.37416574, 0.38729833, 0.36055513 },

                new double[] { 0.51961524, 0.50990195, 0.26457513, 0.33166248, 0.45825757, 0., 0.42426407,
                        0.54772256, 0.47958315, 0.3, 0.48989795, 0.6164414, 0.50990195, 0.64807407, 0.6, 0.45825757,
                        0.6244998, 0.54772256, 0.60827625, 0.45825757, 0.6244998, 0.60827625, 0.31622777,
                        0.42426407, 0.47958315, 0.5, 0.47958315, 0.46904158, 0.51961524, 0.42426407, 0.31622777,
                        0.54772256, 0.4472136, 0.678233, 0.2236068, 0.77459667, 0.42426407 },

                new double[] { 0.17320508, 0.42426407, 0.41231056, 0.5, 0.2236068, 0.7, 0.42426407, 0., 0.78740079,
                        0.33166248, 0.5, 0.2236068, 0.46904158, 0.7, 0.2, 0.42426407, 0.4472136, 0.37416574,
                        0.67082039, 0.38729833, 0.4472136, 0.41231056, 0.2236068, 0.2236068, 0.2236068, 0.37416574,
                        0.37416574, 0.4472136, 0.73484692, 0.33166248, 0.36055513, 0.54772256, 0.33166248,
                        0.74833148, 0.1, 0.24494897, 0.66332496, 0.42426407, 0.60827625, 0.46904158, 0.42426407,
                        0.45825757, 0.42426407, 0.14142136 },

                new double[] { 0.50990195, 0.43588989, 0.3, 0.54772256, 0.78740079, 0., 0.55677644, 0.67082039,
                        0.42426407, 0.34641016, 0.64031242, 0.46904158, 0.48989795, 0.55677644, 0.7, 0.55677644,
                        0.14142136, 0.6244998, 0.31622777, 0.42426407, 0.36055513, 0.72111026 },

                new double[] { 0.46904158, 0.17320508, 0.31622777, 0.31622777, 0.52915026, 0.47958315, 0.33166248,
                        0.55677644, 0., 0.78740079, 0.34641016, 0.17320508, 0.72801099, 0.5, 0.75498344, 0.6244998,
                        0.7, 0.77459667, 0.52915026, 0.51961524, 0.2, 0.4472136, 0.50990195, 0.4472136, 0.26457513,
                        0.17320508, 0.65574385, 0., 0.34641016, 0.75498344, 0., 0.55677644, 0.37416574, 0.5,
                        0.55677644, 0.65574385, 0.26457513, 0.74161985, 0.34641016, 0.72801099, 0.26457513 } };

        int[][] expectedNonSortedIndices = new int[][] {
                new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
                        28, 29, 30, 31, 32, 34, 35, 36, 37, 39, 40, 42, 43, 44, 45, 46, 47, 48, 49 },

                new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 17, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
                        31, 34, 35, 36, 37, 38, 39, 40, 42, 43, 45, 47, 49 },

                new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 17, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
                        31, 34, 35, 37, 38, 39, 40, 42, 43, 45, 46, 47, 49 },

                new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 17, 22, 23, 24, 25, 26, 27, 28, 29, 30, 34, 35,
                        37, 38, 39, 40, 42, 43, 45, 47, 49 },

                new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
                        28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 42, 43, 44, 45, 46, 47, 48, 49 },

                new int[] { 0, 4, 5, 7, 10, 14, 15, 16, 17, 18, 19, 20, 21, 23, 26, 27, 28, 31, 32, 33, 36, 39, 40,
                        43, 44, 46, 48 },

                new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 17, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
                        34, 35, 37, 38, 39, 40, 42, 43, 45, 46, 47, 48, 49 },

                new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
                        28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48, 49 },

                new int[] { 1, 2, 3, 6, 7, 8, 9, 11, 12, 13, 25, 29, 30, 34, 35, 37, 38, 41, 42, 45, 47, 49 },

                new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
                        29, 30, 31, 34, 35, 36, 37, 38, 39, 40, 42, 43, 45, 46, 47, 48, 49 } };

        Neighborhood neighb = tree.queryRadius(query, 0.8, false);

        // Just want to know that the total diff in matrices generated from sklearn and clust4j
        // is less than some arbitrarily low number, say one (rounding error).
        assertTrue(absDiffInMatrices(expectedNonSortedDists, neighb.getDistances()) < 1);

        int[][] indices = neighb.getIndices();
        for (int i = 0; i < expectedNonSortedIndices.length; i++)
            assertTrue(differenceInIdxArrays(expectedNonSortedIndices[i], indices[i]) <= 2);
    }

    @Test
    public void testQueryRadiusWithSort() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        double[][] query = new double[10][];

        int idx = 0;
        for (double[] row : IRIS.getData()) {
            if (idx == query.length)
                break;
            query[idx++] = row; // copied implicitly
        }

        double[][] expectedSortedDists = new double[][] {
                new double[] { 0., 0.1, 0.14142136, 0.14142136, 0.14142136, 0.14142136, 0.17320508, 0.17320508,
                        0.2236068, 0.3, 0.3, 0.31622777, 0.33166248, 0.36055513, 0.37416574, 0.37416574, 0.37416574,
                        0.38729833, 0.41231056, 0.43588989, 0.45825757, 0.46904158, 0.46904158, 0.46904158,
                        0.46904158, 0.50990195, 0.51961524, 0.53851648, 0.53851648, 0.53851648, 0.54772256,
                        0.54772256, 0.58309519, 0.59160798, 0.59160798, 0.59160798, 0.6164414, 0.6164414, 0.6244998,
                        0.64807407, 0.64807407, 0.74161985, 0.76811457 },
                new double[] { 0., 0.14142136, 0.14142136, 0.17320508, 0.17320508, 0.17320508, 0.2236068,
                        0.24494897, 0.3, 0.3, 0.31622777, 0.33166248, 0.34641016, 0.36055513, 0.42426407,
                        0.45825757, 0.45825757, 0.5, 0.5, 0.50990195, 0.50990195, 0.50990195, 0.52915026,
                        0.53851648, 0.54772256, 0.54772256, 0.55677644, 0.59160798, 0.60827625, 0.64807407,
                        0.678233, 0.678233, 0.678233, 0.70710678, 0.76157731, 0.78102497, 0.78740079 },
                new double[] { 0., 0.14142136, 0.24494897, 0.26457513, 0.26457513, 0.26457513, 0.3, 0.3, 0.3,
                        0.31622777, 0.31622777, 0.31622777, 0.31622777, 0.33166248, 0.33166248, 0.36055513,
                        0.37416574, 0.41231056, 0.43588989, 0.43588989, 0.46904158, 0.48989795, 0.5, 0.50990195,
                        0.50990195, 0.50990195, 0.50990195, 0.51961524, 0.54772256, 0.6164414, 0.64031242,
                        0.64807407, 0.65574385, 0.7, 0.75498344, 0.78102497, 0.78102497 },
                new double[] { 0., 0.14142136, 0.17320508, 0.2236068, 0.24494897, 0.26457513, 0.26457513, 0.3, 0.3,
                        0.3, 0.31622777, 0.31622777, 0.31622777, 0.33166248, 0.33166248, 0.37416574, 0.42426407,
                        0.45825757, 0.5, 0.50990195, 0.51961524, 0.53851648, 0.54772256, 0.58309519, 0.60827625,
                        0.64807407, 0.64807407, 0.64807407, 0.65574385, 0.678233, 0.7, 0.70710678, 0.72111026 },
                new double[] { 0., 0.14142136, 0.17320508, 0.17320508, 0.2236068, 0.24494897, 0.24494897,
                        0.26457513, 0.26457513, 0.28284271, 0.3, 0.3, 0.33166248, 0.34641016, 0.34641016,
                        0.42426407, 0.4472136, 0.45825757, 0.45825757, 0.5, 0.50990195, 0.51961524, 0.52915026,
                        0.52915026, 0.52915026, 0.52915026, 0.53851648, 0.53851648, 0.54772256, 0.55677644,
                        0.56568542, 0.56568542, 0.57445626, 0.57445626, 0.58309519, 0.60827625, 0.6164414,
                        0.63245553, 0.64031242, 0.64031242, 0.64807407, 0.72801099, 0.78102497, 0.79372539 },
                new double[] { 0., 0.33166248, 0.34641016, 0.36055513, 0.37416574, 0.38729833, 0.38729833, 0.4,
                        0.41231056, 0.45825757, 0.47958315, 0.52915026, 0.53851648, 0.53851648, 0.59160798,
                        0.60827625, 0.60827625, 0.6164414, 0.6164414, 0.6164414, 0.64807407, 0.64807407, 0.64807407,
                        0.678233, 0.678233, 0.7, 0.7 },
                new double[] { 0., 0.2236068, 0.26457513, 0.3, 0.31622777, 0.31622777, 0.33166248, 0.42426407,
                        0.42426407, 0.42426407, 0.42426407, 0.4472136, 0.45825757, 0.45825757, 0.45825757,
                        0.46904158, 0.47958315, 0.47958315, 0.47958315, 0.48989795, 0.5, 0.50990195, 0.50990195,
                        0.51961524, 0.51961524, 0.54772256, 0.54772256, 0.54772256, 0.6, 0.60827625, 0.60827625,
                        0.6164414, 0.6244998, 0.6244998, 0.64807407, 0.678233, 0.77459667 },
                new double[] { 0., 0.1, 0.14142136, 0.17320508, 0.2, 0.2236068, 0.2236068, 0.2236068, 0.2236068,
                        0.2236068, 0.24494897, 0.33166248, 0.33166248, 0.33166248, 0.36055513, 0.37416574,
                        0.37416574, 0.37416574, 0.38729833, 0.41231056, 0.41231056, 0.42426407, 0.42426407,
                        0.42426407, 0.42426407, 0.42426407, 0.42426407, 0.4472136, 0.4472136, 0.4472136, 0.45825757,
                        0.46904158, 0.46904158, 0.5, 0.5, 0.54772256, 0.60827625, 0.66332496, 0.67082039, 0.7, 0.7,
                        0.73484692, 0.74833148, 0.78740079 },
                new double[] { 0., 0.14142136, 0.3, 0.31622777, 0.34641016, 0.36055513, 0.42426407, 0.42426407,
                        0.43588989, 0.46904158, 0.48989795, 0.50990195, 0.54772256, 0.55677644, 0.55677644,
                        0.55677644, 0.6244998, 0.64031242, 0.67082039, 0.7, 0.72111026, 0.78740079 },
                new double[] { 0., 0., 0., 0.17320508, 0.17320508, 0.17320508, 0.2, 0.26457513, 0.26457513,
                        0.26457513, 0.31622777, 0.31622777, 0.33166248, 0.34641016, 0.34641016, 0.34641016,
                        0.37416574, 0.4472136, 0.4472136, 0.46904158, 0.47958315, 0.5, 0.5, 0.50990195, 0.51961524,
                        0.52915026, 0.52915026, 0.55677644, 0.55677644, 0.55677644, 0.6244998, 0.65574385,
                        0.65574385, 0.7, 0.72801099, 0.72801099, 0.74161985, 0.75498344, 0.75498344, 0.77459667,
                        0.78740079 } };

        int[][] expectedSortedIndices = new int[][] {
                new int[] { 0, 17, 4, 39, 27, 28, 40, 7, 49, 21, 48, 26, 19, 46, 35, 11, 10, 31, 36, 20, 43, 9, 34,
                        37, 23, 2, 6, 29, 1, 30, 25, 16, 47, 24, 12, 45, 44, 5, 32, 3, 22, 18, 42 },
                new int[] { 1, 45, 12, 37, 34, 9, 25, 30, 35, 2, 49, 3, 29, 47, 7, 39, 11, 28, 26, 38, 8, 6, 40, 0,
                        17, 42, 23, 27, 4, 24, 31, 43, 13, 20, 21, 22, 36 },
                new int[] { 2, 47, 3, 45, 12, 6, 42, 29, 1, 35, 37, 34, 9, 49, 30, 38, 11, 7, 40, 8, 25, 39, 13, 0,
                        26, 4, 22, 17, 28, 27, 24, 23, 43, 21, 19, 46, 31 },
                new int[] { 3, 47, 29, 30, 2, 12, 45, 42, 38, 8, 9, 34, 37, 6, 1, 11, 25, 49, 7, 35, 13, 24, 26, 39,
                        40, 0, 23, 4, 17, 28, 43, 22, 27 },
                new int[] { 4, 0, 17, 40, 7, 39, 27, 19, 21, 28, 46, 49, 48, 26, 11, 10, 35, 43, 6, 31, 2, 36, 34,
                        37, 9, 23, 29, 20, 16, 32, 22, 47, 24, 30, 44, 1, 5, 25, 45, 12, 3, 42, 33, 18 },
                new int[] { 5, 18, 10, 48, 44, 46, 19, 16, 21, 32, 33, 27, 31, 20, 17, 36, 43, 0, 15, 4, 28, 26, 39,
                        14, 23, 40, 7 },
                new int[] { 6, 47, 2, 11, 42, 29, 3, 30, 49, 7, 40, 45, 22, 4, 26, 38, 37, 9, 34, 12, 35, 17, 1, 0,
                        39, 8, 24, 43, 21, 25, 28, 13, 23, 27, 19, 46, 48 },
                new int[] { 7, 39, 49, 0, 17, 26, 28, 27, 11, 4, 40, 9, 34, 37, 35, 29, 30, 21, 23, 2, 25, 19, 1,
                        46, 48, 43, 6, 24, 20, 31, 47, 45, 12, 3, 10, 36, 44, 42, 22, 5, 16, 32, 38, 8 },
                new int[] { 8, 38, 3, 42, 13, 47, 12, 45, 2, 29, 30, 1, 6, 9, 37, 34, 41, 25, 11, 35, 49, 7 },
                new int[] { 34, 37, 9, 1, 30, 12, 25, 49, 29, 45, 2, 3, 7, 35, 11, 47, 39, 28, 26, 0, 6, 17, 40, 27,
                        24, 23, 4, 42, 38, 8, 20, 43, 31, 21, 48, 13, 46, 19, 36, 22, 10 } };

        Neighborhood neighb = tree.queryRadius(query, 0.8, true);

        // ensure doesn't throw NPE
        assertTrue(null != neighb.toString());

        // ensure doesn't throw NPE
        assertTrue(null != neighb.copy());

        // Just want to know that the total diff in matrices generated from sklearn and clust4j
        // is less than some arbitrarily low number, say one (rounding error).
        assertTrue(absDiffInMatrices(expectedSortedDists, neighb.getDistances()) < 1);

        int[][] indices = neighb.getIndices();
        for (int i = 0; i < expectedSortedIndices.length; i++)
            assertTrue(differenceInIdxArrays(expectedSortedIndices[i], indices[i]) <= 2);
    }

    private static double absDiffInMatrices(double[][] expected, double[][] actual) {
        double sumA = 0;
        double sumB = 0;
        for (int i = 0; i < expected.length; i++) {
            sumA += VecUtils.sum(VecUtils.abs(expected[i]));
            sumB += VecUtils.sum(VecUtils.abs(actual[i]));
        }

        return FastMath.abs(sumA - sumB);
    }

    private static int differenceInIdxArrays(int[] expected, int[] actual) {
        // Check to see if the diff is <= 2
        ArrayList<Integer> aa = new ArrayList<Integer>();
        ArrayList<Integer> bb = new ArrayList<Integer>();

        for (int in : expected)
            aa.add(in);
        for (int in : actual)
            bb.add(in);

        ArrayList<Integer> larger = aa.size() > bb.size() ? aa : bb;
        ArrayList<Integer> smaller = aa.equals(larger) ? bb : aa;
        larger.removeAll(smaller);

        return larger.size();
    }

    private void addOne(MutableDouble d) {
        d.value++;
    }

    @Test
    public void testMutDouble2() {
        MutableDouble d = new MutableDouble();
        addOne(d);
        assertTrue(d.value == 1);
    }

    @Test
    public void testTwoPointCorrelation() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        double[][] query = new double[10][];

        int idx = 0;
        for (double[] row : IRIS.getData()) {
            if (idx == query.length)
                break;
            query[idx++] = row; // copied implicitly
        }

        int[] corSingle, corDual;
        corSingle = tree.twoPointCorrelation(query, 2.5, false);
        corDual = tree.twoPointCorrelation(query, 2.5, true);
        assertTrue(VecUtils.equalsExactly(corSingle, corDual));
        assertTrue(VecUtils.equalsExactly(corSingle, VecUtils.repInt(542, 10)));

        corSingle = tree.twoPointCorrelation(query, 1.5, false);
        corDual = tree.twoPointCorrelation(query, 1.5, true);
        assertTrue(VecUtils.equalsExactly(corSingle, corDual));
        assertTrue(VecUtils.equalsExactly(corSingle, VecUtils.repInt(489, 10)));

        corSingle = tree.twoPointCorrelation(query, 25, false);
        corDual = tree.twoPointCorrelation(query, 25, true);
        assertTrue(VecUtils.equalsExactly(corSingle, corDual));
        assertTrue(VecUtils.equalsExactly(corSingle, VecUtils.repInt(1500, 10)));

        corSingle = tree.twoPointCorrelation(query, 0, false);
        corDual = tree.twoPointCorrelation(query, 0, true);
        assertTrue(VecUtils.equalsExactly(corSingle, corDual));
        assertTrue(VecUtils.equalsExactly(corSingle, VecUtils.repInt(12, 10)));

        corSingle = tree.twoPointCorrelation(query, -1, false);
        corDual = tree.twoPointCorrelation(query, -1, true);
        assertTrue(VecUtils.equalsExactly(corSingle, corDual));
        assertTrue(VecUtils.equalsExactly(corSingle, VecUtils.repInt(0, 10)));

        // Test a big query now, just to ensure no exceptions are thrown...
        final double[][] X = IRIS.getData();
        tree.twoPointCorrelation(X, -1, false);
        tree.twoPointCorrelation(X, -1, true);
        tree.twoPointCorrelation(X, -1.0);
        tree.twoPointCorrelation(X, new double[] { 1, 2 });

        // Make it a ball tree, now.
        KMeans loggable = new KMeans(IRIS, 3); // don't fit, just using the logger...
        tree = new BallTree(IRIS, Distance.EUCLIDEAN, loggable);
        int[] corr = tree.twoPointCorrelation(X, -1, false);
        assertTrue(VecUtils.equalsExactly(corr, VecUtils.repInt(0, X.length)));
    }

    @Test(expected = DimensionMismatchException.class)
    public void testTwoPointCorrelationExcept1() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        tree.twoPointCorrelation(new double[][] { new double[] { 1, 2 } }, new double[] { 1.5 });
    }

    @Test(expected = DimensionMismatchException.class)
    public void testTwoPointCorrelationExcept2() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        tree.twoPointCorrelation(new double[][] { new double[] { 1, 2 } }, 1.5);
    }

    @Test(expected = DimensionMismatchException.class)
    public void radiusQueryTestDimException() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        tree.queryRadius(new double[][] { new double[] { 1, 2 } }, 150.0, true);
    }

    @Test
    public void radiusQueryTestAllInRadius() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        tree.queryRadius(new double[][] { new double[] { 2.5, 2.5, 2.5, 2.5 } }, 150.0, true);
    }

    @Test(expected = DimensionMismatchException.class)
    public void radiusQueryTestMPrimeDimMismatch1() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        tree.queryRadius(new double[][] { new double[] { 2.5, 2.5, 2.5, 2.5 } }, new double[] { 1, 2, 3, 4, 5 },
                true);
    }

    @Test(expected = DimensionMismatchException.class)
    public void radiusQueryTestNDimMismatch2() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        tree.queryRadius(new double[][] { new double[] { 2.5, 2.5, 2.5 } }, new double[] { 5 }, true);
    }

    @Test(expected = DimensionMismatchException.class)
    public void queryNDimMismatch1() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        tree.query(new double[][] { new double[] { 1, 2 } }, 2, true, true);
    }

    @Test(expected = DimensionMismatchException.class)
    public void testKernelDimMismatch() {
        NearestNeighborHeapSearch tree = new KDTree(IRIS);
        tree.kernelDensity(new double[][] { new double[] { 1.0 } }, 1.0, PartialKernelDensity.LOG_COSINE, 0.0, 1e-8,
                false);
    }

    @Test
    public void testNodeDataEquals() {
        NodeData n1 = new NodeData();
        NodeData n2 = new NodeData(1, 2, true, 1.9);

        assertTrue(n1.equals(n1));
        assertFalse(n1.equals(n2));
        assertFalse(n1.equals(new String()));
    }

    @Test
    public void testInfDist() {
        Array2DRowRealMatrix mat = new Array2DRowRealMatrix(
                MatUtils.reshape(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }, 3, 3), false);

        KDTree k = new KDTree(mat, Distance.CHEBYSHEV);
        Neighborhood n = k.query(mat.getDataRef());
        Neighborhood p = k.query(mat.getDataRef(), 1, false, true);
        assertTrue(n.equals(p));
        assertTrue(n.equals(n));
        assertFalse(n.equals("asdf"));

        Neighborhood res = new Neighborhood(
                new double[][] { new double[] { 0.0 }, new double[] { 0.0 }, new double[] { 0.0 } },

                new int[][] { new int[] { 0 }, new int[] { 1 }, new int[] { 2 } });

        assertTrue(n.equals(res));
        final int[] corr = k.twoPointCorrelation(mat.getDataRef(), new double[] { 1, 2, 3 });
        assertTrue(VecUtils.equalsExactly(corr, new int[] { 3, 3, 7 }));
        assertTrue(k.infinity_dist);
    }

    @Test
    public void testWarn() {
        Array2DRowRealMatrix mat = new Array2DRowRealMatrix(
                MatUtils.reshape(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }, 3, 3), false);

        KDTree k = new KDTree(mat, Distance.HAVERSINE.MI, new KMeans(mat, 1));
        assertTrue(k.logger.hasWarnings());
    }

    @Test
    public void testImmutability() {
        double[][] a = MatUtils.reshape(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }, 3, 3);

        double[][] b = MatUtils.copy(a);
        Array2DRowRealMatrix mat = new Array2DRowRealMatrix(a, false);

        KDTree k = new KDTree(mat, Distance.EUCLIDEAN);
        k.query(a);

        assertTrue(MatUtils.equalsExactly(b, a)); // assert immutability
    }

    @Test
    public void testInfDistQuery() {
        int[] expected = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
                22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 9, 35, 36, 9, 38, 39, 40, 41, 42, 43, 44, 45, 46,
                47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
                72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
                97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
                117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
                137, 138, 139, 140, 141, 101, 143, 144, 145, 146, 147, 148, 149 };

        /*
         * Test KDTree query
         */
        KDTree k = new KDTree(IRIS, Distance.CHEBYSHEV);
        Neighborhood query = k.query(IRIS.getData());
        double[][] dists = query.getDistances();

        // Assert all are zero:
        assertTrue(new MatUtils.MatSeries(dists, Inequality.EQUAL_TO, 0.0).all());

        // Assert the indices equal to expected
        assertTrue(VecUtils.equalsExactly(MatUtils.flatten(query.getIndices()), expected));

        /*
         * Test BallTree query
         */
        BallTree b = new BallTree(IRIS, Distance.CHEBYSHEV);
        query = b.query(IRIS.getData());
        dists = query.getDistances();

        // Assert all are zero:
        assertTrue(new MatUtils.MatSeries(dists, Inequality.EQUAL_TO, 0.0).all());

        // Assert the indices equal to expected
        assertTrue(VecUtils.equalsExactly(MatUtils.flatten(query.getIndices()), expected));

        /**
         * The radius
         */
        final double radius = 0.05;
        final double[][] expected_dists = new double[][] { new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0., 0., 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0., 0., 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0., 0., 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0., 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0., 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. },
                new double[] { 0. }, new double[] { 0. }, new double[] { 0. }, new double[] { 0. } };

        int[][] expected_idcs = new int[][] { new int[] { 0 }, new int[] { 1 }, new int[] { 2 }, new int[] { 3 },
                new int[] { 4 }, new int[] { 5 }, new int[] { 6 }, new int[] { 7 }, new int[] { 8 },
                new int[] { 9, 34, 37 }, new int[] { 10 }, new int[] { 11 }, new int[] { 12 }, new int[] { 13 },
                new int[] { 14 }, new int[] { 15 }, new int[] { 16 }, new int[] { 17 }, new int[] { 18 },
                new int[] { 19 }, new int[] { 20 }, new int[] { 21 }, new int[] { 22 }, new int[] { 23 },
                new int[] { 24 }, new int[] { 25 }, new int[] { 26 }, new int[] { 27 }, new int[] { 28 },
                new int[] { 29 }, new int[] { 30 }, new int[] { 31 }, new int[] { 32 }, new int[] { 33 },
                new int[] { 9, 34, 37 }, new int[] { 35 }, new int[] { 36 }, new int[] { 9, 34, 37 },
                new int[] { 38 }, new int[] { 39 }, new int[] { 40 }, new int[] { 41 }, new int[] { 42 },
                new int[] { 43 }, new int[] { 44 }, new int[] { 45 }, new int[] { 46 }, new int[] { 47 },
                new int[] { 48 }, new int[] { 49 }, new int[] { 50 }, new int[] { 51 }, new int[] { 52 },
                new int[] { 53 }, new int[] { 54 }, new int[] { 55 }, new int[] { 56 }, new int[] { 57 },
                new int[] { 58 }, new int[] { 59 }, new int[] { 60 }, new int[] { 61 }, new int[] { 62 },
                new int[] { 63 }, new int[] { 64 }, new int[] { 65 }, new int[] { 66 }, new int[] { 67 },
                new int[] { 68 }, new int[] { 69 }, new int[] { 70 }, new int[] { 71 }, new int[] { 72 },
                new int[] { 73 }, new int[] { 74 }, new int[] { 75 }, new int[] { 76 }, new int[] { 77 },
                new int[] { 78 }, new int[] { 79 }, new int[] { 80 }, new int[] { 81 }, new int[] { 82 },
                new int[] { 83 }, new int[] { 84 }, new int[] { 85 }, new int[] { 86 }, new int[] { 87 },
                new int[] { 88 }, new int[] { 89 }, new int[] { 90 }, new int[] { 91 }, new int[] { 92 },
                new int[] { 93 }, new int[] { 94 }, new int[] { 95 }, new int[] { 96 }, new int[] { 97 },
                new int[] { 98 }, new int[] { 99 }, new int[] { 100 }, new int[] { 101, 142 }, new int[] { 102 },
                new int[] { 103 }, new int[] { 104 }, new int[] { 105 }, new int[] { 106 }, new int[] { 107 },
                new int[] { 108 }, new int[] { 109 }, new int[] { 110 }, new int[] { 111 }, new int[] { 112 },
                new int[] { 113 }, new int[] { 114 }, new int[] { 115 }, new int[] { 116 }, new int[] { 117 },
                new int[] { 118 }, new int[] { 119 }, new int[] { 120 }, new int[] { 121 }, new int[] { 122 },
                new int[] { 123 }, new int[] { 124 }, new int[] { 125 }, new int[] { 126 }, new int[] { 127 },
                new int[] { 128 }, new int[] { 129 }, new int[] { 130 }, new int[] { 131 }, new int[] { 132 },
                new int[] { 133 }, new int[] { 134 }, new int[] { 135 }, new int[] { 136 }, new int[] { 137 },
                new int[] { 138 }, new int[] { 139 }, new int[] { 140 }, new int[] { 141 }, new int[] { 101, 142 },
                new int[] { 143 }, new int[] { 144 }, new int[] { 145 }, new int[] { 146 }, new int[] { 147 },
                new int[] { 148 }, new int[] { 149 } };

        /*
         * Test KDTree query radius
         */
        query = k.queryRadius(IRIS.getData(), radius, true);
        dists = query.getDistances();
        assertTrue(MatUtils.equalsExactly(dists, expected_dists));
        assertTrue(MatUtils.equalsExactly(query.getIndices(), expected_idcs));

        /*
         * Test BallTree query radius
         */
        query = b.queryRadius(IRIS.getData(), radius, true);
        dists = query.getDistances();
        assertTrue(MatUtils.equalsExactly(dists, expected_dists));
        assertTrue(MatUtils.equalsExactly(query.getIndices(), expected_idcs));

        assertTrue(k.infinity_dist);
        assertTrue(b.infinity_dist);

        /*
         * TEST TWO POINT CORRELATION WITH INF DIST
         */
        int[] vis1, vis2;
        int[] tpc = vis1 = k.twoPointCorrelation(IRIS.getDataRef(), radius);
        assertTrue(tpc.length == IRIS.getRowDimension());
        assertTrue(new IntSeries(tpc, Inequality.EQUAL_TO, 158).all());
        // Just so any() gets some coverage love...
        assertFalse(new IntSeries(tpc, Inequality.EQUAL_TO, 0).any());

        tpc = vis2 = b.twoPointCorrelation(IRIS.getDataRef(), radius);
        assertTrue(tpc.length == IRIS.getRowDimension());
        assertTrue(new IntSeries(tpc, Inequality.EQUAL_TO, 158).all());

        /*
         * Assert vis1, vis2 are the same
         */
        IntSeries vis = new IntSeries(vis1, Inequality.EQUAL_TO, vis2);
        assertTrue(vis.all());

        // Coverage love, test refs...
        boolean[] bref = vis.getRef();
        boolean[] bcop = vis.get();
        bcop[0] = false;

        assertFalse(VecUtils.equalsExactly(bref, bcop));
    }

    @Test
    public void testTooHighKQuery() {
        KDTree k = new KDTree(IRIS);
        boolean a = false;

        try {
            k.query(IRIS.getData(), 1500, true, true);
        } catch (IllegalArgumentException i) {
            a = true;
        } finally {
            assertTrue(a);
        }
    }

    @Test
    public void testQueryRadiusManyArgs() {
        KDTree k = new KDTree(IRIS);

        // make sure returns same neighborhood...
        assertEquals(k.queryRadius(IRIS, VecUtils.rep(1.5, IRIS.getRowDimension()), true),
                k.queryRadius(IRIS.getData(), 1.5, true));
    }
}