com.ydy.cf.solver.impl.AlternatingLeastSquaresSolver.java Source code

Java tutorial

Introduction

Here is the source code for com.ydy.cf.solver.impl.AlternatingLeastSquaresSolver.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.ydy.cf.solver.impl;

import java.util.Iterator;
import java.util.List;

import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.Vector;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.ydy.cf.common.VectorUtils;
import com.ydy.cf.model.MatrixLoader;

/**
 * See <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf">
 * this paper.</a>
 */
public final class AlternatingLeastSquaresSolver extends AbstractRecommendationSolver {
    private List<Vector> featureVectors = Lists.newArrayList();
    private final double lambda;
    private final int numFeatures;

    public AlternatingLeastSquaresSolver(String userId, Vector userRatings, double lambda, MatrixLoader loader,
            int numRecommendations) {
        super(userId, userRatings, loader, numRecommendations);
        this.lambda = lambda;
        this.numFeatures = this.Y.numCols();

        Iterator<Vector.Element> iter = userRatings.iterateNonZero();
        while (iter.hasNext()) {
            int index = iter.next().index();
            featureVectors.add(this.Y.viewRow(index));
        }

        System.out.println(AlternatingLeastSquaresSolver.class.getName() + userId + "\t" + userRatings);
    }

    public List<RecommendedItem> solveAll() {
        Vector userFeatures = solve(this.featureVectors, this.userRatings, this.lambda, this.numFeatures);
        return VectorUtils.buildRecommends(this.Y, this.userRatings, userFeatures, 100).retrieve();
    }

    private Vector solve(Iterable<Vector> featureVectors, Vector ratingVector, double lambda, int numFeatures) {

        Preconditions.checkNotNull(featureVectors, "Feature vectors cannot be null");
        Preconditions.checkArgument(!Iterables.isEmpty(featureVectors));
        Preconditions.checkNotNull(ratingVector, "rating vector cannot be null");
        Preconditions.checkArgument(ratingVector.getNumNondefaultElements() > 0, "Rating vector cannot be empty");
        Preconditions.checkArgument(Iterables.size(featureVectors) == ratingVector.getNumNondefaultElements());

        int nui = ratingVector.getNumNondefaultElements();

        Matrix MiIi = createMiIi(featureVectors, numFeatures);
        Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector);

        /* compute Ai = MiIi * t(MiIi) + lambda * nui * E */
        Matrix Ai = addLambdaTimesNuiTimesE(MiIi.times(MiIi.transpose()), lambda, nui);
        /* compute Vi = MiIi * t(R(i,Ii)) */
        Matrix Vi = MiIi.times(RiIiMaybeTransposed);
        /* compute Ai * ui = Vi */
        return solve(Ai, Vi);
    }

    private Vector solve(Matrix Ai, Matrix Vi) {
        return new QRDecomposition(Ai).solve(Vi).viewColumn(0);
    }

    private Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int nui) {
        Preconditions.checkArgument(matrix.numCols() == matrix.numRows());
        for (int n = 0; n < matrix.numCols(); n++) {
            matrix.setQuick(n, n, matrix.getQuick(n, n) + lambda * nui);
        }
        return matrix;
    }

    private Matrix createMiIi(Iterable<Vector> featureVectors, int numFeatures) {
        Matrix MiIi = new DenseMatrix(numFeatures, Iterables.size(featureVectors));
        int n = 0;
        for (Vector featureVector : featureVectors) {
            for (int m = 0; m < numFeatures; m++) {
                MiIi.setQuick(m, n, featureVector.getQuick(m));
            }
            n++;
        }
        return MiIi;
    }

    private Matrix createRiIiMaybeTransposed(Vector ratingVector) {
        Preconditions.checkArgument(ratingVector.isSequentialAccess());
        Matrix RiIiMaybeTransposed = new DenseMatrix(ratingVector.getNumNondefaultElements(), 1);
        Iterator<Vector.Element> ratingsIterator = ratingVector.iterateNonZero();
        int index = 0;
        while (ratingsIterator.hasNext()) {
            Vector.Element elem = ratingsIterator.next();
            RiIiMaybeTransposed.setQuick(index++, 0, elem.get());
        }
        return RiIiMaybeTransposed;
    }
}