com.anhth12.lambda.app.ml.als.ALSUpdate.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.lambda.app.ml.als.ALSUpdate.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.lambda.app.ml.als;

import com.anhth12.lambda.TopicProducer;
import com.anhth12.lambda.app.common.fn.MLFunctions;
import com.anhth12.lambda.app.pmml.AppPMMLUtils;
import com.anhth12.lambda.common.pmml.PMMLUtils;
import com.anhth12.lambda.common.text.TextUtils;
import com.anhth12.lambda.fn.Functions;
import com.anhth12.lambda.ml.MLUpdate;
import com.anhth12.lambda.ml.param.HyperParamValues;
import com.anhth12.lambda.ml.param.HyperParams;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hashing;
import com.typesafe.config.Config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.compress.GzipCodec;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;
import org.apache.spark.storage.StorageLevel;
import org.dmg.pmml.PMML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import scala.reflect.ClassTag$;

/**
 *
 * @author Tong Hoang Anh
 */
public final class ALSUpdate extends MLUpdate<String> {

    private static final Logger log = LoggerFactory.getLogger(ALSUpdate.class);

    private static final ObjectMapper MAPPER = new ObjectMapper();
    private static final HashFunction HASH = Hashing.md5();
    private static final Pattern MOST_INTS_PATTERN = Pattern.compile("(0|-?[1-9][0-9]{0,9})");

    private final int iterations;
    private final boolean implicit;
    private final List<HyperParamValues<?>> hyperParamValues;
    private final boolean noKnownItems;
    private final double decayFactor;
    private final double decayZeroThreshold;

    public ALSUpdate(Config config) {
        super(config);

        iterations = config.getInt("lambda.als.iterations");
        implicit = config.getBoolean("lambda.als.implicit");
        Preconditions.checkArgument(iterations > 0);
        hyperParamValues = Arrays.asList(HyperParams.fromConfig(config, "lambda.als.hyperparams.features"),
                HyperParams.fromConfig(config, "lambda.als.hyperparams.lambda"),
                HyperParams.fromConfig(config, "lambda.als.hyperparams.alpha"));

        noKnownItems = config.getBoolean("lambda.als.no-known-items");
        decayFactor = config.getDouble("lambda.als.decay.factor");
        decayZeroThreshold = config.getDouble("lambda.als.decay.zero-threshold");
        Preconditions.checkArgument(iterations > 0);
        Preconditions.checkArgument(decayFactor > 0.0 && decayFactor <= 1.0);
        Preconditions.checkArgument(decayZeroThreshold >= 0.0);
    }

    @Override
    public List<HyperParamValues<?>> getHyperParamValues() {
        return hyperParamValues;
    }

    @Override
    public PMML buildModel(JavaSparkContext sparkContext, JavaRDD<String> trainData, List<?> hyperParameter,
            Path cadidatePath) {
        int features = (Integer) hyperParameter.get(0);
        double lambda = (Double) hyperParameter.get(1);
        double alpha = (Double) hyperParameter.get(2);
        Preconditions.checkArgument(features > 0);
        Preconditions.checkArgument(lambda >= 0.0);
        Preconditions.checkArgument(alpha > 0.0);

        JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN);

        JavaRDD<Rating> traingRatingData = parseToRating(parsedRDD);
        traingRatingData = aggregatedScores(traingRatingData);

        MatrixFactorizationModel model;
        if (implicit) {
            model = ALS.trainImplicit(traingRatingData.rdd(), features, iterations, lambda, alpha);
        } else {
            model = ALS.train(traingRatingData.rdd(), features, iterations, lambda);
        }

        Map<Integer, String> reverseIdLookUp = parsedRDD.flatMapToPair(new ToReverseLookupFn())
                .reduceByKey(Functions.<String>last()).collectAsMap();
        //Clone, due to some serialization problems with the result of collectAsMap
        reverseIdLookUp = new HashMap<>(reverseIdLookUp);

        PMML pmml = mfModelToPMML(model, features, lambda, alpha, implicit, cadidatePath, reverseIdLookUp);
        unpersist(model);
        return pmml;

    }

    @Override
    public double evaluate(JavaSparkContext sparkContext, PMML model, Path modelParentPath,
            JavaRDD<String> testData, JavaRDD<String> trainData) {
        JavaRDD<Rating> testRatingData = parseToRating(testData.map(MLFunctions.PARSE_FN));
        testRatingData = aggregatedScores(testRatingData);
        MatrixFactorizationModel mfModel = pmmlToMFModel(sparkContext, model, modelParentPath);

        double eval;
        if (implicit) {
            double auc = Evaluation.areaUnderCurve(sparkContext, mfModel, testRatingData);
            log.info("Area Under Curve: {}", auc);
            eval = auc;
        } else {
            double rmse = Evaluation.rmse(mfModel, testRatingData);
            log.info("RMSE: {}", rmse);
            eval = 1.0 / rmse;
        }

        unpersist(mfModel);
        return eval;
    }

    @Override
    public void publishAdditionalModelData(JavaSparkContext sparkContext, PMML pmml, JavaRDD<String> newData,
            JavaRDD<String> pastData, Path modelParentPath, TopicProducer<String, String> modelUpdateTopic) {

    }

    private JavaRDD<Rating> parseToRating(JavaRDD<String[]> parsedRDD) {
        JavaPairRDD<Long, Rating> timestampRatingRDD = parsedRDD.mapToPair(new ParseRatingFn());

        if (decayFactor < 1.0) {
            final double factor = decayFactor;
            final long now = System.currentTimeMillis();
            timestampRatingRDD = timestampRatingRDD
                    .mapToPair(new PairFunction<Tuple2<Long, Rating>, Long, Rating>() {

                        @Override
                        public Tuple2<Long, Rating> call(Tuple2<Long, Rating> timestampRating) throws Exception {
                            long timestamp = timestampRating._1;
                            Rating rating = timestampRating._2;

                            double newRating;
                            if (timestamp >= now) {
                                newRating = rating.rating();
                            } else {
                                double days = (now - timestamp) / 86400000.0;
                                newRating = rating.rating() * Math.pow(factor, days);
                            }
                            return new Tuple2<>(timestamp, new Rating(rating.user(), rating.product(), newRating));
                        }

                    });
        }

        if (decayZeroThreshold > 0.0) {
            final double theThreshold = decayZeroThreshold;
            timestampRatingRDD = timestampRatingRDD.filter(new Function<Tuple2<Long, Rating>, Boolean>() {

                @Override
                public Boolean call(Tuple2<Long, Rating> t1) throws Exception {
                    return t1._2.rating() > theThreshold;
                }
            });
        }
        return timestampRatingRDD.sortByKey().values();
    }

    private JavaRDD<Rating> aggregatedScores(JavaRDD<Rating> original) {
        JavaPairRDD<Tuple2<Integer, Integer>, Double> tuples = original.mapToPair(new RatingToTupleDouble());

        JavaPairRDD<Tuple2<Integer, Integer>, Double> aggregated;

        if (implicit) {
            aggregated = tuples.groupByKey().mapValues(MLFunctions.SUM_WITH_NAN);
        } else {
            aggregated = tuples.foldByKey(Double.NaN, Functions.<Double>last());
        }

        return aggregated.filter(MLFunctions.<Tuple2<Integer, Integer>>notNaNValue())
                .map(new Function<Tuple2<Tuple2<Integer, Integer>, Double>, Rating>() {
                    @Override
                    public Rating call(Tuple2<Tuple2<Integer, Integer>, Double> userProductScore) throws Exception {
                        Tuple2<Integer, Integer> userProduct = userProductScore._1();
                        return new Rating(userProduct._1(), userProduct._2(), userProductScore._2());
                    }
                });
    }

    private PMML mfModelToPMML(MatrixFactorizationModel model, int features, double lambda, double alpha,
            boolean implicit, Path cadidatePath, Map<Integer, String> reverseIdLookUp) {
        JavaPairRDD<Integer, double[]> userFeaturesRDD = massageToIntKey(model.userFeatures());
        JavaPairRDD<Integer, double[]> itemFeaturesRDD = massageToIntKey(model.productFeatures());

        saveFeaturesRDD(userFeaturesRDD, new Path(cadidatePath, "X"), reverseIdLookUp);
        saveFeaturesRDD(itemFeaturesRDD, new Path(cadidatePath, "Y"), reverseIdLookUp);

        PMML pmml = PMMLUtils.buildSkeletonPMML();
        AppPMMLUtils.addExtension(pmml, "X", "X/");
        AppPMMLUtils.addExtension(pmml, "Y", "Y/");
        AppPMMLUtils.addExtension(pmml, "features", features);
        AppPMMLUtils.addExtension(pmml, "lambda", alpha);
        AppPMMLUtils.addExtension(pmml, "implicit", implicit);
        if (implicit) {
            AppPMMLUtils.addExtension(pmml, "alpha", alpha);
        }

        addIDsExtentions(pmml, "XIDs", userFeaturesRDD, reverseIdLookUp);
        addIDsExtentions(pmml, "YIDs", itemFeaturesRDD, reverseIdLookUp);
        return pmml;
    }

    private static <A, B> JavaPairRDD<Integer, B> massageToIntKey(RDD<Tuple2<A, B>> in) {
        JavaPairRDD<Integer, B> javaRDD = fromRDD((RDD<Tuple2<Integer, B>>) (RDD<?>) in);
        return javaRDD;
    }

    private void unpersist(MatrixFactorizationModel model) {
        model.userFeatures().unpersist(false);
        model.productFeatures().unpersist(false);
    }

    private static MatrixFactorizationModel pmmlToMFModel(JavaSparkContext sparkContext, PMML model,
            Path modelParentPath) {
        String xPathString = AppPMMLUtils.getExtensionValue(model, "X");
        String yPathString = AppPMMLUtils.getExtensionValue(model, "Y");
        JavaPairRDD<String, double[]> userRDD = readFeaturesRDD(sparkContext,
                new Path(modelParentPath, xPathString));
        JavaPairRDD<String, double[]> productRDD = readFeaturesRDD(sparkContext,
                new Path(modelParentPath, yPathString));

        int rank = userRDD.first()._2.length;
        return new MatrixFactorizationModel(rank, readAndConvertFeatureRDD(userRDD),
                readAndConvertFeatureRDD(productRDD));
    }

    private static RDD<Tuple2<Object, double[]>> readAndConvertFeatureRDD(JavaPairRDD<String, double[]> javaRDD) {
        RDD<Tuple2<Integer, double[]>> scalaRDD = javaRDD
                .mapToPair(new PairFunction<Tuple2<String, double[]>, Integer, double[]>() {

                    @Override
                    public Tuple2<Integer, double[]> call(Tuple2<String, double[]> t) throws Exception {
                        return new Tuple2<>(parseOrHashInt(t._1()), t._2());
                    }
                }).rdd();

        scalaRDD.persist(StorageLevel.MEMORY_AND_DISK());

        RDD<Tuple2<Object, double[]>> objKeyRDD = (RDD<Tuple2<Object, double[]>>) (RDD<?>) scalaRDD;
        return objKeyRDD;
    }

    private static JavaPairRDD<String, double[]> readFeaturesRDD(JavaSparkContext sparkContext, Path path) {
        log.info("Loading features RDD from {}", path);

        JavaRDD<String> featureLines = sparkContext.textFile(path.toString());

        return featureLines.mapToPair(new PairFunction<String, String, double[]>() {

            @Override
            public Tuple2<String, double[]> call(String t) throws Exception {
                List<?> update = MAPPER.readValue(t, List.class);
                String key = update.get(0).toString();
                double[] vector = MAPPER.convertValue(update.get(1), double[].class);
                return new Tuple2<>(key, vector);
            }
        });
    }

    private static final class ToReverseLookupFn implements PairFlatMapFunction<String[], Integer, String> {

        @Override
        public Iterable<Tuple2<Integer, String>> call(String[] tokens) throws Exception {
            List<Tuple2<Integer, String>> results = new ArrayList<>(2);
            for (int i = 0; i < 2; i++) {
                String s = tokens[i];
                if (MOST_INTS_PATTERN.matcher(s).matches()) {
                    try {
                        Integer.parseInt(s);
                        continue;
                    } catch (NumberFormatException nfe) {
                        //continue
                    }
                }
                results.add(new Tuple2<>(hash(s), s));
            }
            return results;
        }

    }

    private static final class ParseRatingFn implements PairFunction<String[], Long, Rating> {

        @Override
        public Tuple2<Long, Rating> call(String[] tokens) throws Exception {
            return new Tuple2<>(Long.valueOf(tokens[3]), new Rating(parseOrHashInt(tokens[0]),
                    parseOrHashInt(tokens[1]), tokens[2].isEmpty() ? Double.NaN : Double.parseDouble(tokens[2])));
        }
    }

    private static int parseOrHashInt(String s) {
        if (MOST_INTS_PATTERN.matcher(s).matches()) {
            try {
                return Integer.parseInt(s);
            } catch (NumberFormatException e) {

            }
        }
        return hash(s);
    }

    private static int hash(String s) {
        return HASH.hashString(s).asInt() & 0x7FFFFFFF;
    }

    private static <K, V> JavaPairRDD<K, V> fromRDD(RDD<Tuple2<K, V>> rdd) {
        return JavaPairRDD.fromRDD(rdd, ClassTag$.MODULE$.<K>apply(Object.class),
                ClassTag$.MODULE$.<V>apply(Object.class));
    }

    private static void saveFeaturesRDD(JavaPairRDD<Integer, double[]> features, Path path,
            final Map<Integer, String> reverseIDMaping) {
        log.info("Saving features RDD to {}", path);
        features.map(new Function<Tuple2<Integer, double[]>, String>() {

            @Override
            public String call(Tuple2<Integer, double[]> keyAndVector) throws Exception {
                Integer id = keyAndVector._1();
                String originalKey = reverseIDMaping.get(id);
                Object key = originalKey == null ? id : originalKey;
                double[] vector = keyAndVector._2();
                return TextUtils.joinJSON(Arrays.asList(key, vector));
            }
        }).saveAsTextFile(path.toString(), GzipCodec.class);
    }

    private static void addIDsExtentions(PMML pmml, String key, JavaPairRDD<Integer, double[]> feature,
            Map<Integer, String> reverseIDMaping) {
        List<Integer> hashedIDs = feature.keys().collect();
        List<String> ids = new ArrayList<>(hashedIDs.size());

        for (Integer hashedID : hashedIDs) {
            String originalId = reverseIDMaping.get(hashedID);
            ids.add(originalId == null ? hashedIDs.toString() : originalId);
        }

        AppPMMLUtils.addExtensionContent(pmml, key, ids);
    }

}