edu.indiana.d2i.htrc.skmeans.StreamingKMeansAdapterTest.java Source code

Java tutorial

Introduction

Here is the source code for edu.indiana.d2i.htrc.skmeans.StreamingKMeansAdapterTest.java

Source

/*
#
# Copyright 2012 The Trustees of Indiana University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -----------------------------------------------------------------
#
# Project: knn
# File:  StreamingKMeansAdapterTest.java
# Description:  
#
# -----------------------------------------------------------------
# 
 */

package edu.indiana.d2i.htrc.skmeans;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.knn.WeightedVector;
import org.apache.mahout.knn.generate.MultiNormal;
import org.apache.mahout.knn.means.StreamingKmeans;
import org.apache.mahout.knn.means.StreamingKmeans.CentroidFactory;
import org.apache.mahout.knn.search.ProjectionSearch;
import org.apache.mahout.knn.search.Searcher;
import org.apache.mahout.knn.search.UpdatableSearcher;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.junit.Test;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;

public class StreamingKMeansAdapterTest {

    private static double totalWeight(Iterable<MatrixSlice> data) {
        double sum = 0;
        for (MatrixSlice row : data) {
            if (row.vector() instanceof WeightedVector) {
                sum += ((WeightedVector) row.vector()).getWeight();
            } else {
                sum++;
            }
        }
        return sum;
    }

    @Test
    public static void testCluster() {
        int dimension = 500;

        // construct data samplers centered on the corners of a unit cube
        Matrix mean = new DenseMatrix(8, dimension);
        List<MultiNormal> rowSamplers = Lists.newArrayList();
        for (int i = 0; i < 8; i++) {
            //         mean.viewRow(i).assign(
            //               new double[] { 0.25 * (i & 4), 0.5 * (i & 2), i & 1 });

            double[] random = new double[dimension];
            for (int j = 0; j < random.length; j++) {
                random[j] = Math.random();
            }
            mean.viewRow(i).assign(random);
            rowSamplers.add(new MultiNormal(0.01, mean.viewRow(i)));
        }

        // sample a bunch of data points
        Matrix data = new DenseMatrix(10000, dimension);
        for (MatrixSlice row : data) {
            row.vector().assign(rowSamplers.get(row.index() % 8).sample());
        }

        // cluster the data
        long t0 = System.currentTimeMillis();

        double cutoff = StreamingKMeansAdapter.estimateCutoff(data, 100);
        Configuration conf = new Configuration();
        conf.setInt(StreamingKMeansConfigKeys.MAXCLUSTER, 1000);
        conf.setFloat(StreamingKMeansConfigKeys.CUTOFF, (float) cutoff);
        conf.setClass(StreamingKMeansConfigKeys.DIST_MEASUREMENT, EuclideanDistanceMeasure.class,
                DistanceMeasure.class);
        conf.setInt(StreamingKMeansConfigKeys.VECTOR_DIMENSION, dimension);
        StreamingKMeansAdapter skmeans = new StreamingKMeansAdapter(conf);
        // for (MatrixSlice row : Iterables.skip(data, 1)) {
        // skmeans.cluster(row.vector());
        // }
        for (MatrixSlice row : data) {
            skmeans.cluster(row.vector());
        }

        // validate
        Searcher r = skmeans.getCentroids();

        // StreamingKMeansAdapter skmeans = new StreamingKMeansAdapter();
        // Searcher r = skmeans.cluster(data, 1000, centroidFactory);

        long t1 = System.currentTimeMillis();

        assertEquals("Total weight not preserved", totalWeight(data), totalWeight(r), 1e-9);

        // and verify that each corner of the cube has a centroid very nearby
        for (MatrixSlice row : mean) {
            WeightedVector v = r.search(row.vector(), 1).get(0);
            assertTrue(v.getWeight() < 0.05);
        }
        System.out.printf("%.2f for clustering\n%.1f us per row\n", (t1 - t0) / 1000.0,
                (t1 - t0) / 1000.0 / data.rowSize() * 1e6);

        System.out.println("Done??");
    }

    public static void main(String[] args) {
        testCluster();
    }
}