com.cloudera.science.ml.core.vectors.Centers.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.science.ml.core.vectors.Centers.java

Source

/**
 * Copyright (c) 2012, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */
package com.cloudera.science.ml.core.vectors;

import java.util.AbstractList;
import java.util.Arrays;
import java.util.List;

import org.apache.mahout.math.Vector;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;

/**
 * Represents a collection of {@code Vector} instances that act as the centers of
 * a set of clusters, as in a k-means model.
 */
public class Centers extends AbstractList<Vector> {
    // The vectors, where each vector is the center of a particular cluster
    private final List<Vector> centers;

    /**
     * Create a new instance from the given points. Any duplicate
     * points in the arg list will be removed.
     * 
     * @param points The points
     * @throws IllegalArgumentException if no points are given
     */
    public Centers(Vector... points) {
        this(Arrays.asList(points));
    }

    /**
     * Create a new instance from the given points. Any duplicate
     * points in the {@code Iterable} instance will be removed.
     * 
     * @param points The points
     * @throws IllegalArgumentException if the input is empty
     */
    public Centers(Iterable<Vector> points) {
        this.centers = ImmutableList.copyOf(Sets.newLinkedHashSet(points));
    }

    /**
     * Returns the number of points in this instance.
     */
    @Override
    public int size() {
        return centers.size();
    }

    /**
     * Returns the {@code Vector} at the given index.
     */
    @Override
    public Vector get(int index) {
        return centers.get(index);
    }

    /**
     * Construct a new {@code Centers} object made up of the given {@code Vector}
     * and the points contained in this instance.
     * 
     * @param point The new point
     * @return A new {@code Centers} instance
     */
    public Centers extendWith(Vector point) {
        return new Centers(Iterables.concat(centers, ImmutableList.of(point)));
    }

    /**
     * Construct a new {@code Centers} object made up of the given points
     * and the points contained in this instance.
     * 
     * @param points The new points
     * @return A new {@code Centers} instance
     */
    public Centers extendWith(Iterable<Vector> points) {
        return new Centers(Iterables.concat(centers, points));
    }

    /**
     * Returns the minimum squared Euclidean distance between the given
     * {@code Vector} and a point contained in this instance.
     * 
     * @param point The point
     * @return The minimum squared Euclidean distance from the point 
     */
    public double getDistanceSquared(Vector point) {
        double min = Double.POSITIVE_INFINITY;
        for (Vector c : centers) {
            min = Math.min(min, c.getDistanceSquared(point));
        }
        return min;
    }

    /**
     * Returns the index of the {@code Vector} within this instance that is
     * closest to the given {@code Vector}.
     * 
     * @param point The point
     * @return The index of the closest {@code Vector} to the given point
     */
    public int indexOfClosest(Vector point) {
        int index = -1;
        double min = Double.POSITIVE_INFINITY;
        for (int i = 0; i < centers.size(); i++) {
            double d = centers.get(i).getDistanceSquared(point);
            if (d < min) {
                min = d;
                index = i;
            }
        }
        return index;
    }

    /**
     * Calculate the sum of the element-wise squared distances between this
     * instance and the given {@code Centers}.
     * 
     * @param other The other {@code Centers} instance
     * @return The sum of the squared distances on a point-by-point basis
     * @throws IllegalArgumentException if the given {@code Centers} object is not
     *     the same size as this one
     */
    public double getSumOfSquaredDistances(Centers other) {
        Preconditions.checkArgument(size() == other.size(),
                String.format("Expected %d but found %d", size(), other.size()));
        double sum = 0.0;
        for (int i = 0; i < centers.size(); i++) {
            sum += centers.get(i).getDistanceSquared(other.centers.get(i));
        }
        return sum;
    }

    @Override
    public boolean equals(Object other) {
        if (!(other instanceof Centers)) {
            return false;
        }
        Centers c = (Centers) other;
        return centers.containsAll(c.centers) && c.centers.containsAll(centers);
    }

    @Override
    public int hashCode() {
        int hc = 0;
        for (Vector center : centers) {
            hc += center.hashCode();
        }
        return hc;
    }

    @Override
    public String toString() {
        return centers.toString();
    }
}