ch.epfl.lsir.xin.algorithm.core.SocialReg.java Source code

Java tutorial

Introduction

Here is the source code for ch.epfl.lsir.xin.algorithm.core.SocialReg.java

Source

//Copyright (C) 2014  Xin Liu
//
//RecMe: a lightweight recommendation algorithm library
//
//RecMe is free software; you can redistribute it and/or
//modify it under the terms of the GNU General Public License
//as published by the Free Software Foundation; either version 2
//of the License, or (at your option) any later version.
//
//This program is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU General Public License for more details.
//
//You should have received a copy of the GNU General Public License
//along with this program; if not, write to the Free Software
//Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

/**
 * This class implements social regularization MF.
 * Recommender Systems with Social Regularization, WSDM 2011
 * 
 * @author Xin Liu
 * */

package ch.epfl.lsir.xin.algorithm.core;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import org.apache.commons.configuration.ConfigurationException;
import org.apache.commons.configuration.PropertiesConfiguration;

import ch.epfl.lsir.xin.algorithm.IAlgorithm;
import ch.epfl.lsir.xin.datatype.LatentMatrix;
import ch.epfl.lsir.xin.datatype.MatrixEntry2D;
import ch.epfl.lsir.xin.datatype.RatingMatrix;
import ch.epfl.lsir.xin.util.SimilarityCalculator;

public class SocialReg implements IAlgorithm {

    /**
     * the rating matrix
     * */
    private RatingMatrix ratingMatrix = null;

    /**
     * users' social information matrix
     * */
    private RatingMatrix socialMatrix = null;

    /**
     * user similarity matrix
     * */
    //   private ArrayList<HashMap<Integer , Double>> similarityMatrix = null;

    /**
     * user's latent matrix
     * */
    private LatentMatrix userMatrix = null;

    private LatentMatrix userMatrixPrevious = null;

    /**
     * item's latent matrix
     * */
    private LatentMatrix itemMatrix = null;

    private LatentMatrix itemMatrixPrevious = null;

    /**
     * logger of the system
     * */
    private PrintWriter logger = null;

    /**
     * Configuration file for parameter setting.
     * */
    public PropertiesConfiguration config = new PropertiesConfiguration();

    /**
     * latent factor initialization method
     * */
    private int initialization = -1;

    /**
     * SGD parameters
     * */
    private int latentFactors = -1;
    private int Iterations = -1;
    private double learningRate = -1;
    private double regUser = -1;
    private double regItem = -1;
    private double convergence = -1;

    /**
     * optimization method indicator
     * */
    private String optimization = null;

    private int topN = -1;

    private int maxRating = -1;
    private int minRating = -1;

    /**
     * social regularization
     * */
    private double socialReg = -1;

    /**
     * constructor
     * */
    public SocialReg(RatingMatrix ratings, RatingMatrix social, boolean readModel, String modelFile) {
        //set configuration file for parameter setting.
        config.setFile(new File(".//conf//SocialReg.properties"));
        try {
            config.load();
        } catch (ConfigurationException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        this.ratingMatrix = ratings;
        this.socialMatrix = social;
        //      this.similarityMatrix = new ArrayList<HashMap<Integer , Double>>();
        //      for( int i = 0 ; i < this.socialMatrix.getRow() ; i++ )
        //      {
        //         this.similarityMatrix.add(new HashMap<Integer , Double>());
        //      }
        this.initialization = this.config.getInt("INITIALIZATION");
        this.latentFactors = this.config.getInt("LATENT_FACTORS");
        this.Iterations = this.config.getInt("ITERATIONS");
        this.learningRate = this.config.getDouble("LEARNING_RATE");
        this.regUser = this.config.getDouble("REGULARIZATION_USER");
        this.regItem = this.config.getDouble("REGULARIZATION_ITEM");
        this.convergence = this.config.getDouble("CONVERGENCE");
        this.optimization = this.config.getString("OPTIMIZATION_METHOD");
        this.userMatrix = new LatentMatrix(this.ratingMatrix.getRow(), this.latentFactors);
        this.userMatrix.setInitialization(this.initialization);
        this.userMatrix.valueInitialization();
        this.userMatrixPrevious = this.userMatrix.clone();
        this.itemMatrix = new LatentMatrix(this.ratingMatrix.getColumn(), this.latentFactors);
        this.itemMatrix.setInitialization(this.initialization);
        this.itemMatrix.valueInitialization();
        this.itemMatrixPrevious = this.itemMatrix.clone();
        this.topN = this.config.getInt("TOP_N_RECOMMENDATION");
        this.socialReg = this.config.getDouble("SOCIAL_REGULARIZATION");
        this.maxRating = this.config.getInt("MAX_RATING");
        this.minRating = this.config.getInt("MIN_RATING");

        if (readModel) {
            this.readModel(modelFile);
        } else {
            //calculate the similarity matrix
            this.calculateSimilarityMatrix(this.config.getString("USER_SIMILARITY"));
            //         try{
            //            PrintWriter printer = new PrintWriter("similarity");
            //            for( int i = 0 ; i < this.socialMatrix.getRatingMatrix().size() ; i++ )
            //            {
            //               for( Map.Entry<Integer, Double> entry : this.socialMatrix.getRatingMatrix().get(i).entrySet() )
            //               {
            //                  printer.print(entry.getValue() + "\t");
            //               }
            //               printer.println();
            //            }
            //            printer.flush();
            //            printer.close();
            //         }catch( IOException e )
            //         {
            //            e.printStackTrace();
            //         }
        }
    }

    /**
     * This function calculates the similarity among users
     * @param: similarity calculation method
     * */
    public void calculateSimilarityMatrix(String method) {
        for (int i = 0; i < this.socialMatrix.getRatingMatrix().size(); i++) {
            for (Map.Entry<Integer, Double> entry : this.socialMatrix.getRatingMatrix().get(i).entrySet()) {
                if (i == entry.getKey()) {
                    this.socialMatrix.getRatingMatrix().get(i).put(entry.getKey(), 1.0);
                } else {
                    ArrayList<Double> commonRatings1 = new ArrayList<Double>();
                    ArrayList<Double> commonRatings2 = new ArrayList<Double>();
                    //find common items for the two users
                    for (Map.Entry<Integer, Double> element : this.ratingMatrix.getRatingMatrix().get(i)
                            .entrySet()) {
                        Double rating = this.ratingMatrix.getRatingMatrix().get(entry.getKey())
                                .get(element.getKey());
                        if (rating != null) {
                            commonRatings1.add(element.getValue());
                            commonRatings2.add(rating);
                        }
                    }
                    double similarity = Double.NaN;
                    if (method.equals("PCC")) {
                        similarity = SimilarityCalculator.getSimilarityPCC(commonRatings1, commonRatings2,
                                this.config.getInt("SHRINKAGE"));
                        //normalize for PCC
                        similarity = (1 + similarity) / 2;
                    } else if (method.equals("COSINE")) {
                        similarity = SimilarityCalculator.getSimilarityCosine(commonRatings1, commonRatings2,
                                this.config.getInt("SHRINKAGE"));
                    } else {
                        logger.println("Similarity calculation method is not set correctly!");
                    }
                    if (Double.isNaN(similarity)) {
                        similarity = 0;
                    }
                    this.socialMatrix.getRatingMatrix().get(i).put(entry.getKey(), similarity);
                }
            }
            //         System.out.println("user index:" + i );
        }
    }

    @Override
    public void saveModel(String file) {
        // TODO Auto-generated method stub

    }

    @Override
    public void readModel(String file) {
        // TODO Auto-generated method stub

    }

    @Override
    public void build() {
        // TODO Auto-generated method stub

        if (this.optimization.equals("SGD")) {
            buildSGD();
            logger.println("SGD is used to train the model.");
        } else if (this.optimization.equals("ALS")) {
            //         buildALS();
            //         logger.println("ALS is used to train the model.");
        } else {
            logger.println("Optimization method is not set properly.");
        }
    }

    /**
     * This function learns a matrix factorization model using Stochastic Gradient Descent 
     * */
    public void buildSGD() {
        double preError = Double.MAX_VALUE;
        for (int i = 0; i < this.Iterations; i++) {
            System.out.println("Iteration: " + i);
            ArrayList<MatrixEntry2D> entries = this.ratingMatrix.getValidEntries();
            double error = 0; //overall error of this iteration
            System.out.println("Ratings: " + entries.size());
            while (entries.size() > 0) {
                //find a random entry
                int r = new Random().nextInt(entries.size());
                MatrixEntry2D entry = entries.get(r);
                double prediction = predict(entry.getRowIndex(), entry.getColumnIndex());
                //            System.out.println(prediction);
                if (prediction > this.maxRating)
                    prediction = this.maxRating;
                if (prediction < this.minRating)
                    prediction = this.minRating;
                double difference = entry.getValue() - prediction;
                for (int l = 0; l < this.latentFactors; l++) {
                    //user latent update is a modified version of the original paper
                    double tempU = -1;
                    if (this.config.getString("MODE").equals("AVERAGE")) {
                        ArrayList<Integer> friendIDs = this.socialMatrix.getRatedItemIndex(entry.getRowIndex());
                        double simSum = 0;
                        double friendSum = 0;
                        for (int j = 0; j < friendIDs.size(); j++) {
                            double sim = this.socialMatrix.getRatingMatrix().get(entry.getRowIndex())
                                    .get(friendIDs.get(j));
                            if (sim == 0)
                                continue;
                            simSum = simSum + sim;
                            friendSum = friendSum + sim * this.userMatrix.get(friendIDs.get(j), l);
                        }
                        if (simSum != 0)
                            friendSum = friendSum / simSum;
                        else
                            friendSum = 0;
                        //                  System.out.println(friendSum);
                        tempU = this.userMatrix.get(entry.getRowIndex(), l)
                                + this.learningRate * (difference * this.itemMatrix.get(entry.getColumnIndex(), l)
                                        - this.regUser * this.userMatrix.get(entry.getRowIndex(), l)
                                        - this.socialReg
                                                * (this.userMatrix.get(entry.getRowIndex(), l) - friendSum));
                    } else if (this.config.getString("MODE").equals("INDIVIDUAL")) {
                        ArrayList<Integer> friendIDs = this.socialMatrix.getRatedItemIndex(entry.getRowIndex());
                        double friendSum = 0;
                        for (int j = 0; j < friendIDs.size(); j++) {
                            friendSum += this.socialMatrix.getRatingMatrix().get(entry.getRowIndex())
                                    .get(friendIDs.get(j))
                                    * (this.userMatrix.get(entry.getRowIndex(), l)
                                            - this.userMatrix.get(friendIDs.get(j), l));
                        }
                        tempU = this.userMatrix.get(entry.getRowIndex(), l)
                                + this.learningRate * (difference * this.itemMatrix.get(entry.getColumnIndex(), l)
                                        - this.regUser * this.userMatrix.get(entry.getRowIndex(), l)
                                        - this.socialReg * friendSum);
                    } else {
                        logger.println("MODE is not set correctly.");
                    }
                    double tempI = this.itemMatrix.get(entry.getColumnIndex(), l)
                            + this.learningRate * (2 * difference * this.userMatrix.get(entry.getRowIndex(), l)
                                    - this.regItem * this.itemMatrix.get(entry.getColumnIndex(), l));
                    this.userMatrix.set(entry.getRowIndex(), l, tempU);
                    this.itemMatrix.set(entry.getColumnIndex(), l, tempI);
                }

                entries.remove(r);
            }

            //error
            entries = this.ratingMatrix.getValidEntries();
            for (int k = 0; k < entries.size(); k++) {
                MatrixEntry2D entry = entries.get(k);
                double prediction = predict(entry.getRowIndex(), entry.getColumnIndex());
                if (prediction > this.maxRating)
                    prediction = this.maxRating;
                if (prediction < this.minRating)
                    prediction = this.minRating;
                error = error + Math.abs(entry.getValue() - prediction);
                //            for( int j = 0 ; j < this.latentFactors ; j++ )
                //            {
                //               error = error + this.regUser/2 * Math.pow(this.userMatrix.get(entry.getRowIndex(), j), 2) + 
                //                     this.regItem/2 * Math.pow(this.itemMatrix.get(entry.getColumnIndex(), j), 2);
                //            }      
            }
            this.logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " Iteration " + i
                    + " : Error ~ " + error);
            this.logger.flush();
            //check for convergence
            if (Math.abs(error - preError) <= this.convergence && error <= preError) {
                logger.println("The algorithm convergences.");
                this.logger.flush();
                break;
            }
            // learning rate update strategy 
            updateLearningRate(error, preError);

            preError = error;
            logger.flush();
        }
    }

    /**
     * This function updates the learning rate at each iteration
     * @param: error in current iteration
     * @param: error in previous iteration
     * */
    public void updateLearningRate(double error, double preError) {
        int update = this.config.getInt("LEARNING_RATE_UPDATE");
        if (update == 1)//no need to update the learning rate
        {

        } else if (update == 2)//bold driver learning rate update algorithm
        {
            if (Math.abs(error) < Math.abs(preError)) {
                this.learningRate = (1 + 0.05) * this.learningRate;
                logger.println("Increase learning rate by 5%.");
                this.userMatrixPrevious = this.userMatrix.clone();
                this.itemMatrixPrevious = this.itemMatrix.clone();
            } else if (Math.abs(error) > Math.abs(preError)) {
                this.learningRate = 0.5 * this.learningRate;
                this.userMatrix = this.userMatrixPrevious.clone();//roll back to previous iteration
                this.itemMatrix = this.itemMatrixPrevious.clone();
                logger.println("Decrease learning rate by 50%.");
            }
        } else if (update == 3) {//decaying learning rate by a constant rate
            this.learningRate = this.learningRate * this.config.getDouble("LEARNING_RATING_DECAY");
            this.userMatrixPrevious = this.userMatrix.clone();
            this.itemMatrixPrevious = this.itemMatrix.clone();
        }
    }

    /**
     * this function predicts a user to an item
     * @param: index of the user
     * @param: index of the item
     * */
    public double predict(int user, int item) {
        double prediction = 0;
        for (int i = 0; i < this.latentFactors; i++) {
            prediction = prediction + this.userMatrix.getValues()[user][i] * this.itemMatrix.getValues()[item][i];
        }
        return prediction;
    }

    /**
     * @return the logger
     */
    public PrintWriter getLogger() {
        return logger;
    }

    /**
     * @param logger the logger to set
     */
    public void setLogger(PrintWriter logger) {
        this.logger = logger;
    }

    /**
     * @return the topN
     */
    public int getTopN() {
        return topN;
    }

    /**
     * @param topN the topN to set
     */
    public void setTopN(int topN) {
        this.topN = topN;
    }

    /**
     * @return the maxRating
     */
    public int getMaxRating() {
        return maxRating;
    }

    /**
     * @param maxRating the maxRating to set
     */
    public void setMaxRating(int maxRating) {
        this.maxRating = maxRating;
    }

    /**
     * @return the minRating
     */
    public int getMinRating() {
        return minRating;
    }

    /**
     * @param minRating the minRating to set
     */
    public void setMinRating(int minRating) {
        this.minRating = minRating;
    }

}