com.anhth12.lambda.app.ml.als.Evaluation.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.lambda.app.ml.als.Evaluation.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.lambda.app.ml.als;

import com.anhth12.lambda.common.random.RandomManager;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.DoubleFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;

/**
 *
 * @author Tong Hoang Anh
 */
final class Evaluation {

    private Evaluation() {
    }

    /**
     * Computes root mean squared error
     *
     * @param mfModel
     * @param testData
     * @return
     */
    static double rmse(MatrixFactorizationModel mfModel, JavaRDD<Rating> testData) {
        JavaPairRDD<Tuple2<Integer, Integer>, Double> testUserProductValues = testData
                .mapToPair(new RatingToTupleDouble());
        RDD<Tuple2<Object, Object>> testUserProducts = (RDD<Tuple2<Object, Object>>) (RDD<?>) testUserProductValues
                .keys().rdd();
        JavaRDD<Rating> predictions = testData.wrapRDD(mfModel.predict(testUserProducts));
        double mse = predictions.mapToPair(new RatingToTupleDouble()).join(testUserProductValues).values()
                .mapToDouble(new DoubleFunction<Tuple2<Double, Double>>() {

                    @Override
                    public double call(Tuple2<Double, Double> valuePrediction) throws Exception {
                        double diff = valuePrediction._1() - valuePrediction._2();
                        return diff * diff;
                    }
                }).mean();

        return Math.sqrt(mse);
    }

    /**
     * Compute AUC (area under the ROC curve) as a recommender evaluation
     *
     * @param sparkContext
     * @param mfModel
     * @param positiveData
     * @return
     */
    static double areaUnderCurve(JavaSparkContext sparkContext, MatrixFactorizationModel mfModel,
            JavaRDD<Rating> positiveData) {

        JavaPairRDD<Integer, Integer> positiveUserProducts = positiveData
                .mapToPair(new PairFunction<Rating, Integer, Integer>() {

                    @Override
                    public Tuple2<Integer, Integer> call(Rating t) throws Exception {
                        return new Tuple2<>(t.user(), t.product());
                    }
                });

        JavaPairRDD<Integer, Iterable<Rating>> positivePredictions = predictAll(mfModel, positiveData,
                positiveUserProducts);

        final Broadcast<List<Integer>> allItemIDsBC = sparkContext
                .broadcast(positiveUserProducts.values().distinct().collect());

        JavaPairRDD<Integer, Integer> negativeUserProducts = positiveUserProducts.groupByKey()
                .flatMapToPair(new PairFlatMapFunction<Tuple2<Integer, Iterable<Integer>>, Integer, Integer>() {
                    private final RandomGenerator random = RandomManager.getRandom();

                    @Override
                    public Iterable<Tuple2<Integer, Integer>> call(
                            Tuple2<Integer, Iterable<Integer>> userIDsAndItemIDs) throws Exception {
                        Integer userID = userIDsAndItemIDs._1;
                        Collection<Integer> positiveItemIDs = Sets.newHashSet(userIDsAndItemIDs._2());
                        int numPositive = positiveItemIDs.size();

                        Collection<Tuple2<Integer, Integer>> negative = new ArrayList<>(numPositive);

                        List<Integer> allItemIDs = allItemIDsBC.value();

                        int numItems = allItemIDs.size();

                        for (int i = 0; i < numItems && negative.size() < numPositive; i++) {
                            Integer itemID = allItemIDs.get(random.nextInt(numItems));
                            if (!positiveItemIDs.contains(itemID)) {
                                negative.add(new Tuple2<>(userID, itemID));
                            }
                        }

                        return negative;
                    }
                });

        JavaPairRDD<Integer, Iterable<Rating>> negativePredictions = predictAll(mfModel, positiveData,
                negativeUserProducts);

        return positivePredictions.join(negativePredictions).values()
                .mapToDouble(new DoubleFunction<Tuple2<Iterable<Rating>, Iterable<Rating>>>() {

                    @Override
                    public double call(Tuple2<Iterable<Rating>, Iterable<Rating>> t) throws Exception {
                        //AUC is also the probability that random positive examples
                        //ranking higher than random examples at large. Heare wer compare all random negative
                        //examples to all positive exampls and rapost the totals as an alternative 
                        //computatioin for AUC
                        long correct = 0;
                        long total = 0;

                        for (Rating positive : t._1()) {
                            for (Rating negative : t._2()) {
                                if (positive.rating() > negative.rating()) {
                                    correct++;
                                }
                                total++;
                            }
                        }

                        return (double) correct / total;
                    }
                }).mean();

    }

    private static JavaPairRDD<Integer, Iterable<Rating>> predictAll(MatrixFactorizationModel mfModel,
            JavaRDD<Rating> data, JavaPairRDD<Integer, Integer> userProducts) {
        RDD<Tuple2<Object, Object>> userProductsRDD = (RDD<Tuple2<Object, Object>>) (RDD<?>) userProducts.rdd();
        return data.wrapRDD(mfModel.predict(userProductsRDD)).groupBy(new Function<Rating, Integer>() {

            @Override
            public Integer call(Rating r) throws Exception {
                return r.user();
            }

        });
    }
}