com.anhth12.lambda.app.speed.ALSSpeedModelManager.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.lambda.app.speed.ALSSpeedModelManager.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.lambda.app.speed;

import com.anhth12.lambda.KeyMessage;
import com.anhth12.lambda.app.common.fn.MLFunctions;
import com.anhth12.lambda.common.math.Solver;
import com.anhth12.lambda.common.math.VectorMath;
import com.anhth12.lambda.common.pmml.PMMLUtils;
import com.anhth12.lambda.common.text.TextUtils;
import com.anhth12.lambda.fn.Functions;
import com.anhth12.lambda.speed.SpeedModelManager;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.typesafe.config.Config;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import javax.xml.bind.JAXBException;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.dmg.pmml.PMML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/**
 *
 * @author Tong Hoang Anh
 */
public class ALSSpeedModelManager implements SpeedModelManager<String, String, String> {

    private static final Logger log = LoggerFactory.getLogger(ALSSpeedModelManager.class);
    private static final ObjectMapper MAPPER = new ObjectMapper();

    private ALSSpeedModel model;
    private final boolean implicit;
    private final boolean noKnownItems;

    public ALSSpeedModelManager(Config config) {
        implicit = config.getBoolean("lambda.als.implicit");
        noKnownItems = config.getBoolean("lambda.als.no-known-items");
    }

    @Override
    public void consume(Iterator<KeyMessage<String, String>> updateIterator) throws IOException {
        while (updateIterator.hasNext()) {
            KeyMessage<String, String> km = updateIterator.next();
            String key = km.getKey();
            String message = km.getMessage();

            switch (key) {
            case "UP":
                if (model == null) {
                    continue; //no model to interpret yet, so skip it
                }
                List<?> update = MAPPER.readValue(message, List.class);
                String id = update.get(1).toString();
                float[] vector = MAPPER.convertValue(update.get(2), float[].class);
                switch (update.get(0).toString()) {
                case "X":
                    model.setUserVector(id, vector);
                    break;
                case "Y":
                    model.setItemVector(id, vector);
                    break;
                default:
                    throw new IllegalStateException("Bad update " + message);
                }
                break;
            case "MODEL":
                PMML pmml;
                try {
                    pmml = PMMLUtils.fromString(message);
                } catch (JAXBException e) {
                    throw new IOException(e);
                }
                int features = 0;
                if (model == null) {
                    log.info("No previous model; create new model");
                    model = new ALSSpeedModel(features);
                } else if (features != model.getFeatures()) {
                    model = new ALSSpeedModel(features);
                } else {
                    log.info("Updating current model");

                }
                break;
            default:
                throw new IllegalStateException("Unexpected key " + key);
            }
        }
    }

    @Override
    public Iterable<String> buildUpdates(JavaPairRDD<String, String> newData) throws IOException {
        if (model == null) {
            return Collections.emptyList();
        }

        JavaRDD<String> sortedValues = newData.values().sortBy(MLFunctions.TO_TIMESTAMP_FN, true,
                newData.partitions().size());
        JavaPairRDD<Tuple2<String, String>, Double> tuples = sortedValues.mapToPair(TO_TUPLE_FN);

        JavaPairRDD<Tuple2<String, String>, Double> aggregated;
        if (implicit) {
            aggregated = tuples.groupByKey().mapValues(MLFunctions.SUM_WITH_NAN);
        } else {
            aggregated = tuples.foldByKey(Double.NaN, Functions.<Double>last());
        }

        Collection<UserItemStrength> input = aggregated.filter(MLFunctions.<Tuple2<String, String>>notNaNValue())
                .map(TO_UIS_FN).collect();

        Solver XTXsolver;
        Solver YTYsolver;
        Collection<String> result = new ArrayList<>();
        try {
            XTXsolver = model.getXTXSolver();
            YTYsolver = model.getYTYSolver();
        } catch (Exception e) {
            return Collections.emptyList();
        }

        for (UserItemStrength uis : input) {
            String user = uis.getUser();
            String item = uis.getItem();
            double value = uis.getStrength();

            float[] Xu = model.getUserVector(user);
            float[] Yi = model.getUserVector(user);

            double[] newXu = newVector(YTYsolver, value, Xu, Yi);

            double[] newYi = newVector(XTXsolver, value, Yi, Xu);

            if (newXu != null) {
                result.add(toUpdateJSON("X", user, newXu, item));
            }

            if (newYi != null) {
                result.add(toUpdateJSON("Y", item, newYi, user));
            }

        }

        return result;
    }

    private double[] newVector(Solver solver, double value, float[] Xu, float[] Yi) {
        double[] newXu = null;

        if (Yi != null) {

            //Qui = Xu*(Yi)^t 
            double currentValue = Xu == null ? 0.5 : VectorMath.dot(Xu, Yi);

            double targetQui = computeTargetQui(value, currentValue);

            if (!Double.isNaN(targetQui)) {
                float[] QuiYi = Yi.clone();
                for (int i = 0; i < QuiYi.length; i++) {
                    QuiYi[i] *= targetQui;
                }
                newXu = solver.solveFToD(Yi);

            }
        }
        return newXu;
    }

    private String toUpdateJSON(String matrix, String ID, double[] vector, String otherID) {
        List<?> args;
        if (noKnownItems) {
            args = Arrays.asList(matrix, ID, vector);
        } else {
            args = Arrays.asList(matrix, ID, vector, Collections.singletonList(otherID));
        }
        return TextUtils.joinJSON(args);
    }

    private double computeTargetQui(double value, double currentValue) {
        if (implicit) {
            // Target is really 1, or 0, depending on whether value is positive or negative.
            // This wouldn't account for the strength though. Instead the target is a function
            // of the current value and strength. If the current value is c, and value is positive
            // then the target is somewhere between c and 1 depending on the strength. If current
            // value is already >= 1, there's no effect. Similarly for negative values.
            if (value > 0.0f && currentValue < 1.0) {
                double diff = 1.0 - Math.max(0.0, currentValue);
                return currentValue + (1.0 - 1.0 / (1.0 + value)) * diff;
            }
            if (value < 0.0f && currentValue > 0.0) {
                double diff = -Math.min(1.0, currentValue);
                return currentValue + (1.0 - 1.0 / (1.0 - value)) * diff;
            }
            // No change
            return Double.NaN;
        } else {
            return value;
        }
    }

    @Override
    public void close() {
        //do nothing
    }

    private static final PairFunction<String, Tuple2<String, String>, Double> TO_TUPLE_FN = new PairFunction<String, Tuple2<String, String>, Double>() {

        @Override
        public Tuple2<Tuple2<String, String>, Double> call(String line) throws Exception {
            String[] tokens = MLFunctions.PARSE_FN.call(line);
            String user = tokens[0];
            String item = tokens[1];
            Double strength = Double.valueOf(tokens[2]);

            return new Tuple2<>(new Tuple2<>(user, item), strength);
        }
    };

    private static final Function<Tuple2<Tuple2<String, String>, Double>, UserItemStrength> TO_UIS_FN = new Function<Tuple2<Tuple2<String, String>, Double>, UserItemStrength>() {

        @Override
        public UserItemStrength call(Tuple2<Tuple2<String, String>, Double> tuple) throws Exception {
            return new UserItemStrength(tuple._1()._1(), tuple._1()._2(), tuple._2().floatValue());
        }
    };

}