org.apache.mahout.math.AbstractVector.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.math.AbstractVector.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math;

import java.util.Iterator;

import com.google.common.base.Preconditions;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;

/** Implementations of generic capabilities like sum of elements and dot products */
public abstract class AbstractVector implements Vector, LengthCachingVector {

    private int size;
    protected double lengthSquared = -1.0;

    protected AbstractVector(int size) {
        this.size = size;
    }

    @Override
    public Iterable<Element> all() {
        return new Iterable<Element>() {
            @Override
            public Iterator<Element> iterator() {
                return AbstractVector.this.iterator();
            }
        };
    }

    @Override
    public Iterable<Element> nonZeroes() {
        return new Iterable<Element>() {
            @Override
            public Iterator<Element> iterator() {
                return iterateNonZero();
            }
        };
    }

    /**
     * Iterates over all elements <p>
     * NOTE: Implementations may choose to reuse the Element returned for performance
     * reasons, so if you need a copy of it, you should call {@link #getElement(int)} for the given index
     *
     * @return An {@link Iterator} over all elements
     */
    protected abstract Iterator<Element> iterator();

    /**
     * Iterates over all non-zero elements. <p>
     * NOTE: Implementations may choose to reuse the Element returned for
     * performance reasons, so if you need a copy of it, you should call {@link #getElement(int)} for the given index
     *
     * @return An {@link Iterator} over all non-zero elements
     */
    protected abstract Iterator<Element> iterateNonZero();

    /**
     * Aggregates a vector by applying a mapping function fm(x) to every component and aggregating
     * the results with an aggregating function fa(x, y).
     *
     * @param aggregator used to combine the current value of the aggregation with the result of map.apply(nextValue)
     * @param map a function to apply to each element of the vector in turn before passing to the aggregator
     * @return the result of the aggregation
     */
    @Override
    public double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map) {
        if (size == 0) {
            return 0;
        }

        // If the aggregator is associative and commutative and it's likeLeftMult (fa(0, y) = 0), and there is
        // at least one zero in the vector (size > getNumNondefaultElements) and applying fm(0) = 0, the result
        // gets cascaded through the aggregation and the final result will be 0.
        if (aggregator.isAssociativeAndCommutative() && aggregator.isLikeLeftMult()
                && size > getNumNondefaultElements() && !map.isDensifying()) {
            return 0;
        }

        double result;
        if (isSequentialAccess() || aggregator.isAssociativeAndCommutative()) {
            Iterator<Element> iterator;
            // If fm(0) = 0 and fa(x, 0) = x, we can skip all zero values.
            if (!map.isDensifying() && aggregator.isLikeRightPlus()) {
                iterator = iterateNonZero();
                if (!iterator.hasNext()) {
                    return 0;
                }
            } else {
                iterator = iterator();
            }
            Element element = iterator.next();
            result = map.apply(element.get());
            while (iterator.hasNext()) {
                element = iterator.next();
                result = aggregator.apply(result, map.apply(element.get()));
            }
        } else {
            result = map.apply(getQuick(0));
            for (int i = 1; i < size; i++) {
                result = aggregator.apply(result, map.apply(getQuick(i)));
            }
        }

        return result;
    }

    @Override
    public double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner) {
        Preconditions.checkArgument(size == other.size(), "Vector sizes differ");
        if (size == 0) {
            return 0;
        }
        return VectorBinaryAggregate.aggregateBest(this, other, aggregator, combiner);
    }

    /**
     * Subclasses must override to return an appropriately sparse or dense result
     *
     * @param rows    the row cardinality
     * @param columns the column cardinality
     * @return a Matrix
     */
    protected abstract Matrix matrixLike(int rows, int columns);

    @Override
    public Vector viewPart(int offset, int length) {
        if (offset < 0) {
            throw new IndexException(offset, size);
        }
        if (offset + length > size) {
            throw new IndexException(offset + length, size);
        }
        return new VectorView(this, offset, length);
    }

    @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException")
    @Override
    public Vector clone() {
        try {
            AbstractVector r = (AbstractVector) super.clone();
            r.size = size;
            r.lengthSquared = lengthSquared;
            return r;
        } catch (CloneNotSupportedException e) {
            throw new IllegalStateException("Can't happen");
        }
    }

    @Override
    public Vector divide(double x) {
        if (x == 1.0) {
            return clone();
        }
        Vector result = createOptimizedCopy();
        for (Element element : result.nonZeroes()) {
            element.set(element.get() / x);
        }
        return result;
    }

    @Override
    public double dot(Vector x) {
        if (size != x.size()) {
            throw new CardinalityException(size, x.size());
        }
        if (this == x) {
            return getLengthSquared();
        }
        return aggregate(x, Functions.PLUS, Functions.MULT);
    }

    protected double dotSelf() {
        return aggregate(Functions.PLUS, Functions.pow(2));
    }

    @Override
    public double get(int index) {
        if (index < 0 || index >= size) {
            throw new IndexException(index, size);
        }
        return getQuick(index);
    }

    @Override
    public Element getElement(int index) {
        return new LocalElement(index);
    }

    @Override
    public Vector normalize() {
        return divide(Math.sqrt(getLengthSquared()));
    }

    @Override
    public Vector normalize(double power) {
        return divide(norm(power));
    }

    @Override
    public Vector logNormalize() {
        return logNormalize(2.0, Math.sqrt(getLengthSquared()));
    }

    @Override
    public Vector logNormalize(double power) {
        return logNormalize(power, norm(power));
    }

    public Vector logNormalize(double power, double normLength) {
        // we can special case certain powers
        if (Double.isInfinite(power) || power <= 1.0) {
            throw new IllegalArgumentException("Power must be > 1 and < infinity");
        } else {
            double denominator = normLength * Math.log(power);
            Vector result = createOptimizedCopy();
            for (Element element : result.nonZeroes()) {
                element.set(Math.log1p(element.get()) / denominator);
            }
            return result;
        }
    }

    @Override
    public double norm(double power) {
        if (power < 0.0) {
            throw new IllegalArgumentException("Power must be >= 0");
        }
        // We can special case certain powers.
        if (Double.isInfinite(power)) {
            return aggregate(Functions.MAX, Functions.ABS);
        } else if (power == 2.0) {
            return Math.sqrt(getLengthSquared());
        } else if (power == 1.0) {
            double result = 0.0;
            Iterator<Element> iterator = this.iterateNonZero();
            while (iterator.hasNext()) {
                result += Math.abs(iterator.next().get());
            }
            return result;
            // TODO: this should ideally be used, but it's slower.
            // return aggregate(Functions.PLUS, Functions.ABS);
        } else if (power == 0.0) {
            return getNumNonZeroElements();
        } else {
            return Math.pow(aggregate(Functions.PLUS, Functions.pow(power)), 1.0 / power);
        }
    }

    @Override
    public double getLengthSquared() {
        if (lengthSquared >= 0.0) {
            return lengthSquared;
        }
        return lengthSquared = dotSelf();
    }

    @Override
    public void invalidateCachedLength() {
        lengthSquared = -1;
    }

    @Override
    public double getDistanceSquared(Vector that) {
        if (size != that.size()) {
            throw new CardinalityException(size, that.size());
        }
        double thisLength = getLengthSquared();
        double thatLength = that.getLengthSquared();
        double dot = dot(that);
        double distanceEstimate = thisLength + thatLength - 2 * dot;
        if (distanceEstimate > 1.0e-3 * (thisLength + thatLength)) {
            // The vectors are far enough from each other that the formula is accurate.
            return Math.max(distanceEstimate, 0);
        } else {
            return aggregate(that, Functions.PLUS, Functions.MINUS_SQUARED);
        }
    }

    @Override
    public double maxValue() {
        if (size == 0) {
            return Double.NEGATIVE_INFINITY;
        }
        return aggregate(Functions.MAX, Functions.IDENTITY);
    }

    @Override
    public int maxValueIndex() {
        int result = -1;
        double max = Double.NEGATIVE_INFINITY;
        int nonZeroElements = 0;
        Iterator<Element> iter = this.iterateNonZero();
        while (iter.hasNext()) {
            nonZeroElements++;
            Element element = iter.next();
            double tmp = element.get();
            if (tmp > max) {
                max = tmp;
                result = element.index();
            }
        }
        // if the maxElement is negative and the vector is sparse then any
        // unfilled element(0.0) could be the maxValue hence we need to
        // find one of those elements
        if (nonZeroElements < size && max < 0.0) {
            for (Element element : all()) {
                if (element.get() == 0.0) {
                    return element.index();
                }
            }
        }
        return result;
    }

    @Override
    public double minValue() {
        if (size == 0) {
            return Double.POSITIVE_INFINITY;
        }
        return aggregate(Functions.MIN, Functions.IDENTITY);
    }

    @Override
    public int minValueIndex() {
        int result = -1;
        double min = Double.POSITIVE_INFINITY;
        int nonZeroElements = 0;
        Iterator<Element> iter = this.iterateNonZero();
        while (iter.hasNext()) {
            nonZeroElements++;
            Element element = iter.next();
            double tmp = element.get();
            if (tmp < min) {
                min = tmp;
                result = element.index();
            }
        }
        // if the maxElement is positive and the vector is sparse then any
        // unfilled element(0.0) could be the maxValue hence we need to
        // find one of those elements
        if (nonZeroElements < size && min > 0.0) {
            for (Element element : all()) {
                if (element.get() == 0.0) {
                    return element.index();
                }
            }
        }
        return result;
    }

    @Override
    public Vector plus(double x) {
        Vector result = createOptimizedCopy();
        if (x == 0.0) {
            return result;
        }
        return result.assign(Functions.plus(x));
    }

    @Override
    public Vector plus(Vector that) {
        if (size != that.size()) {
            throw new CardinalityException(size, that.size());
        }
        return createOptimizedCopy().assign(that, Functions.PLUS);
    }

    @Override
    public Vector minus(Vector that) {
        if (size != that.size()) {
            throw new CardinalityException(size, that.size());
        }
        return createOptimizedCopy().assign(that, Functions.MINUS);
    }

    @Override
    public void set(int index, double value) {
        if (index < 0 || index >= size) {
            throw new IndexException(index, size);
        }
        setQuick(index, value);
    }

    @Override
    public void incrementQuick(int index, double increment) {
        setQuick(index, getQuick(index) + increment);
    }

    @Override
    public Vector times(double x) {
        if (x == 0.0) {
            return like();
        }
        return createOptimizedCopy().assign(Functions.mult(x));
    }

    /**
     * Copy the current vector in the most optimum fashion. Used by immutable methods like plus(), minus().
     * Use this instead of vector.like().assign(vector). Sub-class can choose to override this method.
     *
     * @return a copy of the current vector.
     */
    protected Vector createOptimizedCopy() {
        return createOptimizedCopy(this);
    }

    private static Vector createOptimizedCopy(Vector vector) {
        Vector result;
        if (vector.isDense()) {
            result = vector.like().assign(vector, Functions.SECOND_LEFT_ZERO);
        } else {
            result = vector.clone();
        }
        return result;
    }

    @Override
    public Vector times(Vector that) {
        if (size != that.size()) {
            throw new CardinalityException(size, that.size());
        }

        if (this.getNumNondefaultElements() <= that.getNumNondefaultElements()) {
            return createOptimizedCopy(this).assign(that, Functions.MULT);
        } else {
            return createOptimizedCopy(that).assign(this, Functions.MULT);
        }
    }

    @Override
    public double zSum() {
        return aggregate(Functions.PLUS, Functions.IDENTITY);
    }

    @Override
    public int getNumNonZeroElements() {
        int count = 0;
        Iterator<Element> it = iterateNonZero();
        while (it.hasNext()) {
            if (it.next().get() != 0.0) {
                count++;
            }
        }
        return count;
    }

    @Override
    public Vector assign(double value) {
        Iterator<Element> it;
        if (value == 0.0) {
            // Make all the non-zero values 0.
            it = iterateNonZero();
            while (it.hasNext()) {
                it.next().set(value);
            }
        } else {
            if (isSequentialAccess() && !isAddConstantTime()) {
                // Update all the non-zero values and queue the updates for the zero vaues.
                // The vector will become dense.
                it = iterator();
                OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping();
                while (it.hasNext()) {
                    Element element = it.next();
                    if (element.get() == 0.0) {
                        updates.set(element.index(), value);
                    } else {
                        element.set(value);
                    }
                }
                mergeUpdates(updates);
            } else {
                for (int i = 0; i < size; ++i) {
                    setQuick(i, value);
                }
            }
        }
        invalidateCachedLength();
        return this;
    }

    @Override
    public Vector assign(double[] values) {
        if (size != values.length) {
            throw new CardinalityException(size, values.length);
        }
        if (isSequentialAccess() && !isAddConstantTime()) {
            OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping();
            Iterator<Element> it = iterator();
            while (it.hasNext()) {
                Element element = it.next();
                int index = element.index();
                if (element.get() == 0.0) {
                    updates.set(index, values[index]);
                } else {
                    element.set(values[index]);
                }
            }
            mergeUpdates(updates);
        } else {
            for (int i = 0; i < size; ++i) {
                setQuick(i, values[i]);
            }
        }
        invalidateCachedLength();
        return this;
    }

    @Override
    public Vector assign(Vector other) {
        return assign(other, Functions.SECOND);
    }

    @Override
    public Vector assign(DoubleDoubleFunction f, double y) {
        Iterator<Element> iterator = f.apply(0, y) == 0 ? iterateNonZero() : iterator();
        while (iterator.hasNext()) {
            Element element = iterator.next();
            element.set(f.apply(element.get(), y));
        }
        invalidateCachedLength();
        return this;
    }

    @Override
    public Vector assign(DoubleFunction f) {
        Iterator<Element> iterator = !f.isDensifying() ? iterateNonZero() : iterator();
        while (iterator.hasNext()) {
            Element element = iterator.next();
            element.set(f.apply(element.get()));
        }
        invalidateCachedLength();
        return this;
    }

    @Override
    public Vector assign(Vector other, DoubleDoubleFunction function) {
        if (size != other.size()) {
            throw new CardinalityException(size, other.size());
        }
        VectorBinaryAssign.assignBest(this, other, function);
        invalidateCachedLength();
        return this;
    }

    @Override
    public Matrix cross(Vector other) {
        Matrix result = matrixLike(size, other.size());
        Iterator<Vector.Element> it = iterateNonZero();
        while (it.hasNext()) {
            Vector.Element e = it.next();
            int row = e.index();
            result.assignRow(row, other.times(getQuick(row)));
        }
        return result;
    }

    @Override
    public final int size() {
        return size;
    }

    @Override
    public String asFormatString() {
        return toString();
    }

    @Override
    public int hashCode() {
        int result = size;
        Iterator<Element> iter = iterateNonZero();
        while (iter.hasNext()) {
            Element ele = iter.next();
            result += ele.index() * RandomUtils.hashDouble(ele.get());
        }
        return result;
    }

    /**
     * Determines whether this {@link Vector} represents the same logical vector as another
     * object. Two {@link Vector}s are equal (regardless of implementation) if the value at
     * each index is the same, and the cardinalities are the same.
     */
    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof Vector)) {
            return false;
        }
        Vector that = (Vector) o;
        return size == that.size() && aggregate(that, Functions.PLUS, Functions.MINUS_ABS) == 0.0;
    }

    @Override
    public String toString() {
        return toString(null);
    }

    public String toString(String[] dictionary) {
        StringBuilder result = new StringBuilder();
        result.append('{');
        for (int index = 0; index < size; index++) {
            double value = getQuick(index);
            if (value != 0.0) {
                result.append(dictionary != null && dictionary.length > index ? dictionary[index] : index);
                result.append(':');
                result.append(value);
                result.append(',');
            }
        }
        if (result.length() > 1) {
            result.setCharAt(result.length() - 1, '}');
        } else {
            result.append('}');
        }
        return result.toString();
    }

    /**
     * toString() implementation for sparse vectors via {@link #nonZeroes()} method
     * @return String representation of the vector
     */
    public String sparseVectorToString() {
        Iterator<Element> it = iterateNonZero();
        if (!it.hasNext()) {
            return "{}";
        } else {
            StringBuilder result = new StringBuilder();
            result.append('{');
            while (it.hasNext()) {
                Vector.Element e = it.next();
                result.append(e.index());
                result.append(':');
                result.append(e.get());
                result.append(',');
            }
            result.setCharAt(result.length() - 1, '}');
            return result.toString();
        }
    }

    protected final class LocalElement implements Element {
        int index;

        LocalElement(int index) {
            this.index = index;
        }

        @Override
        public double get() {
            return getQuick(index);
        }

        @Override
        public int index() {
            return index;
        }

        @Override
        public void set(double value) {
            setQuick(index, value);
        }
    }
}