com.cloudera.science.ml.kmeans.parallel.CentersIndex.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.science.ml.kmeans.parallel.CentersIndex.java

Source

/**
 * Copyright (c) 2013, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */
package com.cloudera.science.ml.kmeans.parallel;

import com.cloudera.science.ml.core.vectors.Centers;
import com.cloudera.science.ml.core.vectors.Vectors;
import com.cloudera.science.ml.core.vectors.Weighted;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.io.Serializable;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.SortedSet;

import org.apache.mahout.math.Vector;

/**
 * An internal data structure that manages the locations of the current centers during
 * k-means|| processing.
 */
class CentersIndex implements Serializable {
    private final int[] pointsPerCenter;
    private final List<List<BitSet>> indices;
    private final List<List<double[]>> points;
    private final List<List<Double>> lengthSquared;
    private final int dimensions;
    private final int projectionBits;
    private final int projectionSamples;
    private final long seed;

    private double[] projection;
    private boolean updated;

    public static class Distances {
        public double[] clusterDistances;
        public int[] closestPoints;

        public Distances(double[] clusterDistances, int[] closestPoints) {
            this.clusterDistances = clusterDistances;
            this.closestPoints = closestPoints;
        }
    }

    public CentersIndex(int numClusterings, int dimensions) {
        this(numClusterings, dimensions, 128, 10, 1729L);
    }

    public CentersIndex(int numClusterings, int dimensions, int projectionBits, int projectionSamples, long seed) {
        this.pointsPerCenter = new int[numClusterings];
        this.indices = Lists.newArrayList();
        this.points = Lists.newArrayList();
        this.lengthSquared = Lists.newArrayList();
        for (int i = 0; i < numClusterings; i++) {
            points.add(Lists.<double[]>newArrayList());
            lengthSquared.add(Lists.<Double>newArrayList());
        }
        this.dimensions = dimensions;
        this.projectionBits = projectionBits;
        this.projectionSamples = projectionSamples;
        this.seed = seed;
    }

    public CentersIndex(List<Centers> centers) {
        this(centers, 128, 10, 1729L);
    }

    public CentersIndex(List<Centers> centers, int projectionBits, int projectionSamples, long seed) {
        this(centers.size(), centers.get(0).get(0).size(), projectionBits, projectionSamples, seed);
        for (int centerId = 0; centerId < centers.size(); centerId++) {
            for (Vector v : centers.get(centerId)) {
                add(v, centerId);
            }
        }
    }

    public int getNumCenters() {
        return pointsPerCenter.length;
    }

    public int[] getPointsPerCluster() {
        return pointsPerCenter;
    }

    private void buildIndices() {
        if (projection == null) {
            Random r = new Random(seed);
            this.projection = new double[dimensions * projectionBits];
            for (int i = 0; i < projection.length; i++) {
                projection[i] = r.nextGaussian();
            }
        }
        indices.clear();
        for (int i = 0; i < points.size(); i++) {
            List<double[]> px = points.get(i);
            List<BitSet> indx = Lists.newArrayList();
            for (int j = 0; j < px.size(); j++) {
                indx.add(index(Vectors.of(px.get(j))));
            }
            indices.add(indx);
        }
        updated = false;
    }

    public void add(Vector vec, int centerId) {
        points.get(centerId).add(Vectors.toArray(vec));
        lengthSquared.get(centerId).add(vec.getLengthSquared());
        pointsPerCenter[centerId]++;
        updated = true;
    }

    private BitSet index(Vector vec) {
        double[] prod = new double[projectionBits];
        if (vec.isDense()) {
            for (int i = 0; i < vec.size(); i++) {
                double v = vec.getQuick(i);
                if (v != 0.0) {
                    for (int j = 0; j < projectionBits; j++) {
                        prod[j] += v * projection[i + j * dimensions];
                    }
                }
            }
        } else {
            Iterator<Vector.Element> iter = vec.iterateNonZero();
            while (iter.hasNext()) {
                Vector.Element e = iter.next();
                for (int j = 0; j < projectionBits; j++) {
                    prod[j] = e.get() * projection[e.index() + j * dimensions];
                }
            }
        }
        BitSet bitset = new BitSet(projectionBits);
        for (int i = 0; i < projectionBits; i++) {
            if (prod[i] > 0.0) {
                bitset.set(i);
            }
        }
        return bitset;
    }

    public Distances getDistances(Vector vec, boolean approx) {
        int[] closestPoints = new int[pointsPerCenter.length];
        double[] distances = new double[pointsPerCenter.length];

        if (approx) {
            if (updated) {
                buildIndices();
            }

            BitSet q = index(vec);
            for (int i = 0; i < pointsPerCenter.length; i++) {
                List<BitSet> index = indices.get(i);
                SortedSet<Idx> lookup = Sets.newTreeSet();
                for (int j = 0; j < index.size(); j++) {
                    Idx idx = new Idx(hammingDistance(q, index.get(j)), j);
                    if (lookup.size() < projectionSamples) {
                        lookup.add(idx);
                    } else if (idx.compareTo(lookup.last()) < 0) {
                        lookup.add(idx);
                        lookup.remove(lookup.last());
                    }
                }

                List<double[]> p = points.get(i);
                distances[i] = Double.POSITIVE_INFINITY;
                for (Idx idx : lookup) {
                    double lenSq = lengthSquared.get(i).get(idx.index);
                    double d = vec.getLengthSquared() + lenSq - 2 * dot(vec, p.get(idx.index));
                    if (d < distances[i]) {
                        distances[i] = d;
                        closestPoints[i] = idx.index;
                    }
                }
            }
        } else { // More expensive exact computation
            for (int i = 0; i < pointsPerCenter.length; i++) {
                distances[i] = Double.POSITIVE_INFINITY;
                List<double[]> px = points.get(i);
                List<Double> lsq = lengthSquared.get(i);
                for (int j = 0; j < px.size(); j++) {
                    double[] p = px.get(j);
                    double lenSq = lsq.get(j);
                    double d = vec.getLengthSquared() + lenSq - 2 * dot(vec, p);
                    if (d < distances[i]) {
                        distances[i] = d;
                        closestPoints[i] = j;
                    }
                }
            }
        }

        return new Distances(distances, closestPoints);
    }

    static class Idx implements Comparable<Idx> {
        int distance;
        int index;

        public Idx(int distance, int index) {
            this.distance = distance;
            this.index = index;
        }

        @Override
        public int compareTo(Idx idx) {
            return distance - idx.distance;
        }
    }

    private static int hammingDistance(BitSet q, BitSet idx) {
        BitSet x = new BitSet(q.size());
        x.or(q);
        x.xor(idx);
        return x.cardinality();
    }

    private static double dot(Vector vec, double[] p) {
        double dot = 0;
        if (vec.isDense()) {
            for (int i = 0; i < p.length; i++) {
                dot += vec.getQuick(i) * p[i];
            }
        } else {
            Iterator<Vector.Element> iter = vec.iterateNonZero();
            while (iter.hasNext()) {
                Vector.Element e = iter.next();
                dot += e.get() * p[e.index()];
            }
        }
        return dot;
    }

    public List<List<Weighted<Vector>>> getWeightedVectors(List<List<Long>> pointCounts) {
        List<List<Weighted<Vector>>> ret = Lists.newArrayList();
        for (int i = 0; i < pointCounts.size(); i++) {
            List<Long> counts = pointCounts.get(i);
            List<double[]> p = points.get(i);
            List<Weighted<Vector>> weighted = Lists.newArrayList();
            for (int j = 0; j < counts.size(); j++) {
                weighted.add(new Weighted<Vector>(Vectors.of(p.get(j)), counts.get(j)));
            }
            ret.add(weighted);
        }
        return ret;
    }

}