org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansTestMR.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansTestMR.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.clustering.streaming.mapreduce;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mrunit.mapreduce.MapDriver;
import org.apache.hadoop.mrunit.mapreduce.MapReduceDriver;
import org.apache.hadoop.mrunit.mapreduce.ReduceDriver;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.clustering.streaming.cluster.DataUtils;
import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

@RunWith(Parameterized.class)
public class StreamingKMeansTestMR extends MahoutTestCase {
    private static final int NUM_DATA_POINTS = 1 << 15;
    private static final int NUM_DIMENSIONS = 8;
    private static final int NUM_PROJECTIONS = 3;
    private static final int SEARCH_SIZE = 5;
    private static final int MAX_NUM_ITERATIONS = 10;
    private static final double DISTANCE_CUTOFF = 1.0e-6;

    private static Pair<List<Centroid>, List<Centroid>> syntheticData;

    @Before
    public void setUp() {
        RandomUtils.useTestSeed();
        syntheticData = DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS, 1.0e-4);
    }

    private final String searcherClassName;
    private final String distanceMeasureClassName;

    public StreamingKMeansTestMR(String searcherClassName, String distanceMeasureClassName) {
        this.searcherClassName = searcherClassName;
        this.distanceMeasureClassName = distanceMeasureClassName;
    }

    private void configure(Configuration configuration) {
        configuration.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, distanceMeasureClassName);
        configuration.setInt(StreamingKMeansDriver.SEARCH_SIZE_OPTION, SEARCH_SIZE);
        configuration.setInt(StreamingKMeansDriver.NUM_PROJECTIONS_OPTION, NUM_PROJECTIONS);
        configuration.set(StreamingKMeansDriver.SEARCHER_CLASS_OPTION, searcherClassName);
        configuration.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1 << NUM_DIMENSIONS);
        configuration.setInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS,
                (1 << NUM_DIMENSIONS) * (int) Math.log(NUM_DATA_POINTS));
        configuration.setFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, (float) DISTANCE_CUTOFF);
        configuration.setInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, MAX_NUM_ITERATIONS);

        // Collapse the Centroids in the reducer.
        configuration.setBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, true);
    }

    @Parameterized.Parameters
    public static List<Object[]> generateData() {
        return Arrays.asList(new Object[][] {
                { ProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName() },
                { FastProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName() },
                { LocalitySensitiveHashSearch.class.getName(),
                        SquaredEuclideanDistanceMeasure.class.getName() }, });
    }

    @Test
    public void testHypercubeMapper() throws IOException {
        MapDriver<Writable, VectorWritable, IntWritable, CentroidWritable> mapDriver = MapDriver
                .newMapDriver(new StreamingKMeansMapper());
        configure(mapDriver.getConfiguration());
        System.out.printf("%s mapper test\n",
                mapDriver.getConfiguration().get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
        for (Centroid datapoint : syntheticData.getFirst()) {
            mapDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
        }
        List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results = mapDriver.run();
        BruteSearch resultSearcher = new BruteSearch(new SquaredEuclideanDistanceMeasure());
        for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> result : results) {
            resultSearcher.add(result.getSecond().getCentroid());
        }
        System.out.printf("Clustered the data into %d clusters\n", results.size());
        for (Vector mean : syntheticData.getSecond()) {
            WeightedThing<Vector> closest = resultSearcher.search(mean, 1).get(0);
            assertTrue("Weight " + closest.getWeight() + " not less than 0.5", closest.getWeight() < 0.5);
        }
    }

    @Test
    public void testMapperVsLocal() throws IOException {
        // Clusters the data using the StreamingKMeansMapper.
        MapDriver<Writable, VectorWritable, IntWritable, CentroidWritable> mapDriver = MapDriver
                .newMapDriver(new StreamingKMeansMapper());
        Configuration configuration = mapDriver.getConfiguration();
        configure(configuration);
        System.out.printf("%s mapper vs local test\n",
                mapDriver.getConfiguration().get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));

        for (Centroid datapoint : syntheticData.getFirst()) {
            mapDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
        }
        List<Centroid> mapperCentroids = Lists.newArrayList();
        for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> pair : mapDriver.run()) {
            mapperCentroids.add(pair.getSecond().getCentroid());
        }

        // Clusters the data using local batch StreamingKMeans.
        StreamingKMeans batchClusterer = new StreamingKMeans(
                StreamingKMeansUtilsMR.searcherFromConfiguration(configuration),
                mapDriver.getConfiguration().getInt("estimatedNumMapClusters", -1), DISTANCE_CUTOFF);
        batchClusterer.cluster(syntheticData.getFirst());
        List<Centroid> batchCentroids = Lists.newArrayList();
        for (Vector v : batchClusterer) {
            batchCentroids.add((Centroid) v);
        }

        // Clusters the data using point by point StreamingKMeans.
        StreamingKMeans perPointClusterer = new StreamingKMeans(
                StreamingKMeansUtilsMR.searcherFromConfiguration(configuration),
                (1 << NUM_DIMENSIONS) * (int) Math.log(NUM_DATA_POINTS), DISTANCE_CUTOFF);
        for (Centroid datapoint : syntheticData.getFirst()) {
            perPointClusterer.cluster(datapoint);
        }
        List<Centroid> perPointCentroids = Lists.newArrayList();
        for (Vector v : perPointClusterer) {
            perPointCentroids.add((Centroid) v);
        }

        // Computes the cost (total sum of distances) of these different clusterings.
        double mapperCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), mapperCentroids);
        double localCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), batchCentroids);
        double perPointCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), perPointCentroids);
        System.out.printf(
                "[Total cost] Mapper %f [%d] Local %f [%d] Perpoint local %f [%d];"
                        + "[ratio m-vs-l %f] [ratio pp-vs-l %f]\n",
                mapperCost, mapperCentroids.size(), localCost, batchCentroids.size(), perPointCost,
                perPointCentroids.size(), mapperCost / localCost, perPointCost / localCost);

        // These ratios should be close to 1.0 and have been observed to be go as low as 0.6 and as low as 1.5.
        // A buffer of [0.2, 1.8] seems appropriate.
        assertEquals("Mapper StreamingKMeans / Batch local StreamingKMeans total cost ratio too far from 1", 1.0,
                mapperCost / localCost, 0.8);
        assertEquals("One by one local StreamingKMeans / Batch local StreamingKMeans total cost ratio too high",
                1.0, perPointCost / localCost, 0.8);
    }

    @Test
    public void testHypercubeReducer() throws IOException {
        ReduceDriver<IntWritable, CentroidWritable, IntWritable, CentroidWritable> reduceDriver = ReduceDriver
                .newReduceDriver(new StreamingKMeansReducer());
        Configuration configuration = reduceDriver.getConfiguration();
        configure(configuration);

        System.out.printf("%s reducer test\n", configuration.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
        StreamingKMeans clusterer = new StreamingKMeans(
                StreamingKMeansUtilsMR.searcherFromConfiguration(configuration),
                (1 << NUM_DIMENSIONS) * (int) Math.log(NUM_DATA_POINTS), DISTANCE_CUTOFF);

        long start = System.currentTimeMillis();
        clusterer.cluster(syntheticData.getFirst());
        long end = System.currentTimeMillis();

        System.out.printf("%f [s]\n", (end - start) / 1000.0);
        List<CentroidWritable> reducerInputs = Lists.newArrayList();
        int postMapperTotalWeight = 0;
        for (Centroid intermediateCentroid : clusterer) {
            reducerInputs.add(new CentroidWritable(intermediateCentroid));
            postMapperTotalWeight += intermediateCentroid.getWeight();
        }

        reduceDriver.addInput(new IntWritable(0), reducerInputs);
        List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results = reduceDriver.run();
        testReducerResults(postMapperTotalWeight, results);
    }

    @Test
    public void testHypercubeMapReduce() throws IOException {
        MapReduceDriver<Writable, VectorWritable, IntWritable, CentroidWritable, IntWritable, CentroidWritable> mapReduceDriver = new MapReduceDriver<Writable, VectorWritable, IntWritable, CentroidWritable, IntWritable, CentroidWritable>(
                new StreamingKMeansMapper(), new StreamingKMeansReducer());
        Configuration configuration = mapReduceDriver.getConfiguration();
        configure(configuration);

        System.out.printf("%s full test\n", configuration.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
        for (Centroid datapoint : syntheticData.getFirst()) {
            mapReduceDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
        }
        List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results = mapReduceDriver.run();
        testReducerResults(syntheticData.getFirst().size(), results);
    }

    @Test
    public void testHypercubeMapReduceRunSequentially() throws Exception {
        Configuration configuration = getConfiguration();
        configure(configuration);
        configuration.set(DefaultOptionCreator.METHOD_OPTION, DefaultOptionCreator.SEQUENTIAL_METHOD);

        Path inputPath = new Path("testInput");
        Path outputPath = new Path("testOutput");
        StreamingKMeansUtilsMR.writeVectorsToSequenceFile(syntheticData.getFirst(), inputPath, configuration);

        StreamingKMeansDriver.run(configuration, inputPath, outputPath);

        testReducerResults(syntheticData.getFirst().size(), Lists.newArrayList(Iterables.transform(
                new SequenceFileIterable<IntWritable, CentroidWritable>(outputPath, configuration),
                new Function<Pair<IntWritable, CentroidWritable>, org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>>() {
                    @Override
                    public org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> apply(
                            org.apache.mahout.common.Pair<IntWritable, CentroidWritable> input) {
                        return new org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>(
                                input.getFirst(), input.getSecond());
                    }
                })));
    }

    private static void testReducerResults(int totalWeight,
            List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results) {
        int expectedNumClusters = 1 << NUM_DIMENSIONS;
        double expectedWeight = (double) totalWeight / expectedNumClusters;
        int numClusters = 0;
        int numUnbalancedClusters = 0;
        int totalReducerWeight = 0;
        for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> result : results) {
            if (result.getSecond().getCentroid().getWeight() != expectedWeight) {
                System.out.printf("Unbalanced weight %f in centroid %d\n",
                        result.getSecond().getCentroid().getWeight(), result.getSecond().getCentroid().getIndex());
                ++numUnbalancedClusters;
            }
            assertEquals("Final centroid index is invalid", numClusters, result.getFirst().get());
            totalReducerWeight += result.getSecond().getCentroid().getWeight();
            ++numClusters;
        }
        System.out.printf("%d clusters are unbalanced\n", numUnbalancedClusters);
        assertEquals("Invalid total weight", totalWeight, totalReducerWeight);
        assertEquals("Invalid number of clusters", 1 << NUM_DIMENSIONS, numClusters);
    }

}