ch.epfl.lsir.xin.algorithm.core.MatrixFactorization.java Source code

Java tutorial

Introduction

Here is the source code for ch.epfl.lsir.xin.algorithm.core.MatrixFactorization.java

Source

//Copyright (C) 2014  Xin Liu
//
//RecMe: a lightweight recommendation algorithm library
//
//RecMe is free software; you can redistribute it and/or
//modify it under the terms of the GNU General Public License
//as published by the Free Software Foundation; either version 2
//of the License, or (at your option) any later version.
//
//This program is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU General Public License for more details.
//
//You should have received a copy of the GNU General Public License
//along with this program; if not, write to the Free Software
//Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

/**
 * This class implements the basic matrix factorization model for recommendation
 * 
 * Ruslan Salakhutdinov and Andriy Mnih, Probabilistic Matrix Factorization, NIPS 2008.
 * 
 * @author Xin Liu
 * */

package ch.epfl.lsir.xin.algorithm.core;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.Random;
import java.util.StringTokenizer;

import org.apache.commons.configuration.ConfigurationException;
import org.apache.commons.configuration.PropertiesConfiguration;

import Jama.Matrix;
import ch.epfl.lsir.xin.algorithm.IAlgorithm;
import ch.epfl.lsir.xin.datatype.LatentMatrix;
import ch.epfl.lsir.xin.datatype.MatrixEntry2D;
import ch.epfl.lsir.xin.datatype.RatingMatrix;
import ch.epfl.lsir.xin.evaluation.ResultUnit;

public class MatrixFactorization implements IAlgorithm {

    /**
     * the rating matrix
     * */
    private RatingMatrix ratingMatrix = null;

    /**
     * user's latent matrix
     * */
    private LatentMatrix userMatrix = null;

    private LatentMatrix userMatrixPrevious = null;

    /**
     * item's latent matrix
     * */
    private LatentMatrix itemMatrix = null;

    private LatentMatrix itemMatrixPrevious = null;

    /**
     * logger of the system
     * */
    private PrintWriter logger = null;

    /**
     * Configuration file for parameter setting.
     * */
    public PropertiesConfiguration config = new PropertiesConfiguration();

    /**
     * latent factor initialization method
     * */
    private int initialization = -1;

    /**
     * SGD parameters
     * */
    private int latentFactors = -1;
    private int Iterations = -1;
    private double learningRate = -1;
    private double regUser = -1;
    private double regItem = -1;
    private double convergence = -1;

    /**
     * optimization method indicator
     * */
    private String optimization = null;

    private int topN = -1;

    private int maxRating = -1;
    private int minRating = -1;

    /**
     * constructor
     * @param: rating matrix
     * @param: indicator for loading an existing model
     * @param: location of the file that stores the model
     * */
    public MatrixFactorization(RatingMatrix ratingMatrix, boolean readModel, String modelFile) {
        //set configuration file for parameter setting.
        config.setFile(new File(".//conf//MF.properties"));
        try {
            config.load();
        } catch (ConfigurationException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        this.ratingMatrix = ratingMatrix;
        this.initialization = this.config.getInt("INITIALIZATION");
        this.latentFactors = this.config.getInt("LATENT_FACTORS");
        this.Iterations = this.config.getInt("ITERATIONS");
        this.learningRate = this.config.getDouble("LEARNING_RATE");
        this.regUser = this.config.getDouble("REGULARIZATION_USER");
        this.regItem = this.config.getDouble("REGULARIZATION_ITEM");
        this.convergence = this.config.getDouble("CONVERGENCE");
        this.optimization = this.config.getString("OPTIMIZATION_METHOD");
        this.userMatrix = new LatentMatrix(this.ratingMatrix.getRow(), this.latentFactors);
        this.userMatrix.setInitialization(this.initialization);
        this.userMatrix.valueInitialization();
        this.userMatrixPrevious = this.userMatrix.clone();
        this.itemMatrix = new LatentMatrix(this.ratingMatrix.getColumn(), this.latentFactors);
        this.itemMatrix.setInitialization(this.initialization);
        this.itemMatrix.valueInitialization();
        this.itemMatrixPrevious = this.itemMatrix.clone();
        this.topN = this.config.getInt("TOP_N_RECOMMENDATION");
        this.maxRating = this.config.getInt("MAX_RATING");
        this.minRating = this.config.getInt("MIN_RATING");

        if (readModel) {
            this.readModel(modelFile);
        }
    }

    /**
     * this function trains the model
     * */
    @Override
    public void build() {
        // TODO Auto-generated method stub

        if (this.optimization.equals("SGD")) {
            buildSGD();
            logger.println("SGD is used to train the model.");
        } else if (this.optimization.equals("ALS")) {
            buildALS();
            logger.println("ALS is used to train the model.");
        } else {
            logger.println("Optimization method is not set properly.");
        }
    }

    /**
     * This function learns a matrix factorization model using Stochastic Gradient Descent 
     * */
    public void buildSGD() {
        double preError = Double.MAX_VALUE;
        for (int i = 0; i < this.Iterations; i++) {
            System.out.println("Iteration: " + i);
            ArrayList<MatrixEntry2D> entries = this.ratingMatrix.getValidEntries();
            double error = 0; //overall error of this iteration
            while (entries.size() > 0) {
                //find a random entry
                int r = new Random().nextInt(entries.size());
                MatrixEntry2D entry = entries.get(r);
                double prediction = predict(entry.getRowIndex(), entry.getColumnIndex(), false);
                if (prediction > this.maxRating)
                    prediction = this.maxRating;
                if (prediction < this.minRating)
                    prediction = this.minRating;
                double difference = entry.getValue() - prediction;
                for (int l = 0; l < this.latentFactors; l++) {
                    double tempU = this.userMatrix.get(entry.getRowIndex(), l) + this.learningRate * ( /*2 **/
                    difference * this.itemMatrix.get(entry.getColumnIndex(), l)
                            - this.regUser * this.userMatrix.get(entry.getRowIndex(), l));
                    double tempI = this.itemMatrix.get(entry.getColumnIndex(), l) + this.learningRate * ( /*2 **/
                    difference * this.userMatrix.get(entry.getRowIndex(), l)
                            - this.regItem * this.itemMatrix.get(entry.getColumnIndex(), l));
                    this.userMatrix.set(entry.getRowIndex(), l, tempU);
                    this.itemMatrix.set(entry.getColumnIndex(), l, tempI);
                }
                //one rating is only processed once in an iteration
                entries.remove(r);
            }
            //error
            entries = this.ratingMatrix.getValidEntries();
            for (int k = 0; k < entries.size(); k++) {
                MatrixEntry2D entry = entries.get(k);
                double prediction = predict(entry.getRowIndex(), entry.getColumnIndex(), false);
                if (prediction > this.maxRating)
                    prediction = this.maxRating;
                if (prediction < this.minRating)
                    prediction = this.minRating;
                error = error + Math.abs(entry.getValue() - prediction);
                //            for( int j = 0 ; j < this.latentFactors ; j++ )
                //            {
                //               error = error + this.regUser/2 * Math.pow(this.userMatrix.get(entry.getRowIndex(), j), 2) + 
                //                     this.regItem/2 * Math.pow(this.itemMatrix.get(entry.getColumnIndex(), j), 2);
                //            }      
            }

            this.logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " Iteration " + i
                    + " : Error ~ " + error);
            this.logger.flush();
            //check for convergence
            if (Math.abs(error - preError) <= this.convergence && error <= preError) {
                logger.println("The algorithm convergences.");
                this.logger.flush();
                break;
            }

            // learning rate update strategy 
            updateLearningRate(error, preError);

            preError = error;
            logger.flush();
        }
    }

    /**
     * This function learns a matrix factorization model using Alternative Least Square  
     * refer to "Large-scale Parallel Collaborative Filtering for the Netix Prize"
     * */
    public void buildALS() {
        double preError = Double.MAX_VALUE;
        Matrix identify = Matrix.identity(this.latentFactors, this.latentFactors);
        for (int i = 0; i < this.Iterations; i++) {
            System.out.println("Iteration: " + i);
            double error = 0; //overall error of this iteration
            //fix item matrix M, solve user matrix U
            for (int j = 0; j < this.userMatrix.getRows(); j++) {
                //latent factor of items that are rated by user j
                int n = this.ratingMatrix.getUserRatingNumber(j);//number of items rated by user j
                double[][] m = new double[n][this.latentFactors];
                int index = 0;
                for (int k = 0; k < this.itemMatrix.getRows(); k++) {
                    if (!Double.isNaN(this.ratingMatrix.get(j, k))) {
                        m[index++] = this.itemMatrix.getRowValues(k);
                    }
                }
                Matrix M = new Matrix(m);
                //step 1:
                Matrix A = M.transpose().times(M).plus(identify.times(this.regUser).times(n));
                //step 2:
                double[][] r = new double[1][n];//ratings of this user
                int index1 = 0;
                for (int k = 0; k < this.ratingMatrix.getColumn(); k++) {
                    Double rating = this.ratingMatrix.getRatingMatrix().get(j).get(k);
                    if (rating != null) {
                        r[0][index1++] = rating.doubleValue();
                    }
                }
                Matrix R = new Matrix(r);
                Matrix V = M.transpose().times(R.transpose());
                //step 3: the updated user matrix wrt user j
                Matrix uj = A.inverse().times(V);
                this.userMatrix.getValues()[j] = uj.transpose().getArray()[0];
            }
            //fix user matrix U, solve item matrix M
            for (int j = 0; j < this.itemMatrix.getRows(); j++) {
                //latent factor of users that have rated item j
                int n = this.ratingMatrix.getItemRatingNumber(j);//number of users rate item j
                double[][] u = new double[n][this.latentFactors];
                int index = 0;
                for (int k = 0; k < this.userMatrix.getRows(); k++) {
                    if (!Double.isNaN(this.ratingMatrix.get(k, j))) {
                        u[index++] = this.userMatrix.getRowValues(k);
                    }
                }
                if (u.length == 0)
                    continue;
                Matrix U = new Matrix(u);

                //step 1:
                Matrix A = U.transpose().times(U).plus(identify.times(this.regItem).times(n));

                //step 2:
                double[][] r = new double[1][n];//ratings of this item
                int index1 = 0;
                for (int k = 0; k < this.ratingMatrix.getRow(); k++) {
                    Double rating = this.ratingMatrix.getRatingMatrix().get(k).get(j);
                    if (rating != null) {
                        r[0][index1++] = rating.doubleValue();
                    }
                }
                Matrix R = new Matrix(r);
                Matrix V = U.transpose().times(R.transpose());
                //step 3: the updated item matrix wrt item j
                Matrix mj = A.inverse().times(V);
                this.itemMatrix.getValues()[j] = mj.transpose().getArray()[0];
            }
            //check for convergence
            //error
            ArrayList<MatrixEntry2D> entries = this.ratingMatrix.getValidEntries();
            for (int k = 0; k < entries.size(); k++) {
                MatrixEntry2D entry = entries.get(k);
                double prediction = predict(entry.getRowIndex(), entry.getColumnIndex(), false);
                if (prediction > this.maxRating)
                    prediction = this.maxRating;
                if (prediction < this.minRating)
                    prediction = this.minRating;
                error = error + Math.abs(entry.getValue() - prediction);
                for (int j = 0; j < this.latentFactors; j++) {
                    error = error + this.regUser / 2 * Math.pow(this.userMatrix.get(entry.getRowIndex(), j), 2)
                            + this.regItem / 2 * Math.pow(this.itemMatrix.get(entry.getColumnIndex(), j), 2);
                }
            }

            this.logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " Iteration " + i
                    + " : Error ~ " + error);
            this.logger.flush();
            //check for convergence
            if (Math.abs(error - preError) <= this.convergence && error <= preError) {
                logger.println("The algorithm convergences.");
                this.logger.flush();
                break;
            }

            preError = error;
            logger.flush();
        }
    }

    /**
     * This function updates the learning rate at each iteration
     * @param: error in current iteration
     * @param: error in previous iteration
     * */
    public void updateLearningRate(double error, double preError) {
        int update = this.config.getInt("LEARNING_RATE_UPDATE");
        if (update == 1)//no need to update the learning rate
        {
            //         this.userMatrixPrevious = this.userMatrix.clone();
            //         this.itemMatrixPrevious = this.itemMatrix.clone();
        } else if (update == 2)//bold driver learning rate update algorithm
        {
            if (Math.abs(error) < Math.abs(preError)) {
                this.learningRate = (1 + 0.05) * this.learningRate;
                logger.println("Increase learning rate by 5%.");
                this.userMatrixPrevious = this.userMatrix.clone();
                this.itemMatrixPrevious = this.itemMatrix.clone();
            } else if (Math.abs(error) > Math.abs(preError)) {
                this.learningRate = 0.5 * this.learningRate;
                this.userMatrix = this.userMatrixPrevious.clone();//roll back to previous iteration
                this.itemMatrix = this.itemMatrixPrevious.clone();
                logger.println("Decrease learning rate by 50%.");
            }
        } else if (update == 3) {//decaying learning rate by a constant rate
            this.learningRate = this.learningRate * this.config.getDouble("LEARNING_RATING_DECAY");
            this.userMatrixPrevious = this.userMatrix.clone();
            this.itemMatrixPrevious = this.itemMatrix.clone();
        }
    }

    /**
     * this function predicts a user to an item
     * @param: index of the user
     * @param: index of the item
     * */
    public double predict(int user, int item, boolean rank) {
        double prediction = 0;
        for (int i = 0; i < this.latentFactors; i++) {
            prediction = prediction + this.userMatrix.getValues()[user][i] * this.itemMatrix.getValues()[item][i];
        }

        if (rank) {
            //         if( !Double.isNaN(this.ratingMatrix.getUsersMean().get(user)) )
            //         {
            //            prediction += this.ratingMatrix.getUsersMean().get(user);
            //         }
            //         if( !Double.isNaN(this.ratingMatrix.getItemsMean().get(item)) )
            //         {
            //            prediction += this.ratingMatrix.getItemsMean().get(item);
            //         }      
            return prediction;
        } else {
            return prediction;
        }
    }

    /**
     * This function generates a recommendation list for a given user
     * @param: index of the user
     * */
    public ArrayList<ResultUnit> getRecommendationList(int userIndex) {
        //how many train/test ratings of this user are sufficient?
        //      if( this.ratingMatrix.getUserRatingNumber(userIndex) < 10 )
        //         return null;

        ArrayList<ResultUnit> recommendationList = new ArrayList<ResultUnit>();
        //find all item candidate list (items that are not rated by the user)
        for (int i = 0; i < this.ratingMatrix.getColumn(); i++) {
            if (this.ratingMatrix.getRatingMatrix().get(userIndex).get(i) == null) {
                //this item has not been rated by the item
                ResultUnit unit = new ResultUnit(userIndex, i, predict(userIndex, i, true));
                recommendationList.add(unit);
            }
        }

        //sort the recommendation list
        Collections.sort(recommendationList);
        ArrayList<ResultUnit> result = new ArrayList<ResultUnit>();
        for (int i = recommendationList.size() - 1; i >= recommendationList.size() - this.topN; i--) {
            //         System.out.print(recommendationList.get(i).getPrediciton() + " , ");
            result.add(recommendationList.get(i));
        }
        //      System.out.println();
        return result;
    }

    @Override
    public void saveModel(String file) {
        // TODO Auto-generated method stub
        try {
            //save users latent factors
            PrintWriter printer1 = new PrintWriter(file + "_user");
            for (int i = 0; i < this.userMatrix.getRows(); i++) {
                for (int j = 0; j < this.userMatrix.getColumns(); j++) {
                    printer1.print(this.userMatrix.get(i, j) + "\t");
                }
                printer1.println();
            }
            printer1.flush();
            printer1.close();

            //save items latent factors
            PrintWriter printer2 = new PrintWriter(file + "_item");
            for (int i = 0; i < this.itemMatrix.getRows(); i++) {
                for (int j = 0; j < this.itemMatrix.getColumns(); j++) {
                    printer2.print(this.itemMatrix.get(i, j) + "\t");
                }
                printer2.println();
            }
            printer2.flush();
            printer2.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void readModel(String file) {
        // TODO Auto-generated method stub
        try {
            //read user latent factors
            BufferedReader reader1 = new BufferedReader(new FileReader(file + "_user"));
            String line = null;
            int u1 = 0;
            while ((line = reader1.readLine()) != null) {
                StringTokenizer tokens = new StringTokenizer(line.trim());
                int u2 = 0;
                while (tokens.hasMoreElements()) {
                    this.userMatrix.set(u1, u2, Double.parseDouble(tokens.nextToken()));
                    u2++;
                }
                u1++;
            }
            reader1.close();

            //read item latent factors
            BufferedReader reader2 = new BufferedReader(new FileReader(file + "_item"));
            String line2 = null;
            int i1 = 0;
            while ((line2 = reader2.readLine()) != null) {
                StringTokenizer tokens = new StringTokenizer(line2.trim());
                int i2 = 0;
                while (tokens.hasMoreElements()) {
                    this.itemMatrix.set(i1, i2, Double.parseDouble(tokens.nextToken()));
                    i2++;
                }
                i1++;
            }
            reader2.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * @return the ratingMatrix
     */
    public RatingMatrix getRatingMatrix() {
        return ratingMatrix;
    }

    /**
     * @param ratingMatrix the ratingMatrix to set
     */
    public void setRatingMatrix(RatingMatrix ratingMatrix) {
        this.ratingMatrix = ratingMatrix;
    }

    /**
     * @return the userMatrix
     */
    public LatentMatrix getUserMatrix() {
        return userMatrix;
    }

    /**
     * @param userMatrix the userMatrix to set
     */
    public void setUserMatrix(LatentMatrix userMatrix) {
        this.userMatrix = userMatrix;
    }

    /**
     * @return the itemMatrix
     */
    public LatentMatrix getItemMatrix() {
        return itemMatrix;
    }

    /**
     * @param itemMatrix the itemMatrix to set
     */
    public void setItemMatrix(LatentMatrix itemMatrix) {
        this.itemMatrix = itemMatrix;
    }

    /**
     * @return the logger
     */
    public PrintWriter getLogger() {
        return logger;
    }

    /**
     * @param logger the logger to set
     */
    public void setLogger(PrintWriter logger) {
        this.logger = logger;
    }

    /**
     * @return the initialization
     */
    public int getInitialization() {
        return initialization;
    }

    /**
     * @param initialization the initialization to set
     */
    public void setInitialization(int initialization) {
        this.initialization = initialization;
    }

    /**
     * @return the latentFactors
     */
    public int getLatentFactors() {
        return latentFactors;
    }

    /**
     * @param latentFactors the latentFactors to set
     */
    public void setLatentFactors(int latentFactors) {
        this.latentFactors = latentFactors;
    }

    /**
     * @return the iterations
     */
    public int getIterations() {
        return Iterations;
    }

    /**
     * @param iterations the iterations to set
     */
    public void setIterations(int iterations) {
        Iterations = iterations;
    }

    /**
     * @return the learningRate
     */
    public double getLearningRate() {
        return learningRate;
    }

    /**
     * @param learningRate the learningRate to set
     */
    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    /**
     * @return the regUser
     */
    public double getRegUser() {
        return regUser;
    }

    /**
     * @param regUser the regUser to set
     */
    public void setRegUser(double regUser) {
        this.regUser = regUser;
    }

    /**
     * @return the regItem
     */
    public double getRegItem() {
        return regItem;
    }

    /**
     * @param regItem the regItem to set
     */
    public void setRegItem(double regItem) {
        this.regItem = regItem;
    }

    /**
     * @return the convergence
     */
    public double getConvergence() {
        return convergence;
    }

    /**
     * @param convergence the convergence to set
     */
    public void setConvergence(double convergence) {
        this.convergence = convergence;
    }

    /**
     * @return the topN
     */
    public int getTopN() {
        return topN;
    }

    /**
     * @param topN the topN to set
     */
    public void setTopN(int topN) {
        this.topN = topN;
    }

    /**
     * @return the maxRating
     */
    public int getMaxRating() {
        return maxRating;
    }

    /**
     * @param maxRating the maxRating to set
     */
    public void setMaxRating(int maxRating) {
        this.maxRating = maxRating;
    }

    /**
     * @return the minRating
     */
    public int getMinRating() {
        return minRating;
    }

    /**
     * @param minRating the minRating to set
     */
    public void setMinRating(int minRating) {
        this.minRating = minRating;
    }

}