Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.lambda.app.ml.als; import com.anhth12.lambda.TopicProducer; import com.anhth12.lambda.app.common.fn.MLFunctions; import com.anhth12.lambda.app.pmml.AppPMMLUtils; import com.anhth12.lambda.common.pmml.PMMLUtils; import com.anhth12.lambda.common.text.TextUtils; import com.anhth12.lambda.fn.Functions; import com.anhth12.lambda.ml.MLUpdate; import com.anhth12.lambda.ml.param.HyperParamValues; import com.anhth12.lambda.ml.param.HyperParams; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Preconditions; import com.google.common.hash.HashFunction; import com.google.common.hash.Hashing; import com.typesafe.config.Config; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.regex.Pattern; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.compress.GzipCodec; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.recommendation.ALS; import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; import org.apache.spark.mllib.recommendation.Rating; import org.apache.spark.rdd.RDD; import org.apache.spark.storage.StorageLevel; import org.dmg.pmml.PMML; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; import scala.reflect.ClassTag$; /** * * @author Tong Hoang Anh */ public final class ALSUpdate extends MLUpdate<String> { private static final Logger log = LoggerFactory.getLogger(ALSUpdate.class); private static final ObjectMapper MAPPER = new ObjectMapper(); private static final HashFunction HASH = Hashing.md5(); private static final Pattern MOST_INTS_PATTERN = Pattern.compile("(0|-?[1-9][0-9]{0,9})"); private final int iterations; private final boolean implicit; private final List<HyperParamValues<?>> hyperParamValues; private final boolean noKnownItems; private final double decayFactor; private final double decayZeroThreshold; public ALSUpdate(Config config) { super(config); iterations = config.getInt("lambda.als.iterations"); implicit = config.getBoolean("lambda.als.implicit"); Preconditions.checkArgument(iterations > 0); hyperParamValues = Arrays.asList(HyperParams.fromConfig(config, "lambda.als.hyperparams.features"), HyperParams.fromConfig(config, "lambda.als.hyperparams.lambda"), HyperParams.fromConfig(config, "lambda.als.hyperparams.alpha")); noKnownItems = config.getBoolean("lambda.als.no-known-items"); decayFactor = config.getDouble("lambda.als.decay.factor"); decayZeroThreshold = config.getDouble("lambda.als.decay.zero-threshold"); Preconditions.checkArgument(iterations > 0); Preconditions.checkArgument(decayFactor > 0.0 && decayFactor <= 1.0); Preconditions.checkArgument(decayZeroThreshold >= 0.0); } @Override public List<HyperParamValues<?>> getHyperParamValues() { return hyperParamValues; } @Override public PMML buildModel(JavaSparkContext sparkContext, JavaRDD<String> trainData, List<?> hyperParameter, Path cadidatePath) { int features = (Integer) hyperParameter.get(0); double lambda = (Double) hyperParameter.get(1); double alpha = (Double) hyperParameter.get(2); Preconditions.checkArgument(features > 0); Preconditions.checkArgument(lambda >= 0.0); Preconditions.checkArgument(alpha > 0.0); JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN); JavaRDD<Rating> traingRatingData = parseToRating(parsedRDD); traingRatingData = aggregatedScores(traingRatingData); MatrixFactorizationModel model; if (implicit) { model = ALS.trainImplicit(traingRatingData.rdd(), features, iterations, lambda, alpha); } else { model = ALS.train(traingRatingData.rdd(), features, iterations, lambda); } Map<Integer, String> reverseIdLookUp = parsedRDD.flatMapToPair(new ToReverseLookupFn()) .reduceByKey(Functions.<String>last()).collectAsMap(); //Clone, due to some serialization problems with the result of collectAsMap reverseIdLookUp = new HashMap<>(reverseIdLookUp); PMML pmml = mfModelToPMML(model, features, lambda, alpha, implicit, cadidatePath, reverseIdLookUp); unpersist(model); return pmml; } @Override public double evaluate(JavaSparkContext sparkContext, PMML model, Path modelParentPath, JavaRDD<String> testData, JavaRDD<String> trainData) { JavaRDD<Rating> testRatingData = parseToRating(testData.map(MLFunctions.PARSE_FN)); testRatingData = aggregatedScores(testRatingData); MatrixFactorizationModel mfModel = pmmlToMFModel(sparkContext, model, modelParentPath); double eval; if (implicit) { double auc = Evaluation.areaUnderCurve(sparkContext, mfModel, testRatingData); log.info("Area Under Curve: {}", auc); eval = auc; } else { double rmse = Evaluation.rmse(mfModel, testRatingData); log.info("RMSE: {}", rmse); eval = 1.0 / rmse; } unpersist(mfModel); return eval; } @Override public void publishAdditionalModelData(JavaSparkContext sparkContext, PMML pmml, JavaRDD<String> newData, JavaRDD<String> pastData, Path modelParentPath, TopicProducer<String, String> modelUpdateTopic) { } private JavaRDD<Rating> parseToRating(JavaRDD<String[]> parsedRDD) { JavaPairRDD<Long, Rating> timestampRatingRDD = parsedRDD.mapToPair(new ParseRatingFn()); if (decayFactor < 1.0) { final double factor = decayFactor; final long now = System.currentTimeMillis(); timestampRatingRDD = timestampRatingRDD .mapToPair(new PairFunction<Tuple2<Long, Rating>, Long, Rating>() { @Override public Tuple2<Long, Rating> call(Tuple2<Long, Rating> timestampRating) throws Exception { long timestamp = timestampRating._1; Rating rating = timestampRating._2; double newRating; if (timestamp >= now) { newRating = rating.rating(); } else { double days = (now - timestamp) / 86400000.0; newRating = rating.rating() * Math.pow(factor, days); } return new Tuple2<>(timestamp, new Rating(rating.user(), rating.product(), newRating)); } }); } if (decayZeroThreshold > 0.0) { final double theThreshold = decayZeroThreshold; timestampRatingRDD = timestampRatingRDD.filter(new Function<Tuple2<Long, Rating>, Boolean>() { @Override public Boolean call(Tuple2<Long, Rating> t1) throws Exception { return t1._2.rating() > theThreshold; } }); } return timestampRatingRDD.sortByKey().values(); } private JavaRDD<Rating> aggregatedScores(JavaRDD<Rating> original) { JavaPairRDD<Tuple2<Integer, Integer>, Double> tuples = original.mapToPair(new RatingToTupleDouble()); JavaPairRDD<Tuple2<Integer, Integer>, Double> aggregated; if (implicit) { aggregated = tuples.groupByKey().mapValues(MLFunctions.SUM_WITH_NAN); } else { aggregated = tuples.foldByKey(Double.NaN, Functions.<Double>last()); } return aggregated.filter(MLFunctions.<Tuple2<Integer, Integer>>notNaNValue()) .map(new Function<Tuple2<Tuple2<Integer, Integer>, Double>, Rating>() { @Override public Rating call(Tuple2<Tuple2<Integer, Integer>, Double> userProductScore) throws Exception { Tuple2<Integer, Integer> userProduct = userProductScore._1(); return new Rating(userProduct._1(), userProduct._2(), userProductScore._2()); } }); } private PMML mfModelToPMML(MatrixFactorizationModel model, int features, double lambda, double alpha, boolean implicit, Path cadidatePath, Map<Integer, String> reverseIdLookUp) { JavaPairRDD<Integer, double[]> userFeaturesRDD = massageToIntKey(model.userFeatures()); JavaPairRDD<Integer, double[]> itemFeaturesRDD = massageToIntKey(model.productFeatures()); saveFeaturesRDD(userFeaturesRDD, new Path(cadidatePath, "X"), reverseIdLookUp); saveFeaturesRDD(itemFeaturesRDD, new Path(cadidatePath, "Y"), reverseIdLookUp); PMML pmml = PMMLUtils.buildSkeletonPMML(); AppPMMLUtils.addExtension(pmml, "X", "X/"); AppPMMLUtils.addExtension(pmml, "Y", "Y/"); AppPMMLUtils.addExtension(pmml, "features", features); AppPMMLUtils.addExtension(pmml, "lambda", alpha); AppPMMLUtils.addExtension(pmml, "implicit", implicit); if (implicit) { AppPMMLUtils.addExtension(pmml, "alpha", alpha); } addIDsExtentions(pmml, "XIDs", userFeaturesRDD, reverseIdLookUp); addIDsExtentions(pmml, "YIDs", itemFeaturesRDD, reverseIdLookUp); return pmml; } private static <A, B> JavaPairRDD<Integer, B> massageToIntKey(RDD<Tuple2<A, B>> in) { JavaPairRDD<Integer, B> javaRDD = fromRDD((RDD<Tuple2<Integer, B>>) (RDD<?>) in); return javaRDD; } private void unpersist(MatrixFactorizationModel model) { model.userFeatures().unpersist(false); model.productFeatures().unpersist(false); } private static MatrixFactorizationModel pmmlToMFModel(JavaSparkContext sparkContext, PMML model, Path modelParentPath) { String xPathString = AppPMMLUtils.getExtensionValue(model, "X"); String yPathString = AppPMMLUtils.getExtensionValue(model, "Y"); JavaPairRDD<String, double[]> userRDD = readFeaturesRDD(sparkContext, new Path(modelParentPath, xPathString)); JavaPairRDD<String, double[]> productRDD = readFeaturesRDD(sparkContext, new Path(modelParentPath, yPathString)); int rank = userRDD.first()._2.length; return new MatrixFactorizationModel(rank, readAndConvertFeatureRDD(userRDD), readAndConvertFeatureRDD(productRDD)); } private static RDD<Tuple2<Object, double[]>> readAndConvertFeatureRDD(JavaPairRDD<String, double[]> javaRDD) { RDD<Tuple2<Integer, double[]>> scalaRDD = javaRDD .mapToPair(new PairFunction<Tuple2<String, double[]>, Integer, double[]>() { @Override public Tuple2<Integer, double[]> call(Tuple2<String, double[]> t) throws Exception { return new Tuple2<>(parseOrHashInt(t._1()), t._2()); } }).rdd(); scalaRDD.persist(StorageLevel.MEMORY_AND_DISK()); RDD<Tuple2<Object, double[]>> objKeyRDD = (RDD<Tuple2<Object, double[]>>) (RDD<?>) scalaRDD; return objKeyRDD; } private static JavaPairRDD<String, double[]> readFeaturesRDD(JavaSparkContext sparkContext, Path path) { log.info("Loading features RDD from {}", path); JavaRDD<String> featureLines = sparkContext.textFile(path.toString()); return featureLines.mapToPair(new PairFunction<String, String, double[]>() { @Override public Tuple2<String, double[]> call(String t) throws Exception { List<?> update = MAPPER.readValue(t, List.class); String key = update.get(0).toString(); double[] vector = MAPPER.convertValue(update.get(1), double[].class); return new Tuple2<>(key, vector); } }); } private static final class ToReverseLookupFn implements PairFlatMapFunction<String[], Integer, String> { @Override public Iterable<Tuple2<Integer, String>> call(String[] tokens) throws Exception { List<Tuple2<Integer, String>> results = new ArrayList<>(2); for (int i = 0; i < 2; i++) { String s = tokens[i]; if (MOST_INTS_PATTERN.matcher(s).matches()) { try { Integer.parseInt(s); continue; } catch (NumberFormatException nfe) { //continue } } results.add(new Tuple2<>(hash(s), s)); } return results; } } private static final class ParseRatingFn implements PairFunction<String[], Long, Rating> { @Override public Tuple2<Long, Rating> call(String[] tokens) throws Exception { return new Tuple2<>(Long.valueOf(tokens[3]), new Rating(parseOrHashInt(tokens[0]), parseOrHashInt(tokens[1]), tokens[2].isEmpty() ? Double.NaN : Double.parseDouble(tokens[2]))); } } private static int parseOrHashInt(String s) { if (MOST_INTS_PATTERN.matcher(s).matches()) { try { return Integer.parseInt(s); } catch (NumberFormatException e) { } } return hash(s); } private static int hash(String s) { return HASH.hashString(s).asInt() & 0x7FFFFFFF; } private static <K, V> JavaPairRDD<K, V> fromRDD(RDD<Tuple2<K, V>> rdd) { return JavaPairRDD.fromRDD(rdd, ClassTag$.MODULE$.<K>apply(Object.class), ClassTag$.MODULE$.<V>apply(Object.class)); } private static void saveFeaturesRDD(JavaPairRDD<Integer, double[]> features, Path path, final Map<Integer, String> reverseIDMaping) { log.info("Saving features RDD to {}", path); features.map(new Function<Tuple2<Integer, double[]>, String>() { @Override public String call(Tuple2<Integer, double[]> keyAndVector) throws Exception { Integer id = keyAndVector._1(); String originalKey = reverseIDMaping.get(id); Object key = originalKey == null ? id : originalKey; double[] vector = keyAndVector._2(); return TextUtils.joinJSON(Arrays.asList(key, vector)); } }).saveAsTextFile(path.toString(), GzipCodec.class); } private static void addIDsExtentions(PMML pmml, String key, JavaPairRDD<Integer, double[]> feature, Map<Integer, String> reverseIDMaping) { List<Integer> hashedIDs = feature.keys().collect(); List<String> ids = new ArrayList<>(hashedIDs.size()); for (Integer hashedID : hashedIDs) { String originalId = reverseIDMaping.get(hashedID); ids.add(originalId == null ? hashedIDs.toString() : originalId); } AppPMMLUtils.addExtensionContent(pmml, key, ids); } }