net.myrrix.online.factorizer.als.AlternatingLeastSquares.java Source code

Java tutorial

Introduction

Here is the source code for net.myrrix.online.factorizer.als.AlternatingLeastSquares.java

Source

/*
 * Copyright Myrrix Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.myrrix.online.factorizer.als;

import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.myrrix.common.parallel.ExecutorUtils;
import net.myrrix.common.math.SimpleVectorMath;
import net.myrrix.common.random.RandomManager;
import net.myrrix.common.random.RandomUtils;
import net.myrrix.common.stats.DoubleWeightedMean;
import net.myrrix.common.stats.JVMEnvironment;
import net.myrrix.common.LangUtils;
import net.myrrix.common.collection.FastByIDFloatMap;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.math.MatrixUtils;
import net.myrrix.online.factorizer.MatrixFactorizer;

/**
 * <p>Implements the Alternating Least Squares algorithm described in
 * <a href="http://www2.research.att.com/~yifanhu/PUB/cf.pdf">"Collaborative Filtering for Implicit Feedback Datasets"</a>
 * by Yifan Hu, Yehuda Koren, and Chris Volinsky.</p>
 *
 * <p>This implementation varies in some small details; it does not use the same mechanism for explaining ratings
 * for example and seeds the initial Y differently.</p>
 *
 * <p>Note that in this implementation, matrices are sparse and are implemented with a {@link FastByIDMap} of
 * {@link FastByIDFloatMap} so as to be able to use {@code long} keys. In many cases, a tall, skinny matrix is
 * needed (sparse rows, dense columns). This is represented with {@link FastByIDMap} of {@code float[]}.</p>
 *
 * @author Sean Owen
 * @since 1.0
 */
public final class AlternatingLeastSquares implements MatrixFactorizer {

    private static final Logger log = LoggerFactory.getLogger(AlternatingLeastSquares.class);

    /** Default alpha from the ALS algorithm. */
    public static final double DEFAULT_ALPHA = 1.0;
    /** Default lambda factor; this is multiplied by alpha. */
    public static final double DEFAULT_LAMBDA = 0.1;
    public static final double DEFAULT_CONVERGENCE_THRESHOLD = 0.001;
    public static final int DEFAULT_MAX_ITERATIONS = 30;

    private static final int WORK_UNIT_SIZE = 100;
    private static final int NUM_USER_ITEMS_TO_TEST_CONVERGENCE = 100;

    private static final long LOG_INTERVAL = 100000;
    private static final int MAX_FAR_FROM_VECTORS = 100000;

    // This will cause the ALS algorithm to reconstruction the input matrix R, rather than the
    // matrix P = R > 0 . Don't use this unless you understand it!
    private static final boolean RECONSTRUCT_R_MATRIX = Boolean
            .parseBoolean(System.getProperty("model.reconstructRMatrix", "false"));
    // Causes the loss function to exclude entries for any input pairs that do not appear in the
    // input and are implicitly 0
    // Likewise, don't touch this for now unless you know what it does.
    private static final boolean LOSS_IGNORES_UNSPECIFIED = Boolean
            .parseBoolean(System.getProperty("model.lossIgnoresUnspecified", "false"));

    private final FastByIDMap<FastByIDFloatMap> RbyRow;
    private final FastByIDMap<FastByIDFloatMap> RbyColumn;
    private final int features;
    private final double estimateErrorConvergenceThreshold;
    private final int maxIterations;
    private FastByIDMap<float[]> X;
    private FastByIDMap<float[]> Y;
    private FastByIDMap<float[]> previousY;

    /**
     * Uses default number of feature and convergence threshold.
     *
     * @param RbyRow the input R matrix, indexed by row
     * @param RbyColumn the input R matrix, indexed by column
     */
    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn) {
        this(RbyRow, RbyColumn, DEFAULT_FEATURES, DEFAULT_CONVERGENCE_THRESHOLD, DEFAULT_MAX_ITERATIONS);
    }

    /**
     * @param RbyRow the input R matrix, indexed by row
     * @param RbyColumn the input R matrix, indexed by column
     * @param features number of features, must be positive
     */
    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn,
            int features) {
        this(RbyRow, RbyColumn, features, DEFAULT_CONVERGENCE_THRESHOLD, DEFAULT_MAX_ITERATIONS);
    }

    /**
     * @param RbyRow the input R matrix, indexed by row
     * @param RbyColumn the input R matrix, indexed by column
     * @param features number of features, must be positive
     * @param estimateErrorConvergenceThreshold when the average absolute difference in estimated user-item
     *   scores falls below this threshold between iterations, iterations will stop
     * @param maxIterations caps the number of iterations run. If non-positive, there is no cap.
     */
    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn,
            int features, double estimateErrorConvergenceThreshold, int maxIterations) {
        Preconditions.checkNotNull(RbyRow);
        Preconditions.checkNotNull(RbyColumn);
        Preconditions.checkArgument(features > 0, "features must be positive: %s", features);
        Preconditions.checkArgument(
                estimateErrorConvergenceThreshold > 0.0 && estimateErrorConvergenceThreshold < 1.0,
                "threshold must be in (0,1): %s", estimateErrorConvergenceThreshold);
        this.RbyRow = RbyRow;
        this.RbyColumn = RbyColumn;
        this.features = features;
        this.estimateErrorConvergenceThreshold = estimateErrorConvergenceThreshold;
        this.maxIterations = maxIterations;
    }

    @Override
    public FastByIDMap<float[]> getX() {
        return X;
    }

    @Override
    public FastByIDMap<float[]> getY() {
        return Y;
    }

    /**
     * Does nothing.
     */
    @Override
    public void setPreviousX(FastByIDMap<float[]> previousX) {
        // do nothing
    }

    /**
     * Sets the initial state of Y used in the computation, typically the Y from a previous
     * computation. Call before {@link #call()}.
     */
    @Override
    public void setPreviousY(FastByIDMap<float[]> previousY) {
        this.previousY = previousY;
    }

    @Override
    public Void call() throws ExecutionException, InterruptedException {

        X = new FastByIDMap<float[]>(RbyRow.size());

        boolean randomY = previousY == null || previousY.isEmpty();
        Y = constructInitialY(previousY);

        // This will be used to compute rows/columns in parallel during iteration

        String threadsString = System.getProperty("model.threads");
        int numThreads = threadsString == null ? Runtime.getRuntime().availableProcessors()
                : Integer.parseInt(threadsString);
        ExecutorService executor = Executors.newFixedThreadPool(numThreads,
                new ThreadFactoryBuilder().setNameFormat("ALS-%d").setDaemon(true).build());

        log.info("Iterating using {} threads", numThreads);

        // Only of any use if using a Y matrix that was specially constructed and fixed ahead of time
        if (!Boolean.parseBoolean(System.getProperty("model.als.iterate", "true"))) {
            // Just figure X from Y and stop
            try {
                iterateXFromY(executor);
            } finally {
                ExecutorUtils.shutdownNowAndAwait(executor);
            }
            return null;
        }

        RandomGenerator random = RandomManager.getRandom();
        long[] testUserIDs = RandomUtils.chooseAboutNFromStream(NUM_USER_ITEMS_TO_TEST_CONVERGENCE,
                RbyRow.keySetIterator(), RbyRow.size(), random);
        long[] testItemIDs = RandomUtils.chooseAboutNFromStream(NUM_USER_ITEMS_TO_TEST_CONVERGENCE,
                RbyColumn.keySetIterator(), RbyColumn.size(), random);
        double[][] estimates = new double[testUserIDs.length][testItemIDs.length];
        if (!X.isEmpty()) {
            for (int i = 0; i < testUserIDs.length; i++) {
                for (int j = 0; j < testItemIDs.length; j++) {
                    estimates[i][j] = SimpleVectorMath.dot(X.get(testUserIDs[i]), Y.get(testItemIDs[j]));
                }
            }
        }
        // Otherwise X is empty because it's the first ever iteration. Estimates can be left at initial 0 value

        try {
            int iterationNumber = 0;
            while (true) {
                iterateXFromY(executor);
                iterateYFromX(executor);
                DoubleWeightedMean averageAbsoluteEstimateDiff = new DoubleWeightedMean();
                for (int i = 0; i < testUserIDs.length; i++) {
                    for (int j = 0; j < testItemIDs.length; j++) {
                        double newValue = SimpleVectorMath.dot(X.get(testUserIDs[i]), Y.get(testItemIDs[j]));
                        double oldValue = estimates[i][j];
                        estimates[i][j] = newValue;
                        averageAbsoluteEstimateDiff.increment(FastMath.abs(newValue - oldValue),
                                FastMath.max(0.0, newValue));
                    }
                }

                iterationNumber++;
                log.info("Finished iteration {}", iterationNumber);
                if (maxIterations > 0 && iterationNumber >= maxIterations) {
                    log.info("Reached iteration limit");
                    break;
                }
                log.info("Avg absolute difference in estimate vs prior iteration: {}", averageAbsoluteEstimateDiff);
                double convergenceValue = averageAbsoluteEstimateDiff.getResult();
                if (!LangUtils.isFinite(convergenceValue)) {
                    log.warn("Invalid convergence value, aborting iteration! {}", convergenceValue);
                    break;
                }
                // Don't converge after 1 iteration if starting from a random point
                if (!(randomY && iterationNumber == 1) && convergenceValue < estimateErrorConvergenceThreshold) {
                    log.info("Converged");
                    break;
                }
            }
        } finally {
            ExecutorUtils.shutdownNowAndAwait(executor);
        }
        return null;
    }

    private FastByIDMap<float[]> constructInitialY(FastByIDMap<float[]> previousY) {

        RandomGenerator random = RandomManager.getRandom();

        FastByIDMap<float[]> randomY;
        if (previousY == null || previousY.isEmpty()) {
            // Common case: have to start from scratch
            log.info("Starting from new, random Y matrix");
            randomY = new FastByIDMap<float[]>(RbyColumn.size());

        } else {

            int oldFeatureCount = previousY.entrySet().iterator().next().getValue().length;
            if (oldFeatureCount > features) {
                // Fewer features, use some dimensions from prior larger number of features as-is
                log.info("Feature count has decreased to {}, projecting down previous generation's Y matrix",
                        features);
                randomY = new FastByIDMap<float[]>(previousY.size());
                for (FastByIDMap.MapEntry<float[]> entry : previousY.entrySet()) {
                    float[] oldLargerVector = entry.getValue();
                    float[] newSmallerVector = new float[features];
                    System.arraycopy(oldLargerVector, 0, newSmallerVector, 0, newSmallerVector.length);
                    SimpleVectorMath.normalize(newSmallerVector);
                    randomY.put(entry.getKey(), newSmallerVector);
                }

            } else if (oldFeatureCount < features) {
                log.info("Feature count has increased to {}, using previous generation's Y matrix as subspace",
                        features);
                randomY = new FastByIDMap<float[]>(previousY.size());
                for (FastByIDMap.MapEntry<float[]> entry : previousY.entrySet()) {
                    float[] oldSmallerVector = entry.getValue();
                    float[] newLargerVector = new float[features];
                    System.arraycopy(oldSmallerVector, 0, newLargerVector, 0, oldSmallerVector.length);
                    // Fill in new dimensions with random values
                    for (int i = oldSmallerVector.length; i < newLargerVector.length; i++) {
                        newLargerVector[i] = (float) random.nextGaussian();
                    }
                    SimpleVectorMath.normalize(newLargerVector);
                    randomY.put(entry.getKey(), newLargerVector);
                }

            } else {
                // Common case: previous generation is same number of features
                log.info("Starting from previous generation's Y matrix");
                randomY = previousY;
            }
        }

        List<float[]> recentVectors = Lists.newArrayList();
        for (FastByIDMap.MapEntry<float[]> entry : randomY.entrySet()) {
            if (recentVectors.size() >= MAX_FAR_FROM_VECTORS) {
                break;
            }
            recentVectors.add(entry.getValue());
        }
        LongPrimitiveIterator it = RbyColumn.keySetIterator();
        long count = 0;
        while (it.hasNext()) {
            long id = it.nextLong();
            if (!randomY.containsKey(id)) {
                float[] vector = RandomUtils.randomUnitVectorFarFrom(features, recentVectors, random);
                randomY.put(id, vector);
                if (recentVectors.size() < MAX_FAR_FROM_VECTORS) {
                    recentVectors.add(vector);
                }
            }
            if (++count % LOG_INTERVAL == 0) {
                log.info("Computed {} initial Y rows", count);
            }
        }
        log.info("Constructed initial Y");
        return randomY;
    }

    /**
     * Runs one iteration to compute X from Y.
     */
    private void iterateXFromY(ExecutorService executor) throws ExecutionException, InterruptedException {

        RealMatrix YTY = MatrixUtils.transposeTimesSelf(Y);
        Collection<Future<?>> futures = Lists.newArrayList();
        addWorkers(RbyRow, Y, YTY, X, executor, futures);

        int count = 0;
        long total = 0;
        for (Future<?> f : futures) {
            f.get();
            count += WORK_UNIT_SIZE;
            if (count >= LOG_INTERVAL) {
                total += count;
                JVMEnvironment env = new JVMEnvironment();
                log.info("{} X/tag rows computed ({}MB heap)", total, env.getUsedMemoryMB());
                if (env.getPercentUsedMemory() > 95) {
                    log.warn(
                            "Memory is low. Increase heap size with -Xmx, decrease new generation size with larger "
                                    + "-XX:NewRatio value, and/or use -XX:+UseCompressedOops");
                }
                count = 0;
            }
        }
    }

    /**
     * Runs one iteration to compute Y from X.
     */
    private void iterateYFromX(ExecutorService executor) throws ExecutionException, InterruptedException {

        RealMatrix XTX = MatrixUtils.transposeTimesSelf(X);
        Collection<Future<?>> futures = Lists.newArrayList();
        addWorkers(RbyColumn, X, XTX, Y, executor, futures);

        int count = 0;
        long total = 0;
        for (Future<?> f : futures) {
            f.get();
            count += WORK_UNIT_SIZE;
            if (count >= LOG_INTERVAL) {
                total += count;
                JVMEnvironment env = new JVMEnvironment();
                log.info("{} Y/tag rows computed ({}MB heap)", total, env.getUsedMemoryMB());
                if (env.getPercentUsedMemory() > 95) {
                    log.warn(
                            "Memory is low. Increase heap size with -Xmx, decrease new generation size with larger "
                                    + "-XX:NewRatio value, and/or use -XX:+UseCompressedOops");
                }
                count = 0;
            }
        }
    }

    private void addWorkers(FastByIDMap<FastByIDFloatMap> R, FastByIDMap<float[]> M, RealMatrix MTM,
            FastByIDMap<float[]> MTags, ExecutorService executor, Collection<Future<?>> futures) {
        if (R != null) {
            List<Pair<Long, FastByIDFloatMap>> workUnit = Lists.newArrayListWithCapacity(WORK_UNIT_SIZE);
            for (FastByIDMap.MapEntry<FastByIDFloatMap> entry : R.entrySet()) {
                workUnit.add(new Pair<Long, FastByIDFloatMap>(entry.getKey(), entry.getValue()));
                if (workUnit.size() == WORK_UNIT_SIZE) {
                    futures.add(executor.submit(new Worker(features, M, MTM, MTags, workUnit)));
                    workUnit = Lists.newArrayListWithCapacity(WORK_UNIT_SIZE);
                }
            }
            if (!workUnit.isEmpty()) {
                futures.add(executor.submit(new Worker(features, M, MTM, MTags, workUnit)));
            }
        }
    }

    private static final class Worker implements Callable<Void> {

        private final int features;
        private final FastByIDMap<float[]> Y;
        private final RealMatrix YTY;
        private final FastByIDMap<float[]> X;
        private final Iterable<Pair<Long, FastByIDFloatMap>> workUnit;

        private Worker(int features, FastByIDMap<float[]> Y, RealMatrix YTY, FastByIDMap<float[]> X,
                Iterable<Pair<Long, FastByIDFloatMap>> workUnit) {
            this.features = features;
            this.Y = Y;
            this.YTY = YTY;
            this.X = X;
            this.workUnit = workUnit;
        }

        @Override
        public Void call() {
            double alpha = getAlpha();
            double lambda = getLambda() * alpha;
            int features = this.features;
            // Each worker has a batch of rows to compute:
            for (Pair<Long, FastByIDFloatMap> work : workUnit) {

                // Row (column) in original R matrix containing total association value. For simplicity we will
                // talk about users and rows only in the comments and variables. It's symmetric for columns / items.
                // This is Ru:
                FastByIDFloatMap ru = work.getSecond();

                // Start computing Wu = (YT*Cu*Y + lambda*I) = (YT*Y + YT*(Cu-I)*Y + lambda*I),
                // by first starting with a copy of YT * Y. Or, a variant on YT * Y, if LOSS_IGNORES_UNSPECIFIED is set
                RealMatrix Wu = LOSS_IGNORES_UNSPECIFIED
                        ? partialTransposeTimesSelf(Y, YTY.getRowDimension(), ru.keySetIterator())
                        : YTY.copy();

                double[][] WuData = MatrixUtils.accessMatrixDataDirectly(Wu);
                double[] YTCupu = new double[features];

                for (FastByIDFloatMap.MapEntry entry : ru.entrySet()) {

                    double xu = entry.getValue();

                    float[] vector = Y.get(entry.getKey());
                    if (vector == null) {
                        log.warn("No vector for {}. This should not happen. Continuing...", entry.getKey());
                        continue;
                    }

                    // Wu and YTCupu
                    if (RECONSTRUCT_R_MATRIX) {
                        for (int row = 0; row < features; row++) {
                            YTCupu[row] += xu * vector[row];
                        }
                    } else {
                        double cu = 1.0 + alpha * FastMath.abs(xu);
                        for (int row = 0; row < features; row++) {
                            float vectorAtRow = vector[row];
                            double rowValue = vectorAtRow * (cu - 1.0);
                            double[] WuDataRow = WuData[row];
                            for (int col = 0; col < features; col++) {
                                WuDataRow[col] += rowValue * vector[col];
                                //Wu.addToEntry(row, col, rowValue * vector[col]);
                            }
                            if (xu > 0.0) {
                                YTCupu[row] += vectorAtRow * cu;
                            }
                        }
                    }

                }

                double lambdaTimesCount = lambda * ru.size();
                for (int x = 0; x < features; x++) {
                    WuData[x][x] += lambdaTimesCount;
                    //Wu.addToEntry(x, x, lambdaTimesCount);
                }

                float[] xu = MatrixUtils.getSolver(Wu).solveDToF(YTCupu);

                // Store result:
                synchronized (X) {
                    X.put(work.getFirst(), xu);
                }

                // Process is identical for computing Y from X. Swap X in for Y, Y for X, i for u, etc.
            }
            return null;
        }

        private static double getAlpha() {
            String alphaProperty = System.getProperty("model.als.alpha");
            return alphaProperty == null ? DEFAULT_ALPHA : LangUtils.parseDouble(alphaProperty);
        }

        private static double getLambda() {
            String lambdaProperty = System.getProperty("model.als.lambda");
            return lambdaProperty == null ? DEFAULT_LAMBDA : LangUtils.parseDouble(lambdaProperty);
        }

        /**
         * Like {@link MatrixUtils#transposeTimesSelf(FastByIDMap)}, but instead of computing MT * M, 
         * it computes MT * C * M, where C is a diagonal matrix of 1s and 0s. This is like pretending some
         * rows of M are 0.
         * 
         * @see MatrixUtils#transposeTimesSelf(FastByIDMap) 
         * @see #LOSS_IGNORES_UNSPECIFIED
         */
        private static RealMatrix partialTransposeTimesSelf(FastByIDMap<float[]> M, int dimension,
                LongPrimitiveIterator keys) {
            RealMatrix result = new Array2DRowRealMatrix(dimension, dimension);
            while (keys.hasNext()) {
                long key = keys.next();
                float[] vector = M.get(key);
                for (int row = 0; row < dimension; row++) {
                    float rowValue = vector[row];
                    for (int col = 0; col < dimension; col++) {
                        result.addToEntry(row, col, rowValue * vector[col]);
                    }
                }
            }
            return result;
        }

    }

}