ch.epfl.lsir.xin.algorithm.core.ItemBasedCF.java Source code

Java tutorial

Introduction

Here is the source code for ch.epfl.lsir.xin.algorithm.core.ItemBasedCF.java

Source

//Copyright (C) 2014  Xin Liu
//
//RecMe: a lightweight recommendation algorithm library
//
//RecMe is free software; you can redistribute it and/or
//modify it under the terms of the GNU General Public License
//as published by the Free Software Foundation; either version 2
//of the License, or (at your option) any later version.
//
//This program is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU General Public License for more details.
//
//You should have received a copy of the GNU General Public License
//along with this program; if not, write to the Free Software
//Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

/**
 * This class implements the item based collaborative filtering algorithm
 * 
 * @author Xin Liu
 * */

package ch.epfl.lsir.xin.algorithm.core;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.StringTokenizer;

import org.apache.commons.configuration.ConfigurationException;
import org.apache.commons.configuration.PropertiesConfiguration;

import ch.epfl.lsir.xin.algorithm.IAlgorithm;
import ch.epfl.lsir.xin.datatype.RatingMatrix;
import ch.epfl.lsir.xin.evaluation.ResultUnit;
import ch.epfl.lsir.xin.model.SItem;
import ch.epfl.lsir.xin.util.SimilarityCalculator;

public class ItemBasedCF implements IAlgorithm {

    /**
     * the rating matrix
     * */
    private RatingMatrix ratingMatrix = null;

    /**
     * user similarity matrix
     * */
    private double[][] similarityMatrix = null;

    /**
     * logger of the system
     * */
    private PrintWriter logger = null;

    /**
     * Configuration file for parameter setting.
     * */
    public PropertiesConfiguration config = new PropertiesConfiguration();

    /**
     * Top N recommendation
     * */
    private int topN = -1;

    private int maxRating = -1;
    private int minRating = -1;

    /**
     * similarity calculation method
     * */
    private String similarityCalculation = null;

    /**
     * constructor
     * @param training ratings
     * */
    public ItemBasedCF(RatingMatrix ratingMatrix) {
        //set configuration file for parameter setting.
        config.setFile(new File(".//conf//ItemBasedCF.properties"));
        try {
            config.load();
        } catch (ConfigurationException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        this.topN = this.config.getInt("TOP_N_RECOMMENDATION");
        this.similarityCalculation = this.config.getString("SIMILARITY");
        this.ratingMatrix = ratingMatrix;
        this.maxRating = this.config.getInt("MAX_RATING");
        this.minRating = this.config.getInt("MIN_RATING");
        this.similarityMatrix = new double[this.ratingMatrix.getColumn()][this.ratingMatrix.getColumn()];
        similarityMatrixCalculation();

        //display similarity matrix
        //      try {
        //         PrintWriter printer = new PrintWriter("matrix");
        //         for( int i = 0 ; i < this.similarityMatrix.length ; i++ )
        //         {
        //            for( int j = 0 ; j < this.similarityMatrix[i].length ; j++ )
        //            {
        //               printer.print(this.similarityMatrix[i][j] + "  ");
        //            }
        //            printer.println();
        //         }
        //         printer.flush();
        //         printer.close();
        //      } catch (FileNotFoundException e) {
        //         // TODO Auto-generated catch block
        //         e.printStackTrace();
        //      }
    }

    /**
     * constructor
     * @param: training ratings
     * @param: read a saved model or not
     * @param: file of a saved model 
     * */
    public ItemBasedCF(RatingMatrix ratingMatrix, boolean readModel, String file) {
        config.setFile(new File(".//conf//ItemBasedCF.properties"));
        try {
            config.load();
        } catch (ConfigurationException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        this.topN = this.config.getInt("TOP_N_RECOMMENDATION");
        this.similarityCalculation = this.config.getString("SIMILARITY");
        this.ratingMatrix = ratingMatrix;
        this.maxRating = this.config.getInt("MAX_RATING");
        this.minRating = this.config.getInt("MIN_RATING");
        this.similarityMatrix = new double[this.ratingMatrix.getColumn()][this.ratingMatrix.getColumn()];

        if (readModel) {
            readModel(file);
        } else {
            similarityMatrixCalculation();
        }
    }

    /**
     * This function calculates the similarity matrix for items
     * */
    public void similarityMatrixCalculation() {
        for (int i = 0; i < this.ratingMatrix.getColumn(); i++) {
            for (int j = i; j < this.ratingMatrix.getColumn(); j++) {
                if (i == j) //the similarity with herself is 1 
                {
                    this.similarityMatrix[i][j] = 1;
                } else {
                    ArrayList<Double> commonRatings1 = new ArrayList<Double>();
                    ArrayList<Double> commonRatings2 = new ArrayList<Double>();
                    //find common ratings for the two items
                    for (int i1 = 0; i1 < this.ratingMatrix.getRow(); i1++) {
                        if (this.ratingMatrix.getRatingMatrix().get(i1).get(i) != null
                                && this.ratingMatrix.getRatingMatrix().get(i1).get(j) != null) {
                            commonRatings1.add(this.ratingMatrix.getRatingMatrix().get(i1).get(i));
                            commonRatings2.add(this.ratingMatrix.getRatingMatrix().get(i1).get(j));
                        }
                    }
                    double similarity = Double.NaN;
                    if (this.similarityCalculation.equals("pcc")) {
                        similarity = SimilarityCalculator.getSimilarityPCC(commonRatings1, commonRatings2,
                                this.config.getInt("SHRINKAGE"));
                    } else if (this.similarityCalculation.equals("cosine")) {
                        similarity = SimilarityCalculator.getSimilarityCosine(commonRatings1, commonRatings2,
                                this.config.getInt("SHRINKAGE"));
                    } else {
                        logger.append("Cannot determine which similarity calculation method is used for. \n");
                        return;
                    }

                    if (Double.isNaN(similarity)) {
                        similarity = 0;
                    }
                    this.similarityMatrix[i][j] = similarity;
                    this.similarityMatrix[j][i] = similarity;
                }
            }
        }
    }

    @Override
    public void saveModel(String file) {
        // TODO Auto-generated method stub
        //save the similarity matrix
        try {
            PrintWriter printer = new PrintWriter(file);
            for (int i = 0; i < this.similarityMatrix.length; i++) {
                for (int j = 0; j < this.similarityMatrix[0].length; j++) {
                    printer.print(this.similarityMatrix[i][j] + "\t");
                }
                printer.println();
            }
            printer.flush();
            printer.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void readModel(String file) {
        // TODO Auto-generated method stub
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            String line = null;
            int u1 = 0;
            while ((line = reader.readLine()) != null) {
                StringTokenizer tokens = new StringTokenizer(line.trim());
                int u2 = 0;
                while (tokens.hasMoreElements()) {
                    this.similarityMatrix[u1][u2] = Double.parseDouble(tokens.nextToken());
                    u2++;
                }
                u1++;
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void build() {
        // TODO Auto-generated method stub
        this.similarityMatrixCalculation();
    }

    /**
     * This function predicts a user's rating to an item
     * @param: index of the user
     * @param: index of the item
     * @param: if the prediction is for ranking
     * @return: the predicted rating
     * */
    public double predict(int userIndex, int itemIndex, boolean rank) {
        ArrayList<SItem> similarItems = new ArrayList<SItem>();
        int neighbors = this.config.getInt("NEIGHBOUR_SIZE");
        //find the similar items
        for (int i = 0; i < this.ratingMatrix.getColumn(); i++) {
            Double value = this.ratingMatrix.getRatingMatrix().get(userIndex).get(i);
            //this item is also rated by the target user
            if (value != null && this.similarityMatrix[itemIndex][i] > 0) {
                //get the similarity information
                SItem si = new SItem(i, value.doubleValue(), this.similarityMatrix[itemIndex][i]);
                similarItems.add(si);
            }
        }

        //collaborative filtering cannot work: first user-average and then global-average
        if (similarItems.size() < 1 /*|| similarItems.get(similarItems.size()-1).getSimilarity() == 0*/ ) {
            //         System.out.println("Cannot be predicted by UserCF");
            double itemM = this.ratingMatrix.getItemMean(itemIndex);
            if (!Double.isNaN(itemM)) {
                return itemM;
            } else {
                return this.ratingMatrix.getAverageRating();
            }
        }
        Collections.sort(similarItems);

        double totalSimilarity = 0;
        double prediction = 0;
        int counter = 0;
        double simRanking = 0;
        for (int i = similarItems.size() - 1; i >= 0; i--) {
            totalSimilarity = totalSimilarity + Math.abs(similarItems.get(i).getSimilarity());
            simRanking += similarItems.get(i).getSimilarity();
            prediction = prediction + similarItems.get(i).getSimilarity() * (similarItems.get(i).getRating()
                    - this.ratingMatrix.getItemsMean().get(similarItems.get(i).getItemIndex()));
            //         if(  Double.isNaN(totalSimilarity) || totalSimilarity == 0  )
            //         {
            //            System.out.println(similarItems.get(i).getSimilarity() + " ,  "
            //                  + similarItems.get(i).getRating() + " pred: " + prediction);
            //         }
            counter++;
            if (counter == neighbors)
                break;
        }
        if (Double.isNaN(totalSimilarity) || totalSimilarity == 0) {
            double itemM = this.ratingMatrix.getItemMean(itemIndex);
            if (!Double.isNaN(itemM)) {
                return itemM;
            } else {
                return this.ratingMatrix.getAverageRating();
            }
        }
        if (rank) {
            return totalSimilarity;
            //         return simRanking;
        } else {
            prediction = prediction / totalSimilarity;
            prediction = prediction + this.ratingMatrix.getItemMean(itemIndex);
        }

        return prediction;
    }

    /**
     * This function generates a recommendation list for a given user
     * @param: user
     * */
    public ArrayList<ResultUnit> getRecommendationList(int userIndex) {
        ArrayList<ResultUnit> recommendationList = new ArrayList<ResultUnit>();
        //find all item candidate list (items that are not rated by the user)
        for (int i = 0; i < this.ratingMatrix.getColumn(); i++) {
            if (this.ratingMatrix.getRatingMatrix().get(userIndex).get(i) == null) {
                //this item has not been rated by the user
                double prediction = predict(userIndex, i, true);
                ResultUnit unit = new ResultUnit(userIndex, i, prediction);
                recommendationList.add(unit);
            }
        }
        //sort the recommendation list
        Collections.sort(recommendationList);
        ArrayList<ResultUnit> result = new ArrayList<ResultUnit>();
        int top = 0;
        for (int i = recommendationList.size() - 1; i >= 0; i--) {
            result.add(recommendationList.get(i));
            top++;
            //         System.out.print(recommendationList.get(i).getPrediciton() + " , ");
            if (top == this.topN)
                break;
        }
        //      System.out.println();
        return result;
    }

    /**
     * @return the ratingMatrix
     */
    public RatingMatrix getRatingMatrix() {
        return ratingMatrix;
    }

    /**
     * @param ratingMatrix the ratingMatrix to set
     */
    public void setRatingMatrix(RatingMatrix ratingMatrix) {
        this.ratingMatrix = ratingMatrix;
    }

    /**
     * @return the similarityMatrix
     */
    public double[][] getSimilarityMatrix() {
        return similarityMatrix;
    }

    /**
     * @param similarityMatrix the similarityMatrix to set
     */
    public void setSimilarityMatrix(double[][] similarityMatrix) {
        this.similarityMatrix = similarityMatrix;
    }

    /**
     * @return the topN
     */
    public int getTopN() {
        return topN;
    }

    /**
     * @param topN the topN to set
     */
    public void setTopN(int topN) {
        this.topN = topN;
    }

    /**
     * @return the similarityCalculation
     */
    public String getSimilarityCalculation() {
        return similarityCalculation;
    }

    /**
     * @param similarityCalculation the similarityCalculation to set
     */
    public void setSimilarityCalculation(String similarityCalculation) {
        this.similarityCalculation = similarityCalculation;
    }

    /**
     * @return the logger
     */
    public PrintWriter getLogger() {
        return logger;
    }

    /**
     * @param logger the logger to set
     */
    public void setLogger(PrintWriter logger) {
        this.logger = logger;
    }

    /**
     * @return the maxRating
     */
    public int getMaxRating() {
        return maxRating;
    }

    /**
     * @param maxRating the maxRating to set
     */
    public void setMaxRating(int maxRating) {
        this.maxRating = maxRating;
    }

    /**
     * @return the minRating
     */
    public int getMinRating() {
        return minRating;
    }

    /**
     * @param minRating the minRating to set
     */
    public void setMinRating(int minRating) {
        this.minRating = minRating;
    }

}