Java tutorial
/* * Copyright (c) 2014, Cloudera, Inc. and Intel Corp. All Rights Reserved. * * Cloudera, Inc. licenses this file to you under the Apache License, * Version 2.0 (the "License"). You may not use this file except in * compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR * CONDITIONS OF ANY KIND, either express or implied. See the License for * the specific language governing permissions and limitations under the * License. */ package com.cloudera.oryx.app.serving.als.model; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; import java.util.stream.Stream; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ThreadFactoryBuilder; import net.openhft.koloboke.collect.map.ObjIntMap; import net.openhft.koloboke.collect.map.ObjObjMap; import net.openhft.koloboke.collect.map.hash.HashObjIntMaps; import net.openhft.koloboke.collect.map.hash.HashObjObjMaps; import net.openhft.koloboke.collect.set.ObjSet; import net.openhft.koloboke.collect.set.hash.HashObjSets; import net.openhft.koloboke.function.ObjDoubleToDoubleFunction; import org.apache.commons.math3.linear.RealMatrix; import com.cloudera.oryx.api.serving.ServingModel; import com.cloudera.oryx.app.als.FeatureVectors; import com.cloudera.oryx.app.als.RescorerProvider; import com.cloudera.oryx.app.serving.als.CosineDistanceSensitiveFunction; import com.cloudera.oryx.common.collection.Pair; import com.cloudera.oryx.common.collection.Pairs; import com.cloudera.oryx.common.lang.AutoLock; import com.cloudera.oryx.common.lang.AutoReadWriteLock; import com.cloudera.oryx.common.lang.LoggingCallable; import com.cloudera.oryx.common.math.LinearSystemSolver; import com.cloudera.oryx.common.math.Solver; /** * Contains all data structures needed to serve real-time requests for an ALS-based recommender. */ public final class ALSServingModel implements ServingModel { /** Number of partitions for items data structures. */ private static final ExecutorService executor = Executors.newFixedThreadPool( Runtime.getRuntime().availableProcessors(), new ThreadFactoryBuilder().setDaemon(true).setNameFormat("ALSServingModel-%d").build()); private final LocalitySensitiveHash lsh; /** User-feature matrix. */ private final FeatureVectors X; /** Item-feature matrix. This is partitioned into several maps for parallel access. */ private final FeatureVectors[] Y; /** Maps item IDs to their existing partition, if any */ private final ObjIntMap<String> yPartitionMap; /** Controls access to yPartitionMap. */ private final AutoReadWriteLock yPartitionMapLock; /** Remembers items that each user has interacted with*/ private final ObjObjMap<String, ObjSet<String>> knownItems; // Right now no corresponding "knownUsers" object private final AutoReadWriteLock knownItemsLock; private final ObjSet<String> expectedUserIDs; private final AutoReadWriteLock expectedUserIDsLock; private final ObjSet<String> expectedItemIDs; private final AutoReadWriteLock expectedItemIDsLock; private final AtomicReference<Solver> cachedYTYSolver; /** Number of features used in the model. */ private final int features; /** Whether model uses implicit feedback. */ private final boolean implicit; private final RescorerProvider rescorerProvider; /** * Creates an empty model. * * @param features number of features expected for user/item feature vectors * @param implicit whether model implements implicit feedback * @param sampleRate consider only approximately this fraction of all items when making recommendations. * Candidates are chosen intelligently with locality sensitive hashing. * @param rescorerProvider optional instance of a {@link RescorerProvider} */ ALSServingModel(int features, boolean implicit, double sampleRate, RescorerProvider rescorerProvider) { Preconditions.checkArgument(features > 0); Preconditions.checkArgument(sampleRate > 0.0 && sampleRate <= 1.0); lsh = new LocalitySensitiveHash(sampleRate, features); X = new FeatureVectors(); Y = new FeatureVectors[lsh.getNumPartitions()]; for (int i = 0; i < Y.length; i++) { Y[i] = new FeatureVectors(); } yPartitionMap = HashObjIntMaps.newMutableMap(); yPartitionMapLock = new AutoReadWriteLock(); knownItems = HashObjObjMaps.newMutableMap(); knownItemsLock = new AutoReadWriteLock(); expectedUserIDs = HashObjSets.newMutableSet(); expectedUserIDsLock = new AutoReadWriteLock(); expectedItemIDs = HashObjSets.newMutableSet(); expectedItemIDsLock = new AutoReadWriteLock(); cachedYTYSolver = new AtomicReference<>(); this.features = features; this.implicit = implicit; this.rescorerProvider = rescorerProvider; } public int getFeatures() { return features; } public boolean isImplicit() { return implicit; } public RescorerProvider getRescorerProvider() { return rescorerProvider; } public float[] getUserVector(String user) { return X.getVector(user); } public float[] getItemVector(String item) { int partition; try (AutoLock al = yPartitionMapLock.autoReadLock()) { partition = yPartitionMap.getOrDefault(item, Integer.MIN_VALUE); } if (partition < 0) { return null; } return Y[partition].getVector(item); } void setUserVector(String user, float[] vector) { Preconditions.checkArgument(vector.length == features); X.setVector(user, vector); try (AutoLock al = expectedUserIDsLock.autoWriteLock()) { expectedUserIDs.remove(user); } } void setItemVector(String item, float[] vector) { Preconditions.checkArgument(vector.length == features); int newPartition = lsh.getIndexFor(vector); // Exclusive update to mapping -- careful since other locks are acquired inside here try (AutoLock al = yPartitionMapLock.autoWriteLock()) { int existingPartition = yPartitionMap.getOrDefault(item, Integer.MIN_VALUE); if (existingPartition >= 0 && existingPartition != newPartition) { // Move from one to the other partition, so first remove old entry Y[existingPartition].removeVector(item); // Note that it's conceivable that a recommendation call sees *no* copy of this // item here in this brief window } // Then regardless put in new partition Y[newPartition].setVector(item, vector); yPartitionMap.put(item, newPartition); } try (AutoLock al = expectedItemIDsLock.autoWriteLock()) { expectedItemIDs.remove(item); } // Not clear if it's too inefficient to clear and recompute YtY solver every time any bit // of Y changes, but it's the most correct cachedYTYSolver.set(null); } /** * @param user user to get known items for * @return set of known items for the user (immutable, but thread-safe) */ public Set<String> getKnownItems(String user) { ObjSet<String> knownItems = doGetKnownItems(user); if (knownItems == null) { return Collections.emptySet(); } synchronized (knownItems) { if (knownItems.isEmpty()) { return Collections.emptySet(); } // Must copy since the original object is synchronized return HashObjSets.newImmutableSet(knownItems); } } private ObjSet<String> doGetKnownItems(String user) { try (AutoLock al = knownItemsLock.autoReadLock()) { return knownItems.get(user); } } /** * @return mapping of user IDs to count of items the user has interacted with */ public Map<String, Integer> getUserCounts() { ObjIntMap<String> counts = HashObjIntMaps.newUpdatableMap(); try (AutoLock al = knownItemsLock.autoReadLock()) { knownItems.forEach((userID, ids) -> { int numItems; synchronized (ids) { numItems = ids.size(); } counts.addValue(userID, numItems); }); } return counts; } /** * @return mapping of item IDs to count of users that have interacted with that item */ public Map<String, Integer> getItemCounts() { ObjIntMap<String> counts = HashObjIntMaps.newUpdatableMap(); try (AutoLock al = knownItemsLock.autoReadLock()) { knownItems.values().forEach(ids -> { synchronized (ids) { ids.forEach(id -> counts.addValue(id, 1)); } }); } return counts; } void addKnownItems(String user, Collection<String> items) { ObjSet<String> knownItemsForUser = doGetKnownItems(user); if (knownItemsForUser == null) { try (AutoLock al = knownItemsLock.autoWriteLock()) { // Check again knownItemsForUser = knownItems.get(user); if (knownItemsForUser == null) { knownItemsForUser = HashObjSets.newMutableSet(); knownItems.put(user, knownItemsForUser); } } } synchronized (knownItemsForUser) { knownItemsForUser.addAll(items); } } /** * @param user user to get known item vectors for * @return {@code null} if the user is not known to the model, or if there are no known items for the user */ public List<Pair<String, float[]>> getKnownItemVectorsForUser(String user) { float[] userVector = getUserVector(user); if (userVector == null) { return null; } Collection<String> knownItems = doGetKnownItems(user); if (knownItems == null) { return null; } synchronized (knownItems) { int size = knownItems.size(); if (size == 0) { return null; } List<Pair<String, float[]>> idVectors = new ArrayList<>(size); for (String itemID : knownItems) { float[] vector = getItemVector(itemID); if (vector != null) { idVectors.add(new Pair<>(itemID, vector)); } } return idVectors.isEmpty() ? null : idVectors; } } public Stream<Pair<String, Double>> topN(CosineDistanceSensitiveFunction scoreFn, ObjDoubleToDoubleFunction<String> rescoreFn, int howMany, Predicate<String> allowedPredicate) { int[] candidateIndices = lsh.getCandidateIndices(scoreFn.getTargetVector()); List<Callable<Stream<Pair<String, Double>>>> tasks = new ArrayList<>(candidateIndices.length); for (int partition : candidateIndices) { if (Y[partition].size() > 0) { tasks.add(LoggingCallable.log(() -> { TopNConsumer consumer = new TopNConsumer(howMany, scoreFn, rescoreFn, allowedPredicate); Y[partition].forEach(consumer); return consumer.getTopN(); })); } } int numTasks = tasks.size(); if (numTasks == 0) { return Stream.empty(); } Stream<Pair<String, Double>> stream; if (numTasks == 1) { try { stream = tasks.get(0).call(); } catch (Exception e) { throw new IllegalStateException(e); } } else { try { stream = executor.invokeAll(tasks).stream().map(future -> { try { return future.get(); } catch (InterruptedException e) { throw new IllegalStateException(e); } catch (ExecutionException e) { throw new IllegalStateException(e.getCause()); } }).reduce(Stream::concat).get(); } catch (InterruptedException e) { throw new IllegalStateException(e); } } return stream.sorted(Pairs.orderBySecond(Pairs.SortOrder.DESCENDING)).limit(howMany); } /** * @return all user IDs in the model */ public Collection<String> getAllUserIDs() { Collection<String> allUserIDs = HashObjSets.newMutableSet(); X.addAllIDsTo(allUserIDs); return allUserIDs; } /** * @return all item IDs in the model */ public Collection<String> getAllItemIDs() { Collection<String> allItemIDs = HashObjSets.newMutableSet(); for (FeatureVectors yPartition : Y) { yPartition.addAllIDsTo(allItemIDs); } return allItemIDs; } public Solver getYTYSolver() { Solver cached = cachedYTYSolver.get(); if (cached != null) { return cached; } RealMatrix YTY = null; for (FeatureVectors yPartition : Y) { RealMatrix YTYpartial = yPartition.getVTV(); if (YTYpartial != null) { YTY = YTY == null ? YTYpartial : YTY.add(YTYpartial); } } // Possible to compute this twice, but not a big deal Solver newYTYSolver = LinearSystemSolver.getSolver(YTY); cachedYTYSolver.set(newYTYSolver); return newYTYSolver; } /** * Retains only users that are expected to appear * in the upcoming model updates, or, that have arrived recently. This also clears the * recent known users data structure. * * @param users users that should be retained, which are coming in the new model updates */ void retainRecentAndUserIDs(Collection<String> users) { X.retainRecentAndIDs(users); try (AutoLock al = expectedUserIDsLock.autoWriteLock()) { expectedUserIDs.clear(); expectedUserIDs.addAll(users); X.removeAllIDsFrom(expectedUserIDs); } } /** * Retains only items that are expected to appear * in the upcoming model updates, or, that have arrived recently. This also clears the * recent known items data structure. * * @param items items that should be retained, which are coming in the new model updates */ void retainRecentAndItemIDs(Collection<String> items) { for (FeatureVectors yPartition : Y) { yPartition.retainRecentAndIDs(items); } try (AutoLock al = expectedItemIDsLock.autoWriteLock()) { expectedItemIDs.clear(); expectedItemIDs.addAll(items); for (FeatureVectors yPartition : Y) { yPartition.removeAllIDsFrom(expectedItemIDs); } } } /** * Like {@link #retainRecentAndUserIDs(Collection)} and {@link #retainRecentAndItemIDs(Collection)} * but affects the known-items data structure. * * @param users users that should be retained, which are coming in the new model updates * @param items items that should be retained, which are coming in the new model updates */ void retainRecentAndKnownItems(Collection<String> users, Collection<String> items) { // Keep all users in the new model, or, that have been added since last model Collection<String> recentUserIDs = HashObjSets.newMutableSet(); X.addAllRecentTo(recentUserIDs); try (AutoLock al = knownItemsLock.autoWriteLock()) { knownItems.removeIf((key, value) -> !users.contains(key) && !recentUserIDs.contains(key)); } // This will be easier to quickly copy the whole (smallish) set rather than // deal with locks below Collection<String> allRecentKnownItems = HashObjSets.newMutableSet(); for (FeatureVectors yPartition : Y) { yPartition.addAllRecentTo(allRecentKnownItems); } Predicate<String> notKeptOrRecent = value -> !items.contains(value) && !allRecentKnownItems.contains(value); try (AutoLock al = knownItemsLock.autoReadLock()) { knownItems.values().forEach(knownItemsForUser -> { synchronized (knownItemsForUser) { knownItemsForUser.removeIf(notKeptOrRecent); } }); } } /** * @return number of users in the model */ public int getNumUsers() { return X.size(); } /** * @return number of items in the model */ public int getNumItems() { int total = 0; for (FeatureVectors yPartition : Y) { total += yPartition.size(); } return total; } @Override public float getFractionLoaded() { int expected = 0; try (AutoLock al = expectedUserIDsLock.autoReadLock()) { expected += expectedUserIDs.size(); } try (AutoLock al = expectedItemIDsLock.autoReadLock()) { expected += expectedItemIDs.size(); } if (expected == 0) { return 1.0f; } float loaded = (float) getNumUsers() + getNumItems(); return loaded / (loaded + expected); } @Override public String toString() { int maxSize = 16; List<String> partitionSizes = new ArrayList<>(maxSize); for (int i = 0; i < Y.length; i++) { int size = Y[i].size(); if (size > 0) { partitionSizes.add(i + ":" + size); if (partitionSizes.size() == maxSize) { partitionSizes.add("..."); break; } } } return "ALSServingModel[features:" + features + ", implicit:" + implicit + ", X:(" + getNumUsers() + " users), Y:(" + getNumItems() + " items, partitions: " + partitionSizes + "...), fractionLoaded:" + getFractionLoaded() + "]"; } }