zx.soft.mahout.knn.search.AbstractSearchTest.java Source code

Java tutorial

Introduction

Here is the source code for zx.soft.mahout.knn.search.AbstractSearchTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package zx.soft.mahout.knn.search;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.util.List;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.MultiNormal;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Test;

import zx.soft.mahout.knn.search.Searcher;
import zx.soft.mahout.knn.search.UpdatableSearcher;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;

public abstract class AbstractSearchTest {

    protected static Matrix randomData() {
        Matrix data = new DenseMatrix(1000, 20);
        MultiNormal gen = new MultiNormal(20);
        for (MatrixSlice slice : data) {
            slice.vector().assign(gen.sample());
        }
        return data;
    }

    public abstract Iterable<MatrixSlice> testData();

    /**
     * Gets a searcher whose search size is n.
     * @param n
     * @return
     */
    public abstract Searcher getSearch(int n);

    @Test
    public void testExactMatch() {
        Iterable<MatrixSlice> data = testData();

        final Iterable<MatrixSlice> batch1 = Iterables.limit(data, 300);
        List<MatrixSlice> queries = Lists.newArrayList(Iterables.limit(batch1, 100));
        Searcher s = getSearch(20);

        // adding the data in multiple batches triggers special code in some searchers
        s.addAllMatrixSlices(batch1);
        assertEquals(300, s.size());

        Vector q = Iterables.get(data, 0).vector();
        List<WeightedThing<Vector>> r = s.search(q, 2);
        assertEquals(0, r.get(0).getValue().minus(q).norm(1), 1e-8);

        final Iterable<MatrixSlice> batch2 = Iterables.limit(Iterables.skip(data, 300), 10);
        s.addAllMatrixSlices(batch2);
        assertEquals(310, s.size());

        q = Iterables.get(data, 302).vector();
        r = s.search(q, 2);
        assertEquals(0, r.get(0).getValue().minus(q).norm(1), 1e-8);

        s.addAllMatrixSlices(Iterables.skip(data, 310));
        assertEquals(Iterables.size(testData()), s.size());

        for (MatrixSlice query : queries) {
            r = s.search(query.vector(), 2);
            assertEquals("Distance has to be about zero", 0, r.get(0).getWeight(), 1e-6);
            assertEquals("Answer must be substantially the same as query", 0,
                    r.get(0).getValue().minus(query.vector()).norm(1), 1e-8);
            assertTrue("Wrong answer must have non-zero distance", r.get(1).getWeight() > r.get(0).getWeight());
        }
    }

    @Test
    public void testNearMatch() {
        List<MatrixSlice> queries = Lists.newArrayList(Iterables.limit(testData(), 100));
        Searcher s = getSearch(20);
        s.addAllMatrixSlicesAsWeightedVectors(testData());

        MultiNormal noise = new MultiNormal(0.01, new DenseVector(20));
        for (MatrixSlice slice : queries) {
            Vector query = slice.vector();
            final Vector epsilon = noise.sample();
            //         List<WeightedThing<Vector>> r0 = s.search(query, 2);
            query = query.plus(epsilon);
            List<WeightedThing<Vector>> r = s.search(query, 2);
            r = s.search(query, 2);
            assertEquals("Distance has to be small", epsilon.norm(2), r.get(0).getWeight(), 1e-5);
            assertEquals("Answer must be substantially the same as query", epsilon.norm(2),
                    r.get(0).getValue().minus(query).norm(2), 1e-5);
            assertTrue("Wrong answer must be further away", r.get(1).getWeight() > r.get(0).getWeight());
        }
    }

    @Test
    public void testOrdering() {
        Matrix queries = new DenseMatrix(100, 20);
        MultiNormal gen = new MultiNormal(20);
        for (int i = 0; i < 100; i++) {
            queries.viewRow(i).assign(gen.sample());
        }

        Searcher s = getSearch(20);
        // s.setSearchSize(200);
        s.addAllMatrixSlices(testData());

        for (MatrixSlice query : queries) {
            List<WeightedThing<Vector>> r = s.search(query.vector(), 200);
            double x = 0;
            for (WeightedThing<Vector> thing : r) {
                assertTrue("Scores must be monotonic increasing", thing.getWeight() > x);
                x = thing.getWeight();
            }
        }
    }

    @Test
    public void testSmallSearch() {
        Matrix m = new DenseMatrix(8, 3);
        for (int i = 0; i < 8; i++) {
            m.viewRow(i).assign(new double[] { 0.125 * (i & 4), i & 2, i & 1 });
        }

        Searcher s = getSearch(3);
        s.addAllMatrixSlices(m);
        for (MatrixSlice row : m) {
            final List<WeightedThing<Vector>> r = s.search(row.vector(), 3);
            assertEquals(0, r.get(0).getWeight(), 1e-8);
            assertEquals(0, r.get(1).getWeight(), 0.5);
            assertEquals(0, r.get(2).getWeight(), 1);
        }
    }

    @Test
    public void testRemoval() {
        Searcher s = getSearch(20);
        s.addAllMatrixSlices(testData());
        if (s instanceof UpdatableSearcher) {
            List<Vector> x = Lists.newArrayList(Iterables.limit(s, 2));
            int size0 = s.size();

            List<WeightedThing<Vector>> r0 = s.search(x.get(0), 2);

            s.remove(x.get(0), 1e-7);
            assertEquals(size0 - 1, s.size());

            List<WeightedThing<Vector>> r = s.search(x.get(0), 1);
            assertTrue("Vector should be gone", r.get(0).getWeight() > 0);
            assertEquals("Previous second neighbor should be first", 0,
                    r.get(0).getValue().minus(r0.get(1).getValue()).norm(1), 1e-8);

            s.remove(x.get(1), 1e-7);
            assertEquals(size0 - 2, s.size());

            r = s.search(x.get(1), 1);
            assertTrue("Vector should be gone", r.get(0).getWeight() > 0);

            // vectors don't show up in iterator
            for (Vector v : s) {
                assertTrue(x.get(0).minus(v).norm(1) > 1e-8);
                assertTrue(x.get(1).minus(v).norm(1) > 1e-8);
            }
        } else {
            try {
                List<Vector> x = Lists.newArrayList(Iterables.limit(s, 2));
                s.remove(x.get(0), 1e-7);
                fail("Shouldn't be able to delete from " + s.getClass().getName());
            } catch (UnsupportedOperationException e) {
                // good enough that UOE is thrown
            }
        }
    }

    /*
    public List<Vector> subset(Iterable<Vector> data, int n) {
      List<Vector> r = Lists.newArrayList();
      Random gen = RandomUtils.getRandom();
        
      int i = 0;
      for (Vector row : data) {
        if (r.size() < n) {
       r.add(row);
        } else {
     int k = gen.nextInt(row.getIndex() + 1);
     if (k < r.size()) {
         r.set(k, new WeightedVector(row.getVector(), 1, i++));
     }
     i++;
        }
      }
        
      return r;
    }
    */
}