org.deeplearning4j.clustering.kdtree.KDTreeTest.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.clustering.kdtree.KDTreeTest.java

Source

/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.clustering.kdtree;

import com.google.common.primitives.Doubles;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/**
 * Created by agibsonccc on 1/1/15.
 */
public class KDTreeTest {
    @Test
    public void testTree() {
        KDTree tree = new KDTree(2);
        INDArray half = Nd4j.create(Nd4j.createBuffer(new double[] { 0.5, 0.5 }));
        INDArray one = Nd4j.create(Nd4j.createBuffer(new double[] { 1, 1 }));
        tree.insert(half);
        tree.insert(one);
        Pair<Double, INDArray> pair = tree.nn(Nd4j.create(Nd4j.createBuffer(new double[] { 0.5, 0.5 })));
        assertEquals(half, pair.getValue());
    }

    public void testInsert() {
        int elements = 10;
        List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0);

        KDTree kdTree = new KDTree(digits.size());
        List<List<Double>> lists = new ArrayList<>();
        for (int i = 0; i < elements; i++) {
            List<Double> thisList = new ArrayList<>(digits.size());
            for (int k = 0; k < digits.size(); k++) {
                thisList.add(digits.get(k) + i);
            }
            lists.add(thisList);
        }

        for (int i = 0; i < elements; i++) {
            double[] features = Doubles.toArray(lists.get(i));
            INDArray ind = Nd4j.create(Nd4j.createBuffer(features));
            kdTree.insert(ind);
            assertEquals(i + 1, kdTree.size());
        }
    }

    @Test
    public void testNN() {
        int n = 10;

        // make a KD-tree of dimension {#n}
        KDTree kdTree = new KDTree(n);
        for (int i = -1; i < n; i++) {
            // Insert a unit vector along each dimension
            List<Double> vec = new ArrayList<>(n);
            // i = -1 ensures the origin is in the Tree
            for (int k = 0; k < n; k++) {
                vec.add((k == i) ? 1.0 : 0.0);
            }
            INDArray indVec = Nd4j.create(Nd4j.createBuffer(Doubles.toArray(vec)));
            kdTree.insert(indVec);
        }
        Random rand = new Random();

        // random point in the Hypercube
        List<Double> pt = new ArrayList(n);
        for (int k = 0; k < n; k++) {
            pt.add(rand.nextDouble());
        }
        Pair<Double, INDArray> result = kdTree.nn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))));

        // Always true for points in the unitary hypercube
        assertTrue(result.getKey() < Double.MAX_VALUE);

    }

}