Java tutorial
/* * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. * * Cloudera, Inc. licenses this file to you under the Apache License, * Version 2.0 (the "License"). You may not use this file except in * compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR * CONDITIONS OF ANY KIND, either express or implied. See the License for * the specific language governing permissions and limitations under the * License. */ package com.cloudera.oryx.ml.mllib.als; import java.util.ArrayList; import java.util.Collection; import java.util.List; import com.google.common.collect.Sets; import org.apache.commons.math3.random.RandomGenerator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.DoubleFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; import org.apache.spark.mllib.recommendation.Rating; import org.apache.spark.rdd.RDD; import scala.Tuple2; import com.cloudera.oryx.common.random.RandomManager; /** * Computes AUC (area under the ROC curve) as a recommender evaluation metric. * Really, it computes what might be described as "Mean AUC", as it computes AUC per * user and averages them. */ final class AUC { private AUC() { } static double areaUnderCurve(JavaSparkContext sparkContext, MatrixFactorizationModel mfModel, JavaRDD<Rating> positiveData) { // This does not use Spark's BinaryClassificationMetrics.areaUnderROC because it // is intended to operate on one large set of (score,label) pairs. The computation // here is really many small AUC problems, for which a much faster direct computation // is available. // Extract all positive (user,product) pairs JavaPairRDD<Integer, Integer> positiveUserProducts = positiveData .mapToPair(new PairFunction<Rating, Integer, Integer>() { @Override public Tuple2<Integer, Integer> call(Rating rating) { return new Tuple2<>(rating.user(), rating.product()); } }); JavaPairRDD<Integer, Iterable<Rating>> positivePredictions = predictAll(mfModel, positiveData, positiveUserProducts); // All distinct item IDs, to be broadcast final Broadcast<List<Integer>> allItemIDsBC = sparkContext .broadcast(positiveUserProducts.values().distinct().collect()); JavaPairRDD<Integer, Integer> negativeUserProducts = positiveUserProducts.groupByKey() .flatMapToPair(new PairFlatMapFunction<Tuple2<Integer, Iterable<Integer>>, Integer, Integer>() { private final RandomGenerator random = RandomManager.getRandom(); @Override public Iterable<Tuple2<Integer, Integer>> call( Tuple2<Integer, Iterable<Integer>> userIDsAndItemIDs) { Integer userID = userIDsAndItemIDs._1(); Collection<Integer> positiveItemIDs = Sets.newHashSet(userIDsAndItemIDs._2()); int numPositive = positiveItemIDs.size(); Collection<Tuple2<Integer, Integer>> negative = new ArrayList<>(numPositive); List<Integer> allItemIDs = allItemIDsBC.value(); int numItems = allItemIDs.size(); // Sample about as many negative examples as positive for (int i = 0; i < numItems && negative.size() < numPositive; i++) { Integer itemID = allItemIDs.get(random.nextInt(numItems)); if (!positiveItemIDs.contains(itemID)) { negative.add(new Tuple2<>(userID, itemID)); } } return negative; } }); JavaPairRDD<Integer, Iterable<Rating>> negativePredictions = predictAll(mfModel, positiveData, negativeUserProducts); return positivePredictions.join(negativePredictions).values() .mapToDouble(new DoubleFunction<Tuple2<Iterable<Rating>, Iterable<Rating>>>() { @Override public double call(Tuple2<Iterable<Rating>, Iterable<Rating>> t) { // AUC is also the probability that random positive examples // rank higher than random examples at large. Here we compare all random negative // examples to all positive examples and report the totals as an alternative // computation for AUC long correct = 0; long total = 0; for (Rating positive : t._1()) { for (Rating negative : t._2()) { if (positive.rating() > negative.rating()) { correct++; } total++; } } return (double) correct / total; } }).mean(); } private static JavaPairRDD<Integer, Iterable<Rating>> predictAll(MatrixFactorizationModel mfModel, JavaRDD<Rating> data, JavaPairRDD<Integer, Integer> userProducts) { @SuppressWarnings("unchecked") RDD<Tuple2<Object, Object>> userProductsRDD = (RDD<Tuple2<Object, Object>>) (RDD<?>) userProducts.rdd(); return data.wrapRDD(mfModel.predict(userProductsRDD)).groupBy(new Function<Rating, Integer>() { @Override public Integer call(Rating r) { return r.user(); } }); } }