Java tutorial
/* * Copyright Myrrix Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.myrrix.online.factorizer.als; import java.util.Collection; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.Pair; import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import net.myrrix.common.parallel.ExecutorUtils; import net.myrrix.common.math.SimpleVectorMath; import net.myrrix.common.random.RandomManager; import net.myrrix.common.random.RandomUtils; import net.myrrix.common.stats.DoubleWeightedMean; import net.myrrix.common.stats.JVMEnvironment; import net.myrrix.common.LangUtils; import net.myrrix.common.collection.FastByIDFloatMap; import net.myrrix.common.collection.FastByIDMap; import net.myrrix.common.math.MatrixUtils; import net.myrrix.online.factorizer.MatrixFactorizer; /** * <p>Implements the Alternating Least Squares algorithm described in * <a href="http://www2.research.att.com/~yifanhu/PUB/cf.pdf">"Collaborative Filtering for Implicit Feedback Datasets"</a> * by Yifan Hu, Yehuda Koren, and Chris Volinsky.</p> * * <p>This implementation varies in some small details; it does not use the same mechanism for explaining ratings * for example and seeds the initial Y differently.</p> * * <p>Note that in this implementation, matrices are sparse and are implemented with a {@link FastByIDMap} of * {@link FastByIDFloatMap} so as to be able to use {@code long} keys. In many cases, a tall, skinny matrix is * needed (sparse rows, dense columns). This is represented with {@link FastByIDMap} of {@code float[]}.</p> * * @author Sean Owen * @since 1.0 */ public final class AlternatingLeastSquares implements MatrixFactorizer { private static final Logger log = LoggerFactory.getLogger(AlternatingLeastSquares.class); /** Default alpha from the ALS algorithm. */ public static final double DEFAULT_ALPHA = 1.0; /** Default lambda factor; this is multiplied by alpha. */ public static final double DEFAULT_LAMBDA = 0.1; public static final double DEFAULT_CONVERGENCE_THRESHOLD = 0.001; public static final int DEFAULT_MAX_ITERATIONS = 30; private static final int WORK_UNIT_SIZE = 100; private static final int NUM_USER_ITEMS_TO_TEST_CONVERGENCE = 100; private static final long LOG_INTERVAL = 100000; private static final int MAX_FAR_FROM_VECTORS = 100000; // This will cause the ALS algorithm to reconstruction the input matrix R, rather than the // matrix P = R > 0 . Don't use this unless you understand it! private static final boolean RECONSTRUCT_R_MATRIX = Boolean .parseBoolean(System.getProperty("model.reconstructRMatrix", "false")); // Causes the loss function to exclude entries for any input pairs that do not appear in the // input and are implicitly 0 // Likewise, don't touch this for now unless you know what it does. private static final boolean LOSS_IGNORES_UNSPECIFIED = Boolean .parseBoolean(System.getProperty("model.lossIgnoresUnspecified", "false")); private final FastByIDMap<FastByIDFloatMap> RbyRow; private final FastByIDMap<FastByIDFloatMap> RbyColumn; private final int features; private final double estimateErrorConvergenceThreshold; private final int maxIterations; private FastByIDMap<float[]> X; private FastByIDMap<float[]> Y; private FastByIDMap<float[]> previousY; /** * Uses default number of feature and convergence threshold. * * @param RbyRow the input R matrix, indexed by row * @param RbyColumn the input R matrix, indexed by column */ public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn) { this(RbyRow, RbyColumn, DEFAULT_FEATURES, DEFAULT_CONVERGENCE_THRESHOLD, DEFAULT_MAX_ITERATIONS); } /** * @param RbyRow the input R matrix, indexed by row * @param RbyColumn the input R matrix, indexed by column * @param features number of features, must be positive */ public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn, int features) { this(RbyRow, RbyColumn, features, DEFAULT_CONVERGENCE_THRESHOLD, DEFAULT_MAX_ITERATIONS); } /** * @param RbyRow the input R matrix, indexed by row * @param RbyColumn the input R matrix, indexed by column * @param features number of features, must be positive * @param estimateErrorConvergenceThreshold when the average absolute difference in estimated user-item * scores falls below this threshold between iterations, iterations will stop * @param maxIterations caps the number of iterations run. If non-positive, there is no cap. */ public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn, int features, double estimateErrorConvergenceThreshold, int maxIterations) { Preconditions.checkNotNull(RbyRow); Preconditions.checkNotNull(RbyColumn); Preconditions.checkArgument(features > 0, "features must be positive: %s", features); Preconditions.checkArgument( estimateErrorConvergenceThreshold > 0.0 && estimateErrorConvergenceThreshold < 1.0, "threshold must be in (0,1): %s", estimateErrorConvergenceThreshold); this.RbyRow = RbyRow; this.RbyColumn = RbyColumn; this.features = features; this.estimateErrorConvergenceThreshold = estimateErrorConvergenceThreshold; this.maxIterations = maxIterations; } @Override public FastByIDMap<float[]> getX() { return X; } @Override public FastByIDMap<float[]> getY() { return Y; } /** * Does nothing. */ @Override public void setPreviousX(FastByIDMap<float[]> previousX) { // do nothing } /** * Sets the initial state of Y used in the computation, typically the Y from a previous * computation. Call before {@link #call()}. */ @Override public void setPreviousY(FastByIDMap<float[]> previousY) { this.previousY = previousY; } @Override public Void call() throws ExecutionException, InterruptedException { X = new FastByIDMap<float[]>(RbyRow.size()); boolean randomY = previousY == null || previousY.isEmpty(); Y = constructInitialY(previousY); // This will be used to compute rows/columns in parallel during iteration String threadsString = System.getProperty("model.threads"); int numThreads = threadsString == null ? Runtime.getRuntime().availableProcessors() : Integer.parseInt(threadsString); ExecutorService executor = Executors.newFixedThreadPool(numThreads, new ThreadFactoryBuilder().setNameFormat("ALS-%d").setDaemon(true).build()); log.info("Iterating using {} threads", numThreads); // Only of any use if using a Y matrix that was specially constructed and fixed ahead of time if (!Boolean.parseBoolean(System.getProperty("model.als.iterate", "true"))) { // Just figure X from Y and stop try { iterateXFromY(executor); } finally { ExecutorUtils.shutdownNowAndAwait(executor); } return null; } RandomGenerator random = RandomManager.getRandom(); long[] testUserIDs = RandomUtils.chooseAboutNFromStream(NUM_USER_ITEMS_TO_TEST_CONVERGENCE, RbyRow.keySetIterator(), RbyRow.size(), random); long[] testItemIDs = RandomUtils.chooseAboutNFromStream(NUM_USER_ITEMS_TO_TEST_CONVERGENCE, RbyColumn.keySetIterator(), RbyColumn.size(), random); double[][] estimates = new double[testUserIDs.length][testItemIDs.length]; if (!X.isEmpty()) { for (int i = 0; i < testUserIDs.length; i++) { for (int j = 0; j < testItemIDs.length; j++) { estimates[i][j] = SimpleVectorMath.dot(X.get(testUserIDs[i]), Y.get(testItemIDs[j])); } } } // Otherwise X is empty because it's the first ever iteration. Estimates can be left at initial 0 value try { int iterationNumber = 0; while (true) { iterateXFromY(executor); iterateYFromX(executor); DoubleWeightedMean averageAbsoluteEstimateDiff = new DoubleWeightedMean(); for (int i = 0; i < testUserIDs.length; i++) { for (int j = 0; j < testItemIDs.length; j++) { double newValue = SimpleVectorMath.dot(X.get(testUserIDs[i]), Y.get(testItemIDs[j])); double oldValue = estimates[i][j]; estimates[i][j] = newValue; averageAbsoluteEstimateDiff.increment(FastMath.abs(newValue - oldValue), FastMath.max(0.0, newValue)); } } iterationNumber++; log.info("Finished iteration {}", iterationNumber); if (maxIterations > 0 && iterationNumber >= maxIterations) { log.info("Reached iteration limit"); break; } log.info("Avg absolute difference in estimate vs prior iteration: {}", averageAbsoluteEstimateDiff); double convergenceValue = averageAbsoluteEstimateDiff.getResult(); if (!LangUtils.isFinite(convergenceValue)) { log.warn("Invalid convergence value, aborting iteration! {}", convergenceValue); break; } // Don't converge after 1 iteration if starting from a random point if (!(randomY && iterationNumber == 1) && convergenceValue < estimateErrorConvergenceThreshold) { log.info("Converged"); break; } } } finally { ExecutorUtils.shutdownNowAndAwait(executor); } return null; } private FastByIDMap<float[]> constructInitialY(FastByIDMap<float[]> previousY) { RandomGenerator random = RandomManager.getRandom(); FastByIDMap<float[]> randomY; if (previousY == null || previousY.isEmpty()) { // Common case: have to start from scratch log.info("Starting from new, random Y matrix"); randomY = new FastByIDMap<float[]>(RbyColumn.size()); } else { int oldFeatureCount = previousY.entrySet().iterator().next().getValue().length; if (oldFeatureCount > features) { // Fewer features, use some dimensions from prior larger number of features as-is log.info("Feature count has decreased to {}, projecting down previous generation's Y matrix", features); randomY = new FastByIDMap<float[]>(previousY.size()); for (FastByIDMap.MapEntry<float[]> entry : previousY.entrySet()) { float[] oldLargerVector = entry.getValue(); float[] newSmallerVector = new float[features]; System.arraycopy(oldLargerVector, 0, newSmallerVector, 0, newSmallerVector.length); SimpleVectorMath.normalize(newSmallerVector); randomY.put(entry.getKey(), newSmallerVector); } } else if (oldFeatureCount < features) { log.info("Feature count has increased to {}, using previous generation's Y matrix as subspace", features); randomY = new FastByIDMap<float[]>(previousY.size()); for (FastByIDMap.MapEntry<float[]> entry : previousY.entrySet()) { float[] oldSmallerVector = entry.getValue(); float[] newLargerVector = new float[features]; System.arraycopy(oldSmallerVector, 0, newLargerVector, 0, oldSmallerVector.length); // Fill in new dimensions with random values for (int i = oldSmallerVector.length; i < newLargerVector.length; i++) { newLargerVector[i] = (float) random.nextGaussian(); } SimpleVectorMath.normalize(newLargerVector); randomY.put(entry.getKey(), newLargerVector); } } else { // Common case: previous generation is same number of features log.info("Starting from previous generation's Y matrix"); randomY = previousY; } } List<float[]> recentVectors = Lists.newArrayList(); for (FastByIDMap.MapEntry<float[]> entry : randomY.entrySet()) { if (recentVectors.size() >= MAX_FAR_FROM_VECTORS) { break; } recentVectors.add(entry.getValue()); } LongPrimitiveIterator it = RbyColumn.keySetIterator(); long count = 0; while (it.hasNext()) { long id = it.nextLong(); if (!randomY.containsKey(id)) { float[] vector = RandomUtils.randomUnitVectorFarFrom(features, recentVectors, random); randomY.put(id, vector); if (recentVectors.size() < MAX_FAR_FROM_VECTORS) { recentVectors.add(vector); } } if (++count % LOG_INTERVAL == 0) { log.info("Computed {} initial Y rows", count); } } log.info("Constructed initial Y"); return randomY; } /** * Runs one iteration to compute X from Y. */ private void iterateXFromY(ExecutorService executor) throws ExecutionException, InterruptedException { RealMatrix YTY = MatrixUtils.transposeTimesSelf(Y); Collection<Future<?>> futures = Lists.newArrayList(); addWorkers(RbyRow, Y, YTY, X, executor, futures); int count = 0; long total = 0; for (Future<?> f : futures) { f.get(); count += WORK_UNIT_SIZE; if (count >= LOG_INTERVAL) { total += count; JVMEnvironment env = new JVMEnvironment(); log.info("{} X/tag rows computed ({}MB heap)", total, env.getUsedMemoryMB()); if (env.getPercentUsedMemory() > 95) { log.warn( "Memory is low. Increase heap size with -Xmx, decrease new generation size with larger " + "-XX:NewRatio value, and/or use -XX:+UseCompressedOops"); } count = 0; } } } /** * Runs one iteration to compute Y from X. */ private void iterateYFromX(ExecutorService executor) throws ExecutionException, InterruptedException { RealMatrix XTX = MatrixUtils.transposeTimesSelf(X); Collection<Future<?>> futures = Lists.newArrayList(); addWorkers(RbyColumn, X, XTX, Y, executor, futures); int count = 0; long total = 0; for (Future<?> f : futures) { f.get(); count += WORK_UNIT_SIZE; if (count >= LOG_INTERVAL) { total += count; JVMEnvironment env = new JVMEnvironment(); log.info("{} Y/tag rows computed ({}MB heap)", total, env.getUsedMemoryMB()); if (env.getPercentUsedMemory() > 95) { log.warn( "Memory is low. Increase heap size with -Xmx, decrease new generation size with larger " + "-XX:NewRatio value, and/or use -XX:+UseCompressedOops"); } count = 0; } } } private void addWorkers(FastByIDMap<FastByIDFloatMap> R, FastByIDMap<float[]> M, RealMatrix MTM, FastByIDMap<float[]> MTags, ExecutorService executor, Collection<Future<?>> futures) { if (R != null) { List<Pair<Long, FastByIDFloatMap>> workUnit = Lists.newArrayListWithCapacity(WORK_UNIT_SIZE); for (FastByIDMap.MapEntry<FastByIDFloatMap> entry : R.entrySet()) { workUnit.add(new Pair<Long, FastByIDFloatMap>(entry.getKey(), entry.getValue())); if (workUnit.size() == WORK_UNIT_SIZE) { futures.add(executor.submit(new Worker(features, M, MTM, MTags, workUnit))); workUnit = Lists.newArrayListWithCapacity(WORK_UNIT_SIZE); } } if (!workUnit.isEmpty()) { futures.add(executor.submit(new Worker(features, M, MTM, MTags, workUnit))); } } } private static final class Worker implements Callable<Void> { private final int features; private final FastByIDMap<float[]> Y; private final RealMatrix YTY; private final FastByIDMap<float[]> X; private final Iterable<Pair<Long, FastByIDFloatMap>> workUnit; private Worker(int features, FastByIDMap<float[]> Y, RealMatrix YTY, FastByIDMap<float[]> X, Iterable<Pair<Long, FastByIDFloatMap>> workUnit) { this.features = features; this.Y = Y; this.YTY = YTY; this.X = X; this.workUnit = workUnit; } @Override public Void call() { double alpha = getAlpha(); double lambda = getLambda() * alpha; int features = this.features; // Each worker has a batch of rows to compute: for (Pair<Long, FastByIDFloatMap> work : workUnit) { // Row (column) in original R matrix containing total association value. For simplicity we will // talk about users and rows only in the comments and variables. It's symmetric for columns / items. // This is Ru: FastByIDFloatMap ru = work.getSecond(); // Start computing Wu = (YT*Cu*Y + lambda*I) = (YT*Y + YT*(Cu-I)*Y + lambda*I), // by first starting with a copy of YT * Y. Or, a variant on YT * Y, if LOSS_IGNORES_UNSPECIFIED is set RealMatrix Wu = LOSS_IGNORES_UNSPECIFIED ? partialTransposeTimesSelf(Y, YTY.getRowDimension(), ru.keySetIterator()) : YTY.copy(); double[][] WuData = MatrixUtils.accessMatrixDataDirectly(Wu); double[] YTCupu = new double[features]; for (FastByIDFloatMap.MapEntry entry : ru.entrySet()) { double xu = entry.getValue(); float[] vector = Y.get(entry.getKey()); if (vector == null) { log.warn("No vector for {}. This should not happen. Continuing...", entry.getKey()); continue; } // Wu and YTCupu if (RECONSTRUCT_R_MATRIX) { for (int row = 0; row < features; row++) { YTCupu[row] += xu * vector[row]; } } else { double cu = 1.0 + alpha * FastMath.abs(xu); for (int row = 0; row < features; row++) { float vectorAtRow = vector[row]; double rowValue = vectorAtRow * (cu - 1.0); double[] WuDataRow = WuData[row]; for (int col = 0; col < features; col++) { WuDataRow[col] += rowValue * vector[col]; //Wu.addToEntry(row, col, rowValue * vector[col]); } if (xu > 0.0) { YTCupu[row] += vectorAtRow * cu; } } } } double lambdaTimesCount = lambda * ru.size(); for (int x = 0; x < features; x++) { WuData[x][x] += lambdaTimesCount; //Wu.addToEntry(x, x, lambdaTimesCount); } float[] xu = MatrixUtils.getSolver(Wu).solveDToF(YTCupu); // Store result: synchronized (X) { X.put(work.getFirst(), xu); } // Process is identical for computing Y from X. Swap X in for Y, Y for X, i for u, etc. } return null; } private static double getAlpha() { String alphaProperty = System.getProperty("model.als.alpha"); return alphaProperty == null ? DEFAULT_ALPHA : LangUtils.parseDouble(alphaProperty); } private static double getLambda() { String lambdaProperty = System.getProperty("model.als.lambda"); return lambdaProperty == null ? DEFAULT_LAMBDA : LangUtils.parseDouble(lambdaProperty); } /** * Like {@link MatrixUtils#transposeTimesSelf(FastByIDMap)}, but instead of computing MT * M, * it computes MT * C * M, where C is a diagonal matrix of 1s and 0s. This is like pretending some * rows of M are 0. * * @see MatrixUtils#transposeTimesSelf(FastByIDMap) * @see #LOSS_IGNORES_UNSPECIFIED */ private static RealMatrix partialTransposeTimesSelf(FastByIDMap<float[]> M, int dimension, LongPrimitiveIterator keys) { RealMatrix result = new Array2DRowRealMatrix(dimension, dimension); while (keys.hasNext()) { long key = keys.next(); float[] vector = M.get(key); for (int row = 0; row < dimension; row++) { float rowValue = vector[row]; for (int col = 0; col < dimension; col++) { result.addToEntry(row, col, rowValue * vector[col]); } } } return result; } } }