com.cloudera.oryx.app.serving.als.model.ALSServingModel.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.app.serving.als.model.ALSServingModel.java

Source

/*
 * Copyright (c) 2014, Cloudera, Inc. and Intel Corp. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.app.serving.als.model;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import java.util.stream.Stream;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import net.openhft.koloboke.collect.map.ObjIntMap;
import net.openhft.koloboke.collect.map.ObjObjMap;
import net.openhft.koloboke.collect.map.hash.HashObjIntMaps;
import net.openhft.koloboke.collect.map.hash.HashObjObjMaps;
import net.openhft.koloboke.collect.set.ObjSet;
import net.openhft.koloboke.collect.set.hash.HashObjSets;
import net.openhft.koloboke.function.ObjDoubleToDoubleFunction;
import org.apache.commons.math3.linear.RealMatrix;

import com.cloudera.oryx.api.serving.ServingModel;
import com.cloudera.oryx.app.als.FeatureVectors;
import com.cloudera.oryx.app.als.RescorerProvider;
import com.cloudera.oryx.app.serving.als.CosineDistanceSensitiveFunction;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.collection.Pairs;
import com.cloudera.oryx.common.lang.AutoLock;
import com.cloudera.oryx.common.lang.AutoReadWriteLock;
import com.cloudera.oryx.common.lang.LoggingCallable;
import com.cloudera.oryx.common.math.LinearSystemSolver;
import com.cloudera.oryx.common.math.Solver;

/**
 * Contains all data structures needed to serve real-time requests for an ALS-based recommender.
 */
public final class ALSServingModel implements ServingModel {

    /** Number of partitions for items data structures. */
    private static final ExecutorService executor = Executors.newFixedThreadPool(
            Runtime.getRuntime().availableProcessors(),
            new ThreadFactoryBuilder().setDaemon(true).setNameFormat("ALSServingModel-%d").build());

    private final LocalitySensitiveHash lsh;
    /** User-feature matrix. */
    private final FeatureVectors X;
    /** Item-feature matrix. This is partitioned into several maps for parallel access. */
    private final FeatureVectors[] Y;
    /** Maps item IDs to their existing partition, if any */
    private final ObjIntMap<String> yPartitionMap;
    /** Controls access to yPartitionMap. */
    private final AutoReadWriteLock yPartitionMapLock;
    /** Remembers items that each user has interacted with*/
    private final ObjObjMap<String, ObjSet<String>> knownItems; // Right now no corresponding "knownUsers" object
    private final AutoReadWriteLock knownItemsLock;
    private final ObjSet<String> expectedUserIDs;
    private final AutoReadWriteLock expectedUserIDsLock;
    private final ObjSet<String> expectedItemIDs;
    private final AutoReadWriteLock expectedItemIDsLock;
    private final AtomicReference<Solver> cachedYTYSolver;
    /** Number of features used in the model. */
    private final int features;
    /** Whether model uses implicit feedback. */
    private final boolean implicit;
    private final RescorerProvider rescorerProvider;

    /**
     * Creates an empty model.
     *
     * @param features number of features expected for user/item feature vectors
     * @param implicit whether model implements implicit feedback
     * @param sampleRate consider only approximately this fraction of all items when making recommendations.
     *  Candidates are chosen intelligently with locality sensitive hashing.
     * @param rescorerProvider optional instance of a {@link RescorerProvider}
     */
    ALSServingModel(int features, boolean implicit, double sampleRate, RescorerProvider rescorerProvider) {
        Preconditions.checkArgument(features > 0);
        Preconditions.checkArgument(sampleRate > 0.0 && sampleRate <= 1.0);

        lsh = new LocalitySensitiveHash(sampleRate, features);

        X = new FeatureVectors();
        Y = new FeatureVectors[lsh.getNumPartitions()];
        for (int i = 0; i < Y.length; i++) {
            Y[i] = new FeatureVectors();
        }
        yPartitionMap = HashObjIntMaps.newMutableMap();
        yPartitionMapLock = new AutoReadWriteLock();

        knownItems = HashObjObjMaps.newMutableMap();
        knownItemsLock = new AutoReadWriteLock();

        expectedUserIDs = HashObjSets.newMutableSet();
        expectedUserIDsLock = new AutoReadWriteLock();
        expectedItemIDs = HashObjSets.newMutableSet();
        expectedItemIDsLock = new AutoReadWriteLock();

        cachedYTYSolver = new AtomicReference<>();

        this.features = features;
        this.implicit = implicit;
        this.rescorerProvider = rescorerProvider;
    }

    public int getFeatures() {
        return features;
    }

    public boolean isImplicit() {
        return implicit;
    }

    public RescorerProvider getRescorerProvider() {
        return rescorerProvider;
    }

    public float[] getUserVector(String user) {
        return X.getVector(user);
    }

    public float[] getItemVector(String item) {
        int partition;
        try (AutoLock al = yPartitionMapLock.autoReadLock()) {
            partition = yPartitionMap.getOrDefault(item, Integer.MIN_VALUE);
        }
        if (partition < 0) {
            return null;
        }
        return Y[partition].getVector(item);
    }

    void setUserVector(String user, float[] vector) {
        Preconditions.checkArgument(vector.length == features);
        X.setVector(user, vector);
        try (AutoLock al = expectedUserIDsLock.autoWriteLock()) {
            expectedUserIDs.remove(user);
        }
    }

    void setItemVector(String item, float[] vector) {
        Preconditions.checkArgument(vector.length == features);
        int newPartition = lsh.getIndexFor(vector);
        // Exclusive update to mapping -- careful since other locks are acquired inside here
        try (AutoLock al = yPartitionMapLock.autoWriteLock()) {
            int existingPartition = yPartitionMap.getOrDefault(item, Integer.MIN_VALUE);
            if (existingPartition >= 0 && existingPartition != newPartition) {
                // Move from one to the other partition, so first remove old entry
                Y[existingPartition].removeVector(item);
                // Note that it's conceivable that a recommendation call sees *no* copy of this
                // item here in this brief window
            }
            // Then regardless put in new partition
            Y[newPartition].setVector(item, vector);
            yPartitionMap.put(item, newPartition);
        }
        try (AutoLock al = expectedItemIDsLock.autoWriteLock()) {
            expectedItemIDs.remove(item);
        }
        // Not clear if it's too inefficient to clear and recompute YtY solver every time any bit
        // of Y changes, but it's the most correct
        cachedYTYSolver.set(null);
    }

    /**
     * @param user user to get known items for
     * @return set of known items for the user (immutable, but thread-safe)
     */
    public Set<String> getKnownItems(String user) {
        ObjSet<String> knownItems = doGetKnownItems(user);
        if (knownItems == null) {
            return Collections.emptySet();
        }
        synchronized (knownItems) {
            if (knownItems.isEmpty()) {
                return Collections.emptySet();
            }
            // Must copy since the original object is synchronized
            return HashObjSets.newImmutableSet(knownItems);
        }
    }

    private ObjSet<String> doGetKnownItems(String user) {
        try (AutoLock al = knownItemsLock.autoReadLock()) {
            return knownItems.get(user);
        }
    }

    /**
     * @return mapping of user IDs to count of items the user has interacted with
     */
    public Map<String, Integer> getUserCounts() {
        ObjIntMap<String> counts = HashObjIntMaps.newUpdatableMap();
        try (AutoLock al = knownItemsLock.autoReadLock()) {
            knownItems.forEach((userID, ids) -> {
                int numItems;
                synchronized (ids) {
                    numItems = ids.size();
                }
                counts.addValue(userID, numItems);
            });
        }
        return counts;
    }

    /**
     * @return mapping of item IDs to count of users that have interacted with that item
     */
    public Map<String, Integer> getItemCounts() {
        ObjIntMap<String> counts = HashObjIntMaps.newUpdatableMap();
        try (AutoLock al = knownItemsLock.autoReadLock()) {
            knownItems.values().forEach(ids -> {
                synchronized (ids) {
                    ids.forEach(id -> counts.addValue(id, 1));
                }
            });
        }
        return counts;
    }

    void addKnownItems(String user, Collection<String> items) {
        ObjSet<String> knownItemsForUser = doGetKnownItems(user);

        if (knownItemsForUser == null) {
            try (AutoLock al = knownItemsLock.autoWriteLock()) {
                // Check again
                knownItemsForUser = knownItems.get(user);
                if (knownItemsForUser == null) {
                    knownItemsForUser = HashObjSets.newMutableSet();
                    knownItems.put(user, knownItemsForUser);
                }
            }
        }

        synchronized (knownItemsForUser) {
            knownItemsForUser.addAll(items);
        }
    }

    /**
     * @param user user to get known item vectors for
     * @return {@code null} if the user is not known to the model, or if there are no known items for the user
     */
    public List<Pair<String, float[]>> getKnownItemVectorsForUser(String user) {
        float[] userVector = getUserVector(user);
        if (userVector == null) {
            return null;
        }
        Collection<String> knownItems = doGetKnownItems(user);
        if (knownItems == null) {
            return null;
        }
        synchronized (knownItems) {
            int size = knownItems.size();
            if (size == 0) {
                return null;
            }
            List<Pair<String, float[]>> idVectors = new ArrayList<>(size);
            for (String itemID : knownItems) {
                float[] vector = getItemVector(itemID);
                if (vector != null) {
                    idVectors.add(new Pair<>(itemID, vector));
                }
            }
            return idVectors.isEmpty() ? null : idVectors;
        }
    }

    public Stream<Pair<String, Double>> topN(CosineDistanceSensitiveFunction scoreFn,
            ObjDoubleToDoubleFunction<String> rescoreFn, int howMany, Predicate<String> allowedPredicate) {

        int[] candidateIndices = lsh.getCandidateIndices(scoreFn.getTargetVector());
        List<Callable<Stream<Pair<String, Double>>>> tasks = new ArrayList<>(candidateIndices.length);
        for (int partition : candidateIndices) {
            if (Y[partition].size() > 0) {
                tasks.add(LoggingCallable.log(() -> {
                    TopNConsumer consumer = new TopNConsumer(howMany, scoreFn, rescoreFn, allowedPredicate);
                    Y[partition].forEach(consumer);
                    return consumer.getTopN();
                }));
            }
        }

        int numTasks = tasks.size();
        if (numTasks == 0) {
            return Stream.empty();
        }

        Stream<Pair<String, Double>> stream;
        if (numTasks == 1) {
            try {
                stream = tasks.get(0).call();
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        } else {
            try {
                stream = executor.invokeAll(tasks).stream().map(future -> {
                    try {
                        return future.get();
                    } catch (InterruptedException e) {
                        throw new IllegalStateException(e);
                    } catch (ExecutionException e) {
                        throw new IllegalStateException(e.getCause());
                    }
                }).reduce(Stream::concat).get();
            } catch (InterruptedException e) {
                throw new IllegalStateException(e);
            }
        }
        return stream.sorted(Pairs.orderBySecond(Pairs.SortOrder.DESCENDING)).limit(howMany);
    }

    /**
     * @return all user IDs in the model
     */
    public Collection<String> getAllUserIDs() {
        Collection<String> allUserIDs = HashObjSets.newMutableSet();
        X.addAllIDsTo(allUserIDs);
        return allUserIDs;
    }

    /**
     * @return all item IDs in the model
     */
    public Collection<String> getAllItemIDs() {
        Collection<String> allItemIDs = HashObjSets.newMutableSet();
        for (FeatureVectors yPartition : Y) {
            yPartition.addAllIDsTo(allItemIDs);
        }
        return allItemIDs;
    }

    public Solver getYTYSolver() {
        Solver cached = cachedYTYSolver.get();
        if (cached != null) {
            return cached;
        }
        RealMatrix YTY = null;
        for (FeatureVectors yPartition : Y) {
            RealMatrix YTYpartial = yPartition.getVTV();
            if (YTYpartial != null) {
                YTY = YTY == null ? YTYpartial : YTY.add(YTYpartial);
            }
        }
        // Possible to compute this twice, but not a big deal
        Solver newYTYSolver = LinearSystemSolver.getSolver(YTY);
        cachedYTYSolver.set(newYTYSolver);
        return newYTYSolver;
    }

    /**
     * Retains only users that are expected to appear
     * in the upcoming model updates, or, that have arrived recently. This also clears the
     * recent known users data structure.
     *
     * @param users users that should be retained, which are coming in the new model updates
     */
    void retainRecentAndUserIDs(Collection<String> users) {
        X.retainRecentAndIDs(users);
        try (AutoLock al = expectedUserIDsLock.autoWriteLock()) {
            expectedUserIDs.clear();
            expectedUserIDs.addAll(users);
            X.removeAllIDsFrom(expectedUserIDs);
        }
    }

    /**
     * Retains only items that are expected to appear
     * in the upcoming model updates, or, that have arrived recently. This also clears the
     * recent known items data structure.
     *
     * @param items items that should be retained, which are coming in the new model updates
     */
    void retainRecentAndItemIDs(Collection<String> items) {
        for (FeatureVectors yPartition : Y) {
            yPartition.retainRecentAndIDs(items);
        }
        try (AutoLock al = expectedItemIDsLock.autoWriteLock()) {
            expectedItemIDs.clear();
            expectedItemIDs.addAll(items);
            for (FeatureVectors yPartition : Y) {
                yPartition.removeAllIDsFrom(expectedItemIDs);
            }
        }
    }

    /**
     * Like {@link #retainRecentAndUserIDs(Collection)} and {@link #retainRecentAndItemIDs(Collection)}
     * but affects the known-items data structure.
     *
     * @param users users that should be retained, which are coming in the new model updates
     * @param items items that should be retained, which are coming in the new model updates
     */
    void retainRecentAndKnownItems(Collection<String> users, Collection<String> items) {
        // Keep all users in the new model, or, that have been added since last model
        Collection<String> recentUserIDs = HashObjSets.newMutableSet();
        X.addAllRecentTo(recentUserIDs);
        try (AutoLock al = knownItemsLock.autoWriteLock()) {
            knownItems.removeIf((key, value) -> !users.contains(key) && !recentUserIDs.contains(key));
        }

        // This will be easier to quickly copy the whole (smallish) set rather than
        // deal with locks below
        Collection<String> allRecentKnownItems = HashObjSets.newMutableSet();
        for (FeatureVectors yPartition : Y) {
            yPartition.addAllRecentTo(allRecentKnownItems);
        }

        Predicate<String> notKeptOrRecent = value -> !items.contains(value) && !allRecentKnownItems.contains(value);
        try (AutoLock al = knownItemsLock.autoReadLock()) {
            knownItems.values().forEach(knownItemsForUser -> {
                synchronized (knownItemsForUser) {
                    knownItemsForUser.removeIf(notKeptOrRecent);
                }
            });
        }
    }

    /**
     * @return number of users in the model
     */
    public int getNumUsers() {
        return X.size();
    }

    /**
     * @return number of items in the model
     */
    public int getNumItems() {
        int total = 0;
        for (FeatureVectors yPartition : Y) {
            total += yPartition.size();
        }
        return total;
    }

    @Override
    public float getFractionLoaded() {
        int expected = 0;
        try (AutoLock al = expectedUserIDsLock.autoReadLock()) {
            expected += expectedUserIDs.size();
        }
        try (AutoLock al = expectedItemIDsLock.autoReadLock()) {
            expected += expectedItemIDs.size();
        }
        if (expected == 0) {
            return 1.0f;
        }
        float loaded = (float) getNumUsers() + getNumItems();
        return loaded / (loaded + expected);
    }

    @Override
    public String toString() {
        int maxSize = 16;
        List<String> partitionSizes = new ArrayList<>(maxSize);
        for (int i = 0; i < Y.length; i++) {
            int size = Y[i].size();
            if (size > 0) {
                partitionSizes.add(i + ":" + size);
                if (partitionSizes.size() == maxSize) {
                    partitionSizes.add("...");
                    break;
                }
            }
        }
        return "ALSServingModel[features:" + features + ", implicit:" + implicit + ", X:(" + getNumUsers()
                + " users), Y:(" + getNumItems() + " items, partitions: " + partitionSizes + "...), fractionLoaded:"
                + getFractionLoaded() + "]";
    }

}