ch.epfl.lsir.xin.algorithm.core.UserBasedCF.java Source code

Java tutorial

Introduction

Here is the source code for ch.epfl.lsir.xin.algorithm.core.UserBasedCF.java

Source

//Copyright (C) 2014  Xin Liu
//
//RecMe: a lightweight recommendation algorithm library
//
//RecMe is free software; you can redistribute it and/or
//modify it under the terms of the GNU General Public License
//as published by the Free Software Foundation; either version 2
//of the License, or (at your option) any later version.
//
//This program is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU General Public License for more details.
//
//You should have received a copy of the GNU General Public License
//along with this program; if not, write to the Free Software
//Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

/**
 * This class implements the user based collaborative filtering algorithm
 * 
 * @author Xin Liu
 * */

package ch.epfl.lsir.xin.algorithm.core;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.StringTokenizer;

import org.apache.commons.configuration.ConfigurationException;
import org.apache.commons.configuration.PropertiesConfiguration;

import ch.epfl.lsir.xin.algorithm.IAlgorithm;
import ch.epfl.lsir.xin.datatype.RatingMatrix;
import ch.epfl.lsir.xin.evaluation.ResultUnit;
import ch.epfl.lsir.xin.model.SUser;
import ch.epfl.lsir.xin.util.SimilarityCalculator;

public class UserBasedCF implements IAlgorithm {

    /**
     * the rating matrix
     * */
    private RatingMatrix ratingMatrix = null;

    /**
     * user similarity matrix
     * */
    private double[][] similarityMatrix = null;

    /**
     * logger of the system
     * */
    private PrintWriter logger = null;

    /**
     * Configuration file for parameter setting.
     * */
    public PropertiesConfiguration config = new PropertiesConfiguration();

    /**
     * Top N recommendation
     * */
    private int topN = -1;

    private int maxRating = -1;
    private int minRating = -1;

    /**
     * similarity calculation method
     * */
    private String similarityCalculation = null;

    /**
     * constructor
     * @param training ratings
     * */
    public UserBasedCF(RatingMatrix ratingMatrix) {
        //set configuration file for parameter setting.
        config.setFile(new File(".//conf//UserBasedCF.properties"));
        try {
            config.load();
        } catch (ConfigurationException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        this.topN = this.config.getInt("TOP_N_RECOMMENDATION");
        this.similarityCalculation = this.config.getString("SIMILARITY");
        this.ratingMatrix = ratingMatrix;
        this.maxRating = this.config.getInt("MAX_RATING");
        this.minRating = this.config.getInt("MIN_RATING");
        this.similarityMatrix = new double[this.ratingMatrix.getRow()][this.ratingMatrix.getRow()];
    }

    /**
     * constructor
     * @param: training ratings
     * @param: read a saved model or not
     * @param: file of a saved model 
     * */
    public UserBasedCF(RatingMatrix ratingMatrix, boolean readModel, String file) {
        //set configuration file for parameter setting.
        config.setFile(new File(".//conf//UserBasedCF.properties"));
        try {
            config.load();
        } catch (ConfigurationException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        this.topN = this.config.getInt("TOP_N_RECOMMENDATION");
        this.similarityCalculation = this.config.getString("SIMILARITY");
        this.ratingMatrix = ratingMatrix;
        this.maxRating = this.config.getInt("MAX_RATING");
        this.minRating = this.config.getInt("MIN_RATING");
        this.similarityMatrix = new double[this.ratingMatrix.getRow()][this.ratingMatrix.getRow()];
        if (readModel) {
            readModel(file);
        }
    }

    /**
     * This function calculates the similarity matrix for users
     * */
    public void similarityMatrixCalculation() {
        for (int i = 0; i < this.ratingMatrix.getRow(); i++) {
            for (int j = i; j < this.ratingMatrix.getRow(); j++) {
                if (i == j) //the similarity with herself is 1 
                {
                    this.similarityMatrix[i][j] = 1;
                } else {
                    ArrayList<Double> commonRatings1 = new ArrayList<Double>();
                    ArrayList<Double> commonRatings2 = new ArrayList<Double>();
                    //find common items for the two users
                    for (Map.Entry<Integer, Double> entry : this.ratingMatrix.getRatingMatrix().get(i).entrySet()) {
                        int itemI = entry.getKey();
                        if (entry.getValue() != null
                                && this.ratingMatrix.getRatingMatrix().get(j).get(itemI) != null) {
                            commonRatings1.add(this.ratingMatrix.getRatingMatrix().get(i).get(itemI));
                            commonRatings2.add(this.ratingMatrix.getRatingMatrix().get(j).get(itemI));
                        }
                    }

                    double similarity = Double.NaN;
                    if (this.similarityCalculation.equals("pcc")) {
                        similarity = SimilarityCalculator.getSimilarityPCC(commonRatings1, commonRatings2,
                                this.config.getInt("SHRINKAGE"));
                    } else if (this.similarityCalculation.equals("cosine")) {
                        similarity = SimilarityCalculator.getSimilarityCosine(commonRatings1, commonRatings2,
                                this.config.getInt("SHRINKAGE"));
                    } else {
                        logger.append("Cannot determine which similarity calculation method is used for. \n");
                        return;
                    }

                    if (Double.isNaN(similarity)) {
                        similarity = 0;
                    }
                    this.similarityMatrix[i][j] = similarity;
                    this.similarityMatrix[j][i] = similarity;
                }
            }
        }
    }

    /**
     * This function predicts a user's rating to an item
     * @param: index of the user
     * @param: index of the item
     * @param: is the predicted rating for rating prediction or ranking?
     * @return: the predicted rating
     * */
    public double predict(int userIndex, int itemIndex, boolean rank) {
        ArrayList<SUser> similarUsers = new ArrayList<SUser>();
        int neighbors = this.config.getInt("NEIGHBOUR_SIZE");
        //find the similar users
        for (int i = 0; i < this.ratingMatrix.getRow(); i++) {
            Double value = this.ratingMatrix.getRatingMatrix().get(i).get(itemIndex);
            //this user also rated the target item
            if (value != null && this.similarityMatrix[userIndex][i] > 0) {
                //get the similarity information
                SUser su = new SUser(i, value.doubleValue(), this.similarityMatrix[userIndex][i]);
                similarUsers.add(su);
            }
        }

        //collaborative filtering cannot work: first user-average and then global-average
        if (similarUsers.size() < 1) {
            double userM = this.ratingMatrix.getUsersMean().get(userIndex);
            if (!Double.isNaN(userM)) {
                return userM;
            } else {
                return this.ratingMatrix.getAverageRating();
            }
        }
        Collections.sort(similarUsers);

        double totalSimilarity = 0;
        double prediction = 0;
        int counter = 0;
        for (int i = similarUsers.size() - 1; i >= 0; i--) {
            totalSimilarity = totalSimilarity + Math.abs(similarUsers.get(i).getSimilarity());
            prediction = prediction + similarUsers.get(i).getSimilarity() * (similarUsers.get(i).getRating()
                    - this.ratingMatrix.getUsersMean().get(similarUsers.get(i).getUserIndex()));
            counter++;
            if (counter == neighbors)
                break;
        }
        if (Double.isNaN(totalSimilarity) || totalSimilarity == 0) {
            double userM = this.ratingMatrix.getUsersMean().get(userIndex);
            if (!Double.isNaN(userM)) {
                return userM;
            } else {
                return this.ratingMatrix.getAverageRating();
            }
        }
        if (rank) {
            //         prediction = prediction / totalSimilarity;   
            //         return this.getPredictionRanking(userIndex, itemIndex);
            return totalSimilarity;
        } else {
            prediction = prediction / totalSimilarity;
            prediction = prediction + this.ratingMatrix.getUsersMean().get(userIndex);
        }
        return prediction;
    }

    @Override
    public void saveModel(String file) {
        // TODO Auto-generated method stub
        //save the similarity matrix
        try {
            PrintWriter printer = new PrintWriter(file);
            for (int i = 0; i < this.similarityMatrix.length; i++) {
                for (int j = 0; j < this.similarityMatrix[0].length; j++) {
                    printer.print(this.similarityMatrix[i][j] + "\t");
                }
                printer.println();
            }
            printer.flush();
            printer.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void readModel(String file) {
        // TODO Auto-generated method stub
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            String line = null;
            int u1 = 0;
            while ((line = reader.readLine()) != null) {
                StringTokenizer tokens = new StringTokenizer(line.trim());
                int u2 = 0;
                while (tokens.hasMoreElements()) {
                    this.similarityMatrix[u1][u2] = Double.parseDouble(tokens.nextToken());
                    u2++;
                }
                u1++;
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void build() {
        // TODO Auto-generated method stub

        similarityMatrixCalculation();
    }

    /**
     * This function generates a recommendation list for a given user
     * @param: user
     * */
    public ArrayList<ResultUnit> getRecommendationList(int userIndex) {
        //      if( this.ratingMatrix.getUserRatingNumber(userIndex) < 10 )
        //         return null;

        ArrayList<ResultUnit> recommendationList = new ArrayList<ResultUnit>();
        //find all item candidate list (items that are not rated by the user)
        for (int i = 0; i < this.ratingMatrix.getColumn(); i++) {
            if (this.ratingMatrix.getRatingMatrix().get(userIndex).get(i) == null) {
                //this item has not been rated by the user
                double prediction = predict(userIndex, i, true);
                ResultUnit unit = new ResultUnit(userIndex, i, prediction);
                recommendationList.add(unit);
            }
        }

        //sort the recommendation list
        Collections.sort(recommendationList);
        ArrayList<ResultUnit> result = new ArrayList<ResultUnit>();
        int top = 0;
        for (int i = recommendationList.size() - 1; i >= 0; i--) {
            result.add(recommendationList.get(i));
            top++;
            //         System.out.print(recommendationList.get(i).getPrediciton() + " , ");
            if (top == this.topN)
                break;
        }
        //      System.out.println();
        return result;
    }

    public double getPredictionRanking(int userIndex, int itemIndex) {
        ArrayList<SUser> similarUsers = new ArrayList<SUser>();

        //find the similar users
        for (int i = 0; i < this.ratingMatrix.getRow(); i++) {
            //get the similarity information
            if (i == userIndex)
                continue;
            Double rating = this.ratingMatrix.getRatingMatrix().get(i).get(itemIndex);
            if (rating == null)
                continue;
            SUser su = new SUser(i, rating.doubleValue(), this.similarityMatrix[userIndex][i]);
            similarUsers.add(su);

        }
        Collections.sort(similarUsers);
        if (similarUsers.size() < 1 || similarUsers.get(similarUsers.size() - 1).getSimilarity() == 0) {
            return 0;
        }
        int count = 0;
        int neighbors = this.config.getInt("NEIGHBOUR_SIZE");
        int c = 0;
        for (int i = similarUsers.size() - 1; i >= 0; i--) {
            if (this.ratingMatrix.getRatingMatrix().get(similarUsers.get(i).getUserIndex())
                    .get(itemIndex) != null) {
                count++;
            }
            c++;
            if (c == neighbors)
                break;
        }
        return (double) count / neighbors;
    }

    /**
     * @return the ratingMatrix
     */
    public RatingMatrix getRatingMatrix() {
        return ratingMatrix;
    }

    /**
     * @param ratingMatrix the ratingMatrix to set
     */
    public void setRatingMatrix(RatingMatrix ratingMatrix) {
        this.ratingMatrix = ratingMatrix;
    }

    /**
     * @return the similarityMatrix
     */
    public double[][] getSimilarityMatrix() {
        return similarityMatrix;
    }

    /**
     * @param similarityMatrix the similarityMatrix to set
     */
    public void setSimilarityMatrix(double[][] similarityMatrix) {
        this.similarityMatrix = similarityMatrix;
    }

    /**
     * @return the topN
     */
    public int getTopN() {
        return topN;
    }

    /**
     * @param topN the topN to set
     */
    public void setTopN(int topN) {
        this.topN = topN;
    }

    /**
     * @return the logger
     */
    public PrintWriter getLogger() {
        return logger;
    }

    /**
     * @param logger the logger to set
     */
    public void setLogger(PrintWriter logger) {
        this.logger = logger;
    }

    /**
     * @return the similarityCalculation
     */
    public String getSimilarityCalculation() {
        return similarityCalculation;
    }

    /**
     * @param similarityCalculation the similarityCalculation to set
     */
    public void setSimilarityCalculation(String similarityCalculation) {
        this.similarityCalculation = similarityCalculation;
    }

    /**
     * @return the maxRating
     */
    public int getMaxRating() {
        return maxRating;
    }

    /**
     * @param maxRating the maxRating to set
     */
    public void setMaxRating(int maxRating) {
        this.maxRating = maxRating;
    }

    /**
     * @return the minRating
     */
    public int getMinRating() {
        return minRating;
    }

    /**
     * @param minRating the minRating to set
     */
    public void setMinRating(int minRating) {
        this.minRating = minRating;
    }

}