Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.lambda.app.speed; import com.anhth12.lambda.KeyMessage; import com.anhth12.lambda.app.common.fn.MLFunctions; import com.anhth12.lambda.common.math.Solver; import com.anhth12.lambda.common.math.VectorMath; import com.anhth12.lambda.common.pmml.PMMLUtils; import com.anhth12.lambda.common.text.TextUtils; import com.anhth12.lambda.fn.Functions; import com.anhth12.lambda.speed.SpeedModelManager; import com.fasterxml.jackson.databind.ObjectMapper; import com.typesafe.config.Config; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; import javax.xml.bind.JAXBException; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.dmg.pmml.PMML; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; /** * * @author Tong Hoang Anh */ public class ALSSpeedModelManager implements SpeedModelManager<String, String, String> { private static final Logger log = LoggerFactory.getLogger(ALSSpeedModelManager.class); private static final ObjectMapper MAPPER = new ObjectMapper(); private ALSSpeedModel model; private final boolean implicit; private final boolean noKnownItems; public ALSSpeedModelManager(Config config) { implicit = config.getBoolean("lambda.als.implicit"); noKnownItems = config.getBoolean("lambda.als.no-known-items"); } @Override public void consume(Iterator<KeyMessage<String, String>> updateIterator) throws IOException { while (updateIterator.hasNext()) { KeyMessage<String, String> km = updateIterator.next(); String key = km.getKey(); String message = km.getMessage(); switch (key) { case "UP": if (model == null) { continue; //no model to interpret yet, so skip it } List<?> update = MAPPER.readValue(message, List.class); String id = update.get(1).toString(); float[] vector = MAPPER.convertValue(update.get(2), float[].class); switch (update.get(0).toString()) { case "X": model.setUserVector(id, vector); break; case "Y": model.setItemVector(id, vector); break; default: throw new IllegalStateException("Bad update " + message); } break; case "MODEL": PMML pmml; try { pmml = PMMLUtils.fromString(message); } catch (JAXBException e) { throw new IOException(e); } int features = 0; if (model == null) { log.info("No previous model; create new model"); model = new ALSSpeedModel(features); } else if (features != model.getFeatures()) { model = new ALSSpeedModel(features); } else { log.info("Updating current model"); } break; default: throw new IllegalStateException("Unexpected key " + key); } } } @Override public Iterable<String> buildUpdates(JavaPairRDD<String, String> newData) throws IOException { if (model == null) { return Collections.emptyList(); } JavaRDD<String> sortedValues = newData.values().sortBy(MLFunctions.TO_TIMESTAMP_FN, true, newData.partitions().size()); JavaPairRDD<Tuple2<String, String>, Double> tuples = sortedValues.mapToPair(TO_TUPLE_FN); JavaPairRDD<Tuple2<String, String>, Double> aggregated; if (implicit) { aggregated = tuples.groupByKey().mapValues(MLFunctions.SUM_WITH_NAN); } else { aggregated = tuples.foldByKey(Double.NaN, Functions.<Double>last()); } Collection<UserItemStrength> input = aggregated.filter(MLFunctions.<Tuple2<String, String>>notNaNValue()) .map(TO_UIS_FN).collect(); Solver XTXsolver; Solver YTYsolver; Collection<String> result = new ArrayList<>(); try { XTXsolver = model.getXTXSolver(); YTYsolver = model.getYTYSolver(); } catch (Exception e) { return Collections.emptyList(); } for (UserItemStrength uis : input) { String user = uis.getUser(); String item = uis.getItem(); double value = uis.getStrength(); float[] Xu = model.getUserVector(user); float[] Yi = model.getUserVector(user); double[] newXu = newVector(YTYsolver, value, Xu, Yi); double[] newYi = newVector(XTXsolver, value, Yi, Xu); if (newXu != null) { result.add(toUpdateJSON("X", user, newXu, item)); } if (newYi != null) { result.add(toUpdateJSON("Y", item, newYi, user)); } } return result; } private double[] newVector(Solver solver, double value, float[] Xu, float[] Yi) { double[] newXu = null; if (Yi != null) { //Qui = Xu*(Yi)^t double currentValue = Xu == null ? 0.5 : VectorMath.dot(Xu, Yi); double targetQui = computeTargetQui(value, currentValue); if (!Double.isNaN(targetQui)) { float[] QuiYi = Yi.clone(); for (int i = 0; i < QuiYi.length; i++) { QuiYi[i] *= targetQui; } newXu = solver.solveFToD(Yi); } } return newXu; } private String toUpdateJSON(String matrix, String ID, double[] vector, String otherID) { List<?> args; if (noKnownItems) { args = Arrays.asList(matrix, ID, vector); } else { args = Arrays.asList(matrix, ID, vector, Collections.singletonList(otherID)); } return TextUtils.joinJSON(args); } private double computeTargetQui(double value, double currentValue) { if (implicit) { // Target is really 1, or 0, depending on whether value is positive or negative. // This wouldn't account for the strength though. Instead the target is a function // of the current value and strength. If the current value is c, and value is positive // then the target is somewhere between c and 1 depending on the strength. If current // value is already >= 1, there's no effect. Similarly for negative values. if (value > 0.0f && currentValue < 1.0) { double diff = 1.0 - Math.max(0.0, currentValue); return currentValue + (1.0 - 1.0 / (1.0 + value)) * diff; } if (value < 0.0f && currentValue > 0.0) { double diff = -Math.min(1.0, currentValue); return currentValue + (1.0 - 1.0 / (1.0 - value)) * diff; } // No change return Double.NaN; } else { return value; } } @Override public void close() { //do nothing } private static final PairFunction<String, Tuple2<String, String>, Double> TO_TUPLE_FN = new PairFunction<String, Tuple2<String, String>, Double>() { @Override public Tuple2<Tuple2<String, String>, Double> call(String line) throws Exception { String[] tokens = MLFunctions.PARSE_FN.call(line); String user = tokens[0]; String item = tokens[1]; Double strength = Double.valueOf(tokens[2]); return new Tuple2<>(new Tuple2<>(user, item), strength); } }; private static final Function<Tuple2<Tuple2<String, String>, Double>, UserItemStrength> TO_UIS_FN = new Function<Tuple2<Tuple2<String, String>, Double>, UserItemStrength>() { @Override public UserItemStrength call(Tuple2<Tuple2<String, String>, Double> tuple) throws Exception { return new UserItemStrength(tuple._1()._1(), tuple._1()._2(), tuple._2().floatValue()); } }; }