Java tutorial
/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.clustering.kdtree; import org.apache.commons.lang3.tuple.Pair; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance; import org.nd4j.linalg.factory.Nd4j; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; /** * KDTree based on: https://github.com/nicky-zs/kdtree-python/blob/master/kdtree.py * * @author Adam Gibson */ public class KDTree implements Serializable { private KDNode root; private int dims = 100; public final static int GREATER = 1; public final static int LESS = 0; private int size = 0; private HyperRect rect; public KDTree(int dims) { this.dims = dims; } /** * Insert a point in to the tree * @param point the point to insert */ public void insert(INDArray point) { if (!point.isVector() || point.length() != dims) throw new IllegalArgumentException("Point must be a vector of length " + dims); if (root == null) { root = new KDNode(point); rect = new HyperRect(HyperRect.point(point)); } else { int disc = 0; KDNode node = root; KDNode insert = new KDNode(point); int successor; while (true) { //exactly equal if (node.getPoint().neq(point).sum(Integer.MAX_VALUE).getDouble(0) == 0) { return; } else { successor = successor(node, point, disc); KDNode child; if (successor < 1) child = node.getLeft(); else child = node.getRight(); if (child == null) break; disc = (disc + 1) % dims; node = child; } } if (successor < 1) node.setLeft(insert); else node.setRight(insert); rect.enlargeTo(point); insert.setParent(node); } size++; } public KDNode delete(INDArray point) { KDNode node = root; int _disc = 0; while (node != null) { if (node.point == point) break; int successor = successor(node, point, _disc); if (successor < 1) node = node.getLeft(); else node = node.getRight(); _disc = (_disc + 1) % dims; } if (node != null) { if (node == root) { root = delete(root, _disc); } else node = delete(node, _disc); size--; if (size == 1) { rect = new HyperRect(HyperRect.point(point)); } else rect = null; } return node; } public List<Pair<Double, INDArray>> knn(INDArray point, double distance) { List<Pair<Double, INDArray>> best = new ArrayList<>(); knn(root, point, rect, distance, best, 0); Collections.sort(best, new Comparator<Pair<Double, INDArray>>() { @Override public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) { return Double.compare(o1.getKey(), o2.getKey()); } }); return best; } private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List<Pair<Double, INDArray>> best, int _disc) { if (node == null || rect.minDistance(point) > dist) return; int _discNext = (_disc + 1) % dims; double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point)).getFinalResult() .doubleValue(); if (distance <= dist) { best.add(Pair.of(distance, node.getPoint())); } HyperRect lower = rect.getLower(point, _disc); HyperRect upper = rect.getUpper(point, _disc); knn(node.getLeft(), point, lower, dist, best, _discNext); knn(node.getRight(), point, upper, dist, best, _discNext); } /** * Query for nearest neighbor. Returns the distance and point * @param point the point to query for * @return */ public Pair<Double, INDArray> nn(INDArray point) { return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0); } private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, int _disc) { if (node == null || rect.minDistance(point) > dist) return Pair.of(Double.POSITIVE_INFINITY, null); int _discNext = (_disc + 1) % dims; double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point)).getFinalResult() .doubleValue(); if (dist2 < dist) { best = node.getPoint(); dist = dist2; } HyperRect lower = rect.getLower(node.point, _disc); HyperRect upper = rect.getUpper(node.point, _disc); if (point.getDouble(_disc) < node.point.getDouble(_disc)) { Pair<Double, INDArray> left = nn(node.getLeft(), point, lower, dist, best, _discNext); Pair<Double, INDArray> right = nn(node.getRight(), point, upper, dist, best, _discNext); if (left.getKey() < dist) return left; else if (right.getKey() < dist) return right; } else { Pair<Double, INDArray> left = nn(node.getRight(), point, upper, dist, best, _discNext); Pair<Double, INDArray> right = nn(node.getLeft(), point, lower, dist, best, _discNext); if (left.getKey() < dist) return left; else if (right.getKey() < dist) return right; } return Pair.of(dist, best); } private KDNode delete(KDNode delete, int _disc) { if (delete.getLeft() != null && delete.getRight() != null) { if (delete.getParent() != null) { if (delete.getParent().getLeft() == delete) delete.getParent().setLeft(null); else delete.getParent().setRight(null); } return null; } int disc = _disc; _disc = (_disc + 1) % dims; Pair<KDNode, Integer> qd = null; if (delete.getRight() != null) { qd = min(delete.getRight(), disc, _disc); } else if (delete.getLeft() != null) qd = max(delete.getLeft(), disc, _disc); delete.point = qd.getKey().point; KDNode qFather = qd.getKey().getParent(); if (qFather.getLeft() == qd.getKey()) { qFather.setLeft(delete(qd.getKey(), disc)); } else if (qFather.getRight() == qd.getKey()) { qFather.setRight(delete(qd.getKey(), disc)); } return delete; } private Pair<KDNode, Integer> max(KDNode node, int disc, int _disc) { int discNext = (_disc + 1) % dims; if (_disc == disc) { KDNode child = node.getLeft(); if (child != null) { return max(child, disc, discNext); } } else if (node.getLeft() != null || node.getRight() != null) { Pair<KDNode, Integer> left = null, right = null; if (node.getLeft() != null) left = max(node.getLeft(), disc, discNext); if (node.getRight() != null) right = max(node.getRight(), disc, discNext); if (left != null && right != null) { double pointLeft = left.getKey().getPoint().getDouble(disc); double pointRight = right.getKey().getPoint().getDouble(disc); if (pointLeft > pointRight) return left; else return right; } else if (left != null) return left; else return right; } return Pair.of(node, _disc); } private Pair<KDNode, Integer> min(KDNode node, int disc, int _disc) { int discNext = (_disc + 1) % dims; if (_disc == disc) { KDNode child = node.getLeft(); if (child != null) { return min(child, disc, discNext); } } else if (node.getLeft() != null || node.getRight() != null) { Pair<KDNode, Integer> left = null, right = null; if (node.getLeft() != null) left = min(node.getLeft(), disc, discNext); if (node.getRight() != null) right = min(node.getRight(), disc, discNext); if (left != null && right != null) { double pointLeft = left.getKey().getPoint().getDouble(disc); double pointRight = right.getKey().getPoint().getDouble(disc); if (pointLeft < pointRight) return left; else return right; } else if (left != null) return left; else return right; } return Pair.of(node, _disc); } /** * The number of elements in the tree * @return the number of elements in the tree */ public int size() { return size; } private int successor(KDNode node, INDArray point, int disc) { for (int i = disc; i < dims; i++) { double pointI = point.getDouble(i); double nodePointI = node.getPoint().getDouble(i); if (pointI < nodePointI) return LESS; else if (pointI > nodePointI) return GREATER; } throw new IllegalStateException("Point is equal!"); } public static class KDNode { private INDArray point; private KDNode left, right, parent; public KDNode(INDArray point) { this.point = point; } public INDArray getPoint() { return point; } public KDNode getLeft() { return left; } public void setLeft(KDNode left) { this.left = left; } public KDNode getRight() { return right; } public void setRight(KDNode right) { this.right = right; } public KDNode getParent() { return parent; } public void setParent(KDNode parent) { this.parent = parent; } } }