de.tuberlin.dima.recsys.ssnmm.ratingprediction.UserItemBaseline.java Source code

Java tutorial

Introduction

Here is the source code for de.tuberlin.dima.recsys.ssnmm.ratingprediction.UserItemBaseline.java

Source

/*
 * Copyright (C) 2012 Sebastian Schelter <sebastian.schelter [at] tu-berlin.de>
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */

package de.tuberlin.dima.recsys.ssnmm.ratingprediction;

import com.google.common.base.Charsets;
import com.google.common.io.Closeables;
import com.google.common.io.Files;
import de.tuberlin.dima.recsys.ssnmm.Rating;
import de.tuberlin.dima.recsys.ssnmm.RatingsIterable;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;

import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;

/**
 * Java port of the "UserItemBaseline" rating predictor from "mymedialite" https://github.com/zenogantner/MyMediaLite/
 */
public class UserItemBaseline {

    public static void main(String[] args) throws IOException {

        File trainingFile = new File("/home/ssc/Entwicklung/datasets/yahoo-songs/songs.tsv");
        File testFile = new File("/home/ssc/Entwicklung/datasets/yahoo-songs/holdout.tsv");
        File outputDir = new File("/home/ssc/Desktop/yahoo/");

        int numUsers = 1823179;
        int numItems = 136736;
        double mu = 3.157255412010664;

        int numIterations = 3;

        UserItemBaseline baseline = new UserItemBaseline(trainingFile, testFile, 0.5, 0, numUsers, numItems, mu);

        for (int n = 0; n < numIterations; n++) {
            baseline.train();
        }

        baseline.test();

        baseline.persistBiases(outputDir);
    }

    private double[] userBiases;
    private double[] itemBiases;

    private double globalAverage;

    private final File ratings;
    private final File tests;

    private final double regI;
    private final double regU;

    public UserItemBaseline(File ratings, File tests, double regU, double regI, int numUsers, int numItems,
            double mu) {
        this.ratings = ratings;
        this.tests = tests;
        this.regU = regU;
        this.regI = regI;

        globalAverage = mu;

        userBiases = new double[numUsers];
        Arrays.fill(userBiases, 0);
        itemBiases = new double[numItems];
        Arrays.fill(itemBiases, 0);

    }

    void test() throws IOException {

        RunningAverage rmse = new FullRunningAverage();
        RunningAverage mae = new FullRunningAverage();

        System.out.println("Calculating predictions");
        for (Rating rating : new RatingsIterable(tests)) {

            double error = Math.abs(rating.rating() - baselineEstimate(rating.user(), rating.item()));

            mae.addDatum(error);
            rmse.addDatum(error * error);
        }

        System.out.println("MAE " + mae.getAverage() + ", RMSE: " + Math.sqrt(rmse.getAverage()));
    }

    double baselineEstimate(int user, int item) {
        return globalAverage + userBiases[user] + itemBiases[item];
    }

    void train() throws IOException {
        optimizeItemBiases();
        optimizeUserBiases();
    }

    void optimizeItemBiases() throws IOException {

        System.out.println("Optimizing item biases...");

        int[] itemRatingsCount = new int[itemBiases.length];
        Arrays.fill(itemRatingsCount, 0);

        int ratingsProcessed = 0;
        for (Rating rating : new RatingsIterable(ratings)) {

            itemBiases[rating.item()] += rating.rating() - globalAverage - userBiases[rating.user()];
            itemRatingsCount[rating.item()]++;

            if (++ratingsProcessed % 10000000 == 0) {
                System.out.println((ratingsProcessed / 1000000) + "M ratings processed");
            }
        }

        for (int item = 0; item < itemBiases.length; item++) {
            if (itemRatingsCount[item] != 0) {
                itemBiases[item] /= regI + itemRatingsCount[item];
            }
        }

    }

    void optimizeUserBiases() throws IOException {

        System.out.println("Optimizing user biases...");

        int[] userRatingsCount = new int[userBiases.length];
        Arrays.fill(userRatingsCount, 0);

        int ratingsProcessed = 0;
        for (Rating rating : new RatingsIterable(ratings)) {

            userBiases[rating.user()] += rating.rating() - globalAverage - itemBiases[rating.item()];
            userRatingsCount[rating.user()]++;

            if (++ratingsProcessed % 10000000 == 0) {
                System.out.println((ratingsProcessed / 1000000) + "M ratings processed");
            }
        }

        for (int user = 0; user < userBiases.length; user++) {
            if (userRatingsCount[user] != 0) {
                userBiases[user] /= regU + userRatingsCount[user];
            }
        }
    }

    void persistBiases(File dir) throws IOException {
        persist(new File(dir, "userBiases.tsv"), userBiases);
        persist(new File(dir, "itemBiases.tsv"), itemBiases);
    }

    private void persist(File file, double[] biases) throws IOException {
        BufferedWriter writer = null;
        try {
            writer = Files.newWriter(file, Charsets.UTF_8);
            for (int index = 0; index < biases.length; index++) {
                writer.append(String.valueOf(index));
                writer.append("\t");
                writer.append(String.valueOf(biases[index]));
                writer.append("\n");
            }

        } finally {
            Closeables.closeQuietly(writer);
        }
    }

}