Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.knn.search; import com.google.common.base.Preconditions; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Ordering; import com.google.common.collect.Sets; import org.apache.mahout.common.distance.DistanceMeasure; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.function.DoubleFunction; import org.apache.mahout.math.function.Functions; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.ArrayList; import java.util.TreeSet; import java.util.HashSet; /** * Does approximate nearest neighbor dudes search by projecting the data. */ public class ProjectionSearch { private final List<TreeSet<Vector>> vectors; private DistanceMeasure distance; public ProjectionSearch(int d, DistanceMeasure distance) { this(d, distance, 1); } public ProjectionSearch(int d, DistanceMeasure distance, int projections) { Preconditions.checkArgument(projections > 0 && projections < 100, "Unreasonable value for number of projections"); final DoubleFunction random = Functions.random(); this.distance = distance; vectors = Lists.newArrayList(); // we want to create several projections. Each is alike except for the // direction of the projection for (int i = 0; i < projections; i++) { // create a random vector to use for the basis of the projection final DenseVector projection = new DenseVector(d); projection.assign(random); projection.normalize(); // the projection is implemented by a tree set where the ordering of vectors // is based on the dot product of the vector with the projection vector TreeSet<Vector> s = Sets.newTreeSet(new Comparator<Vector>() { @Override public int compare(Vector v1, Vector v2) { int r = Double.compare(v1.dot(projection), v2.dot(projection)); if (r == 0) { return v1.hashCode() - v2.hashCode(); } else { return r; } } }); // so we have a project (s) and we need to add it to the list of projections for later vectors.add(s); } } /** * Adds a vector into the set of projections for later searching. * @param v The vector to add. */ public void add(Vector v) { // add to each projection separately for (TreeSet<Vector> s : vectors) { s.add(v); } } public static void removeDuplicate(List list) { HashSet h = new HashSet(list); list.clear(); list.addAll(h); } public List<Vector> search(final Vector query, int n, int searchSize) { List<Vector> top = Lists.newArrayList(); for (TreeSet<Vector> v : vectors) { Iterables.addAll(top, Iterables.limit(v.tailSet(query, true), searchSize)); Iterables.addAll(top, Iterables.limit(v.headSet(query, false).descendingSet(), searchSize)); } System.out.print(top.size()); removeDuplicate(top); System.out.print(" "); System.out.println(top.size()); // if searchSize * vectors.size() is small enough not to cause much memory pressure, this is probably // just as fast as a priority queue here. Collections.sort(top, byQueryDistance(query)); return top.subList(0, n); } private Ordering<Vector> byQueryDistance(final Vector query) { return new Ordering<Vector>() { @Override public int compare(Vector v1, Vector v2) { int r = Double.compare(distance.distance(query, v1), distance.distance(query, v2)); return r != 0 ? r : v1.hashCode() - v2.hashCode(); } }; } }