Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.lambda.app.serving.als.model; import com.anhth12.lambda.app.als.RescorerProvider; import com.anhth12.lambda.common.collection.AndPredicate; import com.anhth12.lambda.common.collection.KeyOnlyBiPredicate; import com.anhth12.lambda.common.collection.NotContainsPredicate; import com.anhth12.lambda.common.collection.Pair; import com.anhth12.lambda.common.collection.PairComparators; import com.anhth12.lambda.common.lang.AutoLock; import com.anhth12.lambda.common.lang.LoggingCallable; import com.anhth12.lambda.common.math.LinearSystemSolver; import com.anhth12.lambda.common.math.Solver; import com.anhth12.lambda.common.math.VectorMath; import com.google.common.base.Preconditions; import com.google.common.collect.Iterables; import com.google.common.collect.Ordering; import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.Queue; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import net.openhft.koloboke.collect.map.ObjIntMap; import net.openhft.koloboke.collect.map.ObjObjMap; import net.openhft.koloboke.collect.map.hash.HashObjIntMaps; import net.openhft.koloboke.collect.map.hash.HashObjObjMap; import net.openhft.koloboke.collect.map.hash.HashObjObjMaps; import net.openhft.koloboke.collect.set.ObjSet; import net.openhft.koloboke.collect.set.hash.HashObjSet; import net.openhft.koloboke.collect.set.hash.HashObjSets; import net.openhft.koloboke.function.BiConsumer; import net.openhft.koloboke.function.ObjDoubleToDoubleFunction; import net.openhft.koloboke.function.Predicate; import net.openhft.koloboke.function.ToDoubleFunction; import org.apache.commons.math3.linear.RealMatrix; /** * * @author Tong Hoang Anh */ public final class ALSServingModel { private static final int PARTITIONS = Runtime.getRuntime().availableProcessors(); private static final ExecutorService executor = PARTITIONS <= 1 ? null : Executors.newFixedThreadPool(PARTITIONS, new ThreadFactoryBuilder().setDaemon(true).setNameFormat("ALSServingModel-%d").build()); //USER-FEATURE matrix private final ObjObjMap<String, float[]> X; //ITEM-FEATURE matrix private final ObjObjMap<String, float[]>[] Y; private final Collection<String> recentNewUsers; private final Collection<String>[] recentNewItems; private final ObjObjMap<String, ObjSet<String>> knownItems; //Controls access to X private final ReadWriteLock xLock; //Controls access to partitions of Y private final ReadWriteLock[] yLocks; private final int features; private final boolean implicit; private final RescorerProvider rescorerProvider; ALSServingModel(int features, boolean implicit, RescorerProvider rescoreProvider) { Preconditions.checkArgument(features > 0); X = HashObjObjMaps.newMutableMap(); Y = (ObjObjMap<String, float[]>[]) Array.newInstance(ObjObjMap.class, PARTITIONS); for (int i = 0; i < Y.length; i++) { Y[i] = HashObjObjMaps.newMutableMap(); } recentNewUsers = new HashSet<>(); recentNewItems = (Collection<String>[]) Array.newInstance(HashSet.class, PARTITIONS); for (int i = 0; i < recentNewItems.length; i++) { recentNewItems[i] = new HashSet<>(); } knownItems = HashObjObjMaps.newMutableMap(); xLock = new ReentrantReadWriteLock(); yLocks = new ReadWriteLock[Y.length]; for (int i = 0; i < yLocks.length; i++) { yLocks[i] = new ReentrantReadWriteLock(); } this.features = features; this.implicit = implicit; this.rescorerProvider = rescoreProvider; } public int getFeatures() { return features; } public boolean isImplicit() { return implicit; } public RescorerProvider getRescorerProvider() { return rescorerProvider; } private static int partition(Object o) { return (o.hashCode() & 0x7FFFFFFF) % PARTITIONS; } public float[] getUserVector(String user) { try (AutoLock al = new AutoLock(xLock.readLock())) { return X.get(user); } } public float[] getItemVector(String item) { int partition = partition(item); try (AutoLock al = new AutoLock(yLocks[partition].readLock())) { return Y[partition].get(item); } } void setUserVector(String user, float[] vector) { Preconditions.checkNotNull(user); Preconditions.checkArgument(vector.length == features); try (AutoLock al = new AutoLock(xLock.writeLock())) { if (X.put(user, vector) == null) { recentNewUsers.add(user); } } } void setItemVector(String item, float[] vector) { Preconditions.checkNotNull(vector); Preconditions.checkArgument(vector.length == features); int partition = partition(item); try (AutoLock al = new AutoLock(yLocks[partition].writeLock())) { if (Y[partition].put(item, vector) == null) { recentNewItems[partition].add(item); } } } public Collection<String> getKnownItems(String user) { return doGetKnownItems(user); } private ObjSet<String> doGetKnownItems(String user) { try (AutoLock al = new AutoLock((xLock.readLock()))) { return knownItems.get(user); } } public Map<String, Integer> getUserCounts() { ObjIntMap<String> count = HashObjIntMaps.newUpdatableMap(); try (AutoLock al = new AutoLock(xLock.readLock())) { for (Map.Entry<String, ObjSet<String>> entry : knownItems.entrySet()) { String userID = entry.getKey(); Collection<?> ids = entry.getValue(); int numItems; synchronized (ids) { numItems = ids.size(); } count.addValue(userID, numItems); } } return count; } public Map<String, Integer> getItemCounts() { ObjIntMap<String> count = HashObjIntMaps.newUpdatableMap(); try (AutoLock al = new AutoLock(xLock.readLock())) { for (Collection<String> ids : knownItems.values()) { synchronized (ids) { for (String id : ids) { count.addValue(id, 1); } } } } return count; } void addKnownItems(String user, Collection<String> items) { ObjSet<String> knownItemsForUser = doGetKnownItems(user); if (knownItemsForUser != null) { try (AutoLock al = new AutoLock(xLock.writeLock())) { knownItemsForUser = knownItems.get(user); if (knownItemsForUser == null) { knownItemsForUser = HashObjSets.newMutableSet(); knownItems.put(user, knownItemsForUser); } } } synchronized (knownItemsForUser) { knownItemsForUser.addAll(items); } } public List<Pair<String, float[]>> getKnowItemVectorsForUser(String user) { float[] userVector = getUserVector(user); if (userVector == null) { return null; } Collection<String> knownItems = getKnownItems(user); if (knownItems == null || knownItems.isEmpty()) { return null; } List<Pair<String, float[]>> idVectors = new ArrayList<>(knownItems.size()); synchronized (knownItems) { for (String itemID : knownItems) { int partition = partition(itemID); float[] vector; try (AutoLock al = new AutoLock(yLocks[partition].readLock())) { vector = Y[partition].get(itemID); } idVectors.add(new Pair<>(itemID, vector)); } } return idVectors; } public List<Pair<String, Double>> topN(ToDoubleFunction<float[]> scoreFn, int howmany) { return topN(scoreFn, null, howmany, null); } public List<Pair<String, Double>> topN(final ToDoubleFunction<float[]> scoreFn, final ObjDoubleToDoubleFunction<String> rescoreFn, final int howMany, final Predicate<String> allowedPredicate) { List<Callable<Iterable<Pair<String, Double>>>> tasks = new ArrayList<>(Y.length); for (int partition = 0; partition < Y.length; partition++) { final int thePartition = partition; tasks.add(new LoggingCallable<Iterable<Pair<String, Double>>>() { @Override public Iterable<Pair<String, Double>> doCall() throws Exception { Queue<Pair<String, Double>> topN = new PriorityQueue<>(howMany + 1, PairComparators.<Double>bySecond()); TopNConsumer topNPoc = new TopNConsumer(topN, howMany, scoreFn, rescoreFn, allowedPredicate); try (AutoLock al = new AutoLock(yLocks[thePartition].readLock())) { Y[thePartition].forEach(topNPoc); } return topN; } }); } List<Iterable<Pair<String, Double>>> iterables = new ArrayList<>(); if (Y.length >= 2) { try { for (Future<Iterable<Pair<String, Double>>> future : executor.invokeAll(tasks)) { iterables.add(future.get()); } } catch (InterruptedException | ExecutionException e) { throw new IllegalStateException(e); } } else { try { iterables.add(tasks.get(0).call()); } catch (Exception e) { throw new IllegalStateException(e); } } return Ordering.from(PairComparators.<Double>bySecond()).greatestOf(Iterables.concat(iterables), howMany); } public Collection<String> getAllUserIDs() { Collection<String> userList; try (AutoLock al = new AutoLock(xLock.readLock())) { userList = new ArrayList<>(X.keySet()); } return userList; } public Collection<String> getAllItemIDs() { Collection<String> itemsList = new ArrayList<>(); for (int partition = 0; partition < Y.length; partition++) { try (AutoLock al = new AutoLock(yLocks[partition].readLock())) { itemsList.addAll(Y[partition].keySet()); } } return itemsList; } public Solver getYTYSolver() { RealMatrix YTY = null; for (int partition = 0; partition < Y.length; partition++) { RealMatrix YTYpartial; try (AutoLock al = new AutoLock(yLocks[partition].readLock())) { YTYpartial = VectorMath.transposeTimesSelf(Y[partition].values()); } if (YTYpartial != null) { YTY = YTY == null ? YTYpartial : YTY.add(YTYpartial); } } return new LinearSystemSolver().getSolver(YTY); } void pruneX(Collection<String> users) { try (AutoLock al = new AutoLock(xLock.writeLock())) { X.removeIf(new KeyOnlyBiPredicate<>(new AndPredicate<>(new NotContainsPredicate<>(users), new NotContainsPredicate<>(recentNewUsers)))); recentNewUsers.clear(); } } void pruneY(Collection<String> items) { for (int partition = 0; partition < Y.length; partition++) { try (AutoLock al = new AutoLock(yLocks[partition].writeLock())) { Y[partition].removeIf(new KeyOnlyBiPredicate<>(new AndPredicate<>(new NotContainsPredicate<>(items), new NotContainsPredicate<>(recentNewItems[partition])))); recentNewItems[partition].clear(); } } } void pruneKnownItems(Collection<String> users, final Collection<String> items) { try (AutoLock al = new AutoLock(xLock.writeLock())) { knownItems.removeIf(new KeyOnlyBiPredicate<>(new AndPredicate<>(new NotContainsPredicate<>(users), new NotContainsPredicate<>(recentNewUsers)))); } final Collection<String> allRecentKnownItems = new HashSet<>(); for (int partition = 0; partition < Y.length; partition++) { try (AutoLock al = new AutoLock(yLocks[partition].writeLock())) { allRecentKnownItems.addAll(recentNewItems[partition]); } } try (AutoLock al = new AutoLock(xLock.readLock())) { for (ObjSet<String> knownItemsForUser : knownItems.values()) { synchronized (knownItemsForUser) { knownItemsForUser.removeIf(new Predicate<String>() { @Override public boolean test(String value) { return !items.contains(value) && !allRecentKnownItems.contains(value); } }); } } } } @Override public String toString() { int numItems = 0; for (Map<?, ?> partition : Y) { numItems += partition.size(); } return "ALSServingModel[features:" + features + ", implicit:" + implicit + ",X(" + X.size() + " users), Y:(" + numItems + "items)]"; } private static final class TopNConsumer implements BiConsumer<String, float[]> { private final Queue<Pair<String, Double>> topN; private final int howMany; private final ToDoubleFunction<float[]> scoreFn; private final ObjDoubleToDoubleFunction<String> rescoreFn; private final Predicate<String> allowedPredicate; private double topScoreLowerBound; private boolean full; public TopNConsumer(Queue<Pair<String, Double>> topN, int howMany, ToDoubleFunction<float[]> scoreFn, ObjDoubleToDoubleFunction<String> rescoreFn, Predicate<String> allowedPredicate) { this.topN = topN; this.howMany = howMany; this.scoreFn = scoreFn; this.rescoreFn = rescoreFn; this.allowedPredicate = allowedPredicate; topScoreLowerBound = Double.NEGATIVE_INFINITY; full = false; } @Override public void accept(String key, float[] value) { if (allowedPredicate == null || allowedPredicate.test(key)) { double score = scoreFn.applyAsDouble(value); if (rescoreFn != null) { score = rescoreFn.applyAsDouble(key, score); } if (full) { if (score > topScoreLowerBound) { double peek; synchronized (topN) { peek = topN.peek().getSecond(); if (score > peek) { topN.poll(); topN.add(new Pair<>(key, score)); } } if (peek > topScoreLowerBound) { topScoreLowerBound = peek; } } } else { int newSize; synchronized (topN) { topN.add(new Pair<>(key, score)); newSize = topN.size(); } if (newSize >= howMany) { full = true; } } } } } }