org.apache.hama.ml.kmeans.KMeansBSP.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hama.ml.kmeans.KMeansBSP.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hama.ml.kmeans;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Random;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.SequenceFile.CompressionType;
import org.apache.hadoop.io.SequenceFile.Writer;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.NamedDoubleVector;
import org.apache.hama.ml.distance.DistanceMeasurer;
import org.apache.hama.ml.distance.EuclidianDistance;
import org.apache.hama.util.ReflectionUtils;

import com.google.common.base.Preconditions;

/**
 * K-Means in BSP that reads a bunch of vectors from input system and a given
 * centroid path that contains initial centers.
 * 
 */
public final class KMeansBSP extends BSP<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> {

    public static final String CENTER_OUT_PATH = "center.out.path";
    public static final String MAX_ITERATIONS_KEY = "k.means.max.iterations";
    public static final String CACHING_ENABLED_KEY = "k.means.caching.enabled";
    public static final String DISTANCE_MEASURE_CLASS = "distance.measure.class";
    public static final String CENTER_IN_PATH = "center.in.path";

    private static final Log LOG = LogFactory.getLog(KMeansBSP.class);
    // a task local copy of our cluster centers
    private DoubleVector[] centers;
    // simple cache to speed up computation, because the algorithm is disk based
    private List<DoubleVector> cache;
    // numbers of maximum iterations to do
    private int maxIterations;
    // our distance measurement
    private DistanceMeasurer distanceMeasurer;
    private Configuration conf;

    @Override
    public final void setup(BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> peer)
            throws IOException, InterruptedException {
        conf = peer.getConfiguration();

        Path centroids = new Path(peer.getConfiguration().get(CENTER_IN_PATH));
        FileSystem fs = FileSystem.get(peer.getConfiguration());
        final ArrayList<DoubleVector> centers = new ArrayList<DoubleVector>();
        SequenceFile.Reader reader = null;
        try {
            reader = new SequenceFile.Reader(fs, centroids, peer.getConfiguration());
            VectorWritable key = new VectorWritable();
            NullWritable value = NullWritable.get();
            while (reader.next(key, value)) {
                DoubleVector center = key.getVector();
                centers.add(center);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        } finally {
            if (reader != null) {
                reader.close();
            }
        }

        Preconditions.checkArgument(centers.size() > 0, "Centers file must contain at least a single center!");
        this.centers = centers.toArray(new DoubleVector[centers.size()]);

        String distanceClass = peer.getConfiguration().get(DISTANCE_MEASURE_CLASS);
        if (distanceClass != null) {
            try {
                distanceMeasurer = ReflectionUtils.newInstance(distanceClass);
            } catch (ClassNotFoundException e) {
                throw new RuntimeException("Wrong DistanceMeasurer implementation " + distanceClass + " provided");
            }
        } else {
            distanceMeasurer = new EuclidianDistance();
        }

        maxIterations = peer.getConfiguration().getInt(MAX_ITERATIONS_KEY, -1);
        // normally we want to rely on OS caching, but if not, we can cache in heap
        if (peer.getConfiguration().getBoolean(CACHING_ENABLED_KEY, false)) {
            cache = new ArrayList<DoubleVector>();
        }
    }

    @Override
    public final void bsp(BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> peer)
            throws IOException, InterruptedException, SyncException {
        long converged;
        while (true) {
            assignCenters(peer);
            peer.sync();
            converged = updateCenters(peer);
            peer.reopenInput();
            if (converged == 0)
                break;
            if (maxIterations > 0 && maxIterations < peer.getSuperstepCount())
                break;
        }
        LOG.info("Finished! Writing the assignments...");
        recalculateAssignmentsAndWrite(peer);
        LOG.info("Done.");
    }

    private long updateCenters(
            BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> peer)
            throws IOException {
        // this is the update step
        DoubleVector[] msgCenters = new DoubleVector[centers.length];
        int[] incrementSum = new int[centers.length];
        CenterMessage msg;
        // basically just summing incoming vectors
        while ((msg = peer.getCurrentMessage()) != null) {
            DoubleVector oldCenter = msgCenters[msg.getCenterIndex()];
            DoubleVector newCenter = msg.getData();
            incrementSum[msg.getCenterIndex()] += msg.getIncrementCounter();
            if (oldCenter == null) {
                msgCenters[msg.getCenterIndex()] = newCenter;
            } else {
                msgCenters[msg.getCenterIndex()] = oldCenter.addUnsafe(newCenter);
            }
        }
        // divide by how often we globally summed vectors
        for (int i = 0; i < msgCenters.length; i++) {
            // and only if we really have an update for c
            if (msgCenters[i] != null) {
                msgCenters[i] = msgCenters[i].divide(incrementSum[i]);
            }
        }
        // finally check for convergence by the absolute difference
        long convergedCounter = 0L;
        for (int i = 0; i < msgCenters.length; i++) {
            final DoubleVector oldCenter = centers[i];
            if (msgCenters[i] != null) {
                double calculateError = oldCenter.subtractUnsafe(msgCenters[i]).abs().sum();
                if (calculateError > 0.0d) {
                    centers[i] = msgCenters[i];
                    convergedCounter++;
                }
            }
        }
        return convergedCounter;
    }

    private void assignCenters(
            BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> peer)
            throws IOException {
        // each task has all the centers, if a center has been updated it
        // needs to be broadcasted.
        final DoubleVector[] newCenterArray = new DoubleVector[centers.length];
        final int[] summationCount = new int[centers.length];

        // if our cache is not enabled, iterate over the disk items
        if (cache == null) {
            // we have an assignment step
            final NullWritable value = NullWritable.get();
            final VectorWritable key = new VectorWritable();
            while (peer.readNext(key, value)) {
                assignCentersInternal(newCenterArray, summationCount, key.getVector().deepCopy());
            }
        } else {
            // if our cache is enabled but empty, we have to read it from disk first
            if (cache.isEmpty()) {
                final NullWritable value = NullWritable.get();
                final VectorWritable key = new VectorWritable();
                while (peer.readNext(key, value)) {
                    DoubleVector deepCopy = key.getVector().deepCopy();
                    cache.add(deepCopy);
                    // but do the assignment directly
                    assignCentersInternal(newCenterArray, summationCount, deepCopy);
                }
            } else {
                // now we can iterate in memory and check against the centers
                for (DoubleVector v : cache) {
                    assignCentersInternal(newCenterArray, summationCount, v);
                }
            }
        }

        // now send messages about the local updates to each other peer
        for (int i = 0; i < newCenterArray.length; i++) {
            if (newCenterArray[i] != null) {
                for (String peerName : peer.getAllPeerNames()) {
                    peer.send(peerName, new CenterMessage(i, summationCount[i], newCenterArray[i]));
                }
            }
        }
    }

    private void assignCentersInternal(final DoubleVector[] newCenterArray, final int[] summationCount,
            final DoubleVector key) {
        final int lowestDistantCenter = getNearestCenter(key);
        final DoubleVector clusterCenter = newCenterArray[lowestDistantCenter];

        if (clusterCenter == null) {
            newCenterArray[lowestDistantCenter] = key;
        } else {
            // add the vector to the center
            newCenterArray[lowestDistantCenter] = newCenterArray[lowestDistantCenter].addUnsafe(key);
        }
        summationCount[lowestDistantCenter]++;
    }

    private int getNearestCenter(DoubleVector key) {
        int lowestDistantCenter = 0;
        double lowestDistance = Double.MAX_VALUE;

        for (int i = 0; i < centers.length; i++) {
            final double estimatedDistance = distanceMeasurer.measureDistance(centers[i], key);
            // check if we have a can assign a new center, because we
            // got a lower distance
            if (estimatedDistance < lowestDistance) {
                lowestDistance = estimatedDistance;
                lowestDistantCenter = i;
            }
        }
        return lowestDistantCenter;
    }

    private void recalculateAssignmentsAndWrite(
            BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> peer)
            throws IOException {
        final NullWritable value = NullWritable.get();
        // also use our cache to speed up the final writes if exists
        if (cache == null) {
            final VectorWritable key = new VectorWritable();
            IntWritable keyWrite = new IntWritable();
            while (peer.readNext(key, value)) {
                final int lowestDistantCenter = getNearestCenter(key.getVector());
                keyWrite.set(lowestDistantCenter);
                peer.write(keyWrite, key);
            }
        } else {
            IntWritable keyWrite = new IntWritable();
            for (DoubleVector v : cache) {
                final int lowestDistantCenter = getNearestCenter(v);
                keyWrite.set(lowestDistantCenter);
                peer.write(keyWrite, new VectorWritable(v));
            }
        }
        // just on the first task write the centers to filesystem to prevent
        // collisions
        if (peer.getPeerName().equals(peer.getPeerName(0))) {
            String pathString = conf.get(CENTER_OUT_PATH);
            if (pathString != null) {
                final SequenceFile.Writer dataWriter = SequenceFile.createWriter(FileSystem.get(conf), conf,
                        new Path(pathString), VectorWritable.class, NullWritable.class, CompressionType.NONE);
                for (DoubleVector center : centers) {
                    dataWriter.append(new VectorWritable(center), value);
                }
                dataWriter.close();
            }
        }
    }

    /**
     * Creates a basic job with sequencefiles as in and output.
     */
    public static BSPJob createJob(Configuration cnf, Path in, Path out, boolean textOut) throws IOException {
        HamaConfiguration conf = new HamaConfiguration(cnf);
        BSPJob job = new BSPJob(conf, KMeansBSP.class);
        job.setJobName("KMeans Clustering");
        job.setJarByClass(KMeansBSP.class);
        job.setBspClass(KMeansBSP.class);
        job.setInputPath(in);
        job.setOutputPath(out);
        job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class);
        if (textOut)
            job.setOutputFormat(org.apache.hama.bsp.TextOutputFormat.class);
        else
            job.setOutputFormat(org.apache.hama.bsp.SequenceFileOutputFormat.class);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(VectorWritable.class);
        return job;
    }

    public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {

        if (args.length < 6) {
            LOG.info(
                    "USAGE: <INPUT_PATH> <OUTPUT_PATH> <COUNT> <K> <DIMENSION OF VECTORS> <MAXITERATIONS> <optional: num of tasks>");
            return;
        }

        Configuration conf = new Configuration();
        int count = Integer.parseInt(args[2]);
        int k = Integer.parseInt(args[3]);
        int dimension = Integer.parseInt(args[4]);
        int iterations = Integer.parseInt(args[5]);
        conf.setInt(MAX_ITERATIONS_KEY, iterations);

        Path in = new Path(args[0]);
        Path out = new Path(args[1]);
        Path center = new Path(in, "center/cen.seq");
        Path centerOut = new Path(out, "center/center_output.seq");

        conf.set(CENTER_IN_PATH, center.toString());
        conf.set(CENTER_OUT_PATH, centerOut.toString());
        // if you're in local mode, you can increase this to match your core sizes
        conf.set("bsp.local.tasks.maximum", "" + Runtime.getRuntime().availableProcessors());
        // deactivate (set to false) if you want to iterate over disk, else it will
        // cache the input vectors in memory
        conf.setBoolean(CACHING_ENABLED_KEY, true);
        BSPJob job = createJob(conf, in, out, false);

        LOG.info("N: " + count + " k: " + k + " Dimension: " + dimension + " Iterations: " + iterations);

        FileSystem fs = FileSystem.get(conf);
        // prepare the input, like deleting old versions and creating centers
        prepareInput(count, k, dimension, conf, in, center, out, fs);
        if (args.length == 7) {
            job.setNumBspTask(Integer.parseInt(args[6]));
        }

        // just submit the job
        job.waitForCompletion(true);
    }

    /**
     * Reads the cluster centers.
     * 
     * @return an index on the key dimension, and a cluster center on the value.
     */
    public static HashMap<Integer, DoubleVector> readClusterCenters(Configuration conf, Path out, Path centerPath,
            FileSystem fs) throws IOException {
        HashMap<Integer, DoubleVector> centerMap = new HashMap<Integer, DoubleVector>();
        SequenceFile.Reader centerReader = new SequenceFile.Reader(fs, centerPath, conf);
        int index = 0;
        VectorWritable center = new VectorWritable();
        while (centerReader.next(center, NullWritable.get())) {
            centerMap.put(index++, center.getVector());
        }
        centerReader.close();
        return centerMap;
    }

    /**
     * Reads output. The list of output records can be restricted to maxlines.
     * 
     * @param conf
     * @param outPath
     * @param fs
     * @param maxlines
     * @return the list of output records
     * @throws IOException
     */
    public static List<String> readOutput(Configuration conf, Path outPath, FileSystem fs, int maxlines)
            throws IOException {
        List<String> output = new ArrayList<String>();

        FileStatus[] globStatus = fs.globStatus(new Path(outPath + "/part-*"));
        for (FileStatus fts : globStatus) {
            BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(fts.getPath())));
            String line = null;
            while ((line = reader.readLine()) != null) {
                String[] split = line.split("\t");
                output.add(split[1] + " belongs to cluster " + split[0]);

                if (output.size() >= maxlines)
                    return output;
            }
        }

        return output;
    }

    /**
     * Reads input text files and writes it to a sequencefile.
     * 
     * @param k
     * @param conf
     * @param txtIn
     * @param center
     * @param out
     * @param fs
     * @param hasKey true if first column is required to be the key.
     * @return the path of a sequencefile.
     * @throws IOException
     */
    public static Path prepareInputText(int k, Configuration conf, Path txtIn, Path center, Path out, FileSystem fs,
            boolean hasKey) throws IOException {

        Path in;
        if (fs.isFile(txtIn)) {
            in = new Path(txtIn.getParent(), "textinput/in.seq");
        } else {
            in = new Path(txtIn, "textinput/in.seq");
        }

        if (fs.exists(out))
            fs.delete(out, true);

        if (fs.exists(center))
            fs.delete(center, true);

        if (fs.exists(in))
            fs.delete(in, true);

        final NullWritable value = NullWritable.get();

        Writer centerWriter = new SequenceFile.Writer(fs, conf, center, VectorWritable.class, NullWritable.class);

        final SequenceFile.Writer dataWriter = SequenceFile.createWriter(fs, conf, in, VectorWritable.class,
                NullWritable.class, CompressionType.NONE);

        int i = 0;

        BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(txtIn)));
        String line;
        while ((line = br.readLine()) != null) {
            String[] split = line.split("\t");
            int columnLength = split.length;
            int indexPos = 0;
            if (hasKey) {
                columnLength = columnLength - 1;
                indexPos++;
            }

            DenseDoubleVector vec = new DenseDoubleVector(columnLength);
            for (int j = 0; j < columnLength; j++) {
                vec.set(j, Double.parseDouble(split[j + indexPos]));
            }

            VectorWritable vector;
            if (hasKey) {
                NamedDoubleVector named = new NamedDoubleVector(split[0], vec);
                vector = new VectorWritable(named);
            } else {
                vector = new VectorWritable(vec);
            }

            dataWriter.append(vector, value);
            if (k > i) {
                centerWriter.append(vector, value);
            }
            i++;
        }
        br.close();
        centerWriter.close();
        dataWriter.close();
        return in;
    }

    /**
     * Create some random vectors as input and assign the first k vectors as
     * intial centers.
     */
    public static void prepareInput(int count, int k, int dimension, Configuration conf, Path in, Path center,
            Path out, FileSystem fs) throws IOException {
        if (fs.exists(out))
            fs.delete(out, true);

        if (fs.exists(center))
            fs.delete(center, true);

        if (fs.exists(in))
            fs.delete(in, true);

        final SequenceFile.Writer centerWriter = SequenceFile.createWriter(fs, conf, center, VectorWritable.class,
                NullWritable.class, CompressionType.NONE);
        final NullWritable value = NullWritable.get();

        final SequenceFile.Writer dataWriter = SequenceFile.createWriter(fs, conf, in, VectorWritable.class,
                NullWritable.class, CompressionType.NONE);

        Random r = new Random();
        for (int i = 0; i < count; i++) {

            double[] arr = new double[dimension];
            for (int d = 0; d < dimension; d++) {
                arr[d] = r.nextInt(count);
            }
            VectorWritable vector = new VectorWritable(new DenseDoubleVector(arr));
            dataWriter.append(vector, value);
            if (k > i) {
                centerWriter.append(vector, value);
            }
        }
        centerWriter.close();
        dataWriter.close();
    }
}