com.innometrics.integration.app.recommender.ml.als.AlternatingLeastSquaresSolver.java Source code

Java tutorial

Introduction

Here is the source code for com.innometrics.integration.app.recommender.ml.als.AlternatingLeastSquaresSolver.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.innometrics.integration.app.recommender.ml.als;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.Vector;

/**
 * See
 * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf">
 * this paper.</a>
 */
public final class AlternatingLeastSquaresSolver {

    private AlternatingLeastSquaresSolver() {
    }

    //TODO make feature vectors a simple array
    public static Vector solve(Iterable<Vector> featureVectors, Vector ratingVector, double lambda,
            int numFeatures) {

        Preconditions.checkNotNull(featureVectors, "Feature Vectors cannot be null");
        Preconditions.checkArgument(!Iterables.isEmpty(featureVectors));
        Preconditions.checkNotNull(ratingVector, "Rating Vector cannot be null");
        Preconditions.checkArgument(ratingVector.getNumNondefaultElements() > 0, "Rating Vector cannot be empty");
        Preconditions.checkArgument(Iterables.size(featureVectors) == ratingVector.getNumNondefaultElements());

        int nui = ratingVector.getNumNondefaultElements();

        Matrix MiIi = createMiIi(featureVectors, numFeatures);
        Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector);

        /* compute Ai = MiIi * t(MiIi) + lambda * nui * E */
        Matrix Ai = miTimesMiTransposePlusLambdaTimesNuiTimesE(MiIi, lambda, nui);
        /* compute Vi = MiIi * t(R(i,Ii)) */
        Matrix Vi = MiIi.times(RiIiMaybeTransposed);
        /* compute Ai * ui = Vi */
        return solve(Ai, Vi);
    }

    private static Vector solve(Matrix Ai, Matrix Vi) {
        return new QRDecomposition(Ai).solve(Vi).viewColumn(0);
    }

    static Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int nui) {
        Preconditions.checkArgument(matrix.numCols() == matrix.numRows(), "Must be a Square Matrix");
        double lambdaTimesNui = lambda * nui;
        int numCols = matrix.numCols();
        for (int n = 0; n < numCols; n++) {
            matrix.setQuick(n, n, matrix.getQuick(n, n) + lambdaTimesNui);
        }
        return matrix;
    }

    private static Matrix miTimesMiTransposePlusLambdaTimesNuiTimesE(Matrix MiIi, double lambda, int nui) {

        double lambdaTimesNui = lambda * nui;
        int rows = MiIi.numRows();

        double[][] result = new double[rows][rows];

        for (int i = 0; i < rows; i++) {
            for (int j = i; j < rows; j++) {
                double dot = MiIi.viewRow(i).dot(MiIi.viewRow(j));
                if (i != j) {
                    result[i][j] = dot;
                    result[j][i] = dot;
                } else {
                    result[i][i] = dot + lambdaTimesNui;
                }
            }
        }
        return new DenseMatrix(result, true);
    }

    static Matrix createMiIi(Iterable<Vector> featureVectors, int numFeatures) {
        double[][] MiIi = new double[numFeatures][Iterables.size(featureVectors)];
        int n = 0;
        for (Vector featureVector : featureVectors) {
            for (int m = 0; m < numFeatures; m++) {
                MiIi[m][n] = featureVector.getQuick(m);
            }
            n++;
        }
        return new DenseMatrix(MiIi, true);
    }

    static Matrix createRiIiMaybeTransposed(Vector ratingVector) {
        Preconditions.checkArgument(ratingVector.isSequentialAccess(),
                "Ratings should be iterable in Index or Sequential Order");

        double[][] RiIiMaybeTransposed = new double[ratingVector.getNumNondefaultElements()][1];
        int index = 0;
        for (Vector.Element elem : ratingVector.nonZeroes()) {
            RiIiMaybeTransposed[index++][0] = elem.get();
        }
        return new DenseMatrix(RiIiMaybeTransposed, true);
    }
}