com.anhth12.lambda.app.serving.als.model.ALSServingModel.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.lambda.app.serving.als.model.ALSServingModel.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.lambda.app.serving.als.model;

import com.anhth12.lambda.app.als.RescorerProvider;
import com.anhth12.lambda.common.collection.AndPredicate;
import com.anhth12.lambda.common.collection.KeyOnlyBiPredicate;
import com.anhth12.lambda.common.collection.NotContainsPredicate;
import com.anhth12.lambda.common.collection.Pair;
import com.anhth12.lambda.common.collection.PairComparators;
import com.anhth12.lambda.common.lang.AutoLock;
import com.anhth12.lambda.common.lang.LoggingCallable;
import com.anhth12.lambda.common.math.LinearSystemSolver;
import com.anhth12.lambda.common.math.Solver;
import com.anhth12.lambda.common.math.VectorMath;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import net.openhft.koloboke.collect.map.ObjIntMap;
import net.openhft.koloboke.collect.map.ObjObjMap;
import net.openhft.koloboke.collect.map.hash.HashObjIntMaps;
import net.openhft.koloboke.collect.map.hash.HashObjObjMap;
import net.openhft.koloboke.collect.map.hash.HashObjObjMaps;
import net.openhft.koloboke.collect.set.ObjSet;
import net.openhft.koloboke.collect.set.hash.HashObjSet;
import net.openhft.koloboke.collect.set.hash.HashObjSets;
import net.openhft.koloboke.function.BiConsumer;
import net.openhft.koloboke.function.ObjDoubleToDoubleFunction;
import net.openhft.koloboke.function.Predicate;
import net.openhft.koloboke.function.ToDoubleFunction;
import org.apache.commons.math3.linear.RealMatrix;

/**
 *
 * @author Tong Hoang Anh
 */
public final class ALSServingModel {

    private static final int PARTITIONS = Runtime.getRuntime().availableProcessors();

    private static final ExecutorService executor = PARTITIONS <= 1 ? null
            : Executors.newFixedThreadPool(PARTITIONS,
                    new ThreadFactoryBuilder().setDaemon(true).setNameFormat("ALSServingModel-%d").build());

    //USER-FEATURE matrix
    private final ObjObjMap<String, float[]> X;
    //ITEM-FEATURE matrix
    private final ObjObjMap<String, float[]>[] Y;

    private final Collection<String> recentNewUsers;

    private final Collection<String>[] recentNewItems;

    private final ObjObjMap<String, ObjSet<String>> knownItems;
    //Controls access to X
    private final ReadWriteLock xLock;
    //Controls access to partitions of Y
    private final ReadWriteLock[] yLocks;

    private final int features;

    private final boolean implicit;

    private final RescorerProvider rescorerProvider;

    ALSServingModel(int features, boolean implicit, RescorerProvider rescoreProvider) {
        Preconditions.checkArgument(features > 0);

        X = HashObjObjMaps.newMutableMap();
        Y = (ObjObjMap<String, float[]>[]) Array.newInstance(ObjObjMap.class, PARTITIONS);

        for (int i = 0; i < Y.length; i++) {
            Y[i] = HashObjObjMaps.newMutableMap();
        }

        recentNewUsers = new HashSet<>();
        recentNewItems = (Collection<String>[]) Array.newInstance(HashSet.class, PARTITIONS);

        for (int i = 0; i < recentNewItems.length; i++) {
            recentNewItems[i] = new HashSet<>();
        }

        knownItems = HashObjObjMaps.newMutableMap();

        xLock = new ReentrantReadWriteLock();
        yLocks = new ReadWriteLock[Y.length];
        for (int i = 0; i < yLocks.length; i++) {
            yLocks[i] = new ReentrantReadWriteLock();
        }

        this.features = features;
        this.implicit = implicit;
        this.rescorerProvider = rescoreProvider;
    }

    public int getFeatures() {
        return features;
    }

    public boolean isImplicit() {
        return implicit;
    }

    public RescorerProvider getRescorerProvider() {
        return rescorerProvider;
    }

    private static int partition(Object o) {
        return (o.hashCode() & 0x7FFFFFFF) % PARTITIONS;
    }

    public float[] getUserVector(String user) {
        try (AutoLock al = new AutoLock(xLock.readLock())) {
            return X.get(user);
        }
    }

    public float[] getItemVector(String item) {
        int partition = partition(item);
        try (AutoLock al = new AutoLock(yLocks[partition].readLock())) {
            return Y[partition].get(item);
        }
    }

    void setUserVector(String user, float[] vector) {
        Preconditions.checkNotNull(user);
        Preconditions.checkArgument(vector.length == features);
        try (AutoLock al = new AutoLock(xLock.writeLock())) {
            if (X.put(user, vector) == null) {
                recentNewUsers.add(user);
            }
        }
    }

    void setItemVector(String item, float[] vector) {
        Preconditions.checkNotNull(vector);
        Preconditions.checkArgument(vector.length == features);
        int partition = partition(item);
        try (AutoLock al = new AutoLock(yLocks[partition].writeLock())) {
            if (Y[partition].put(item, vector) == null) {
                recentNewItems[partition].add(item);
            }
        }
    }

    public Collection<String> getKnownItems(String user) {
        return doGetKnownItems(user);
    }

    private ObjSet<String> doGetKnownItems(String user) {
        try (AutoLock al = new AutoLock((xLock.readLock()))) {
            return knownItems.get(user);
        }
    }

    public Map<String, Integer> getUserCounts() {
        ObjIntMap<String> count = HashObjIntMaps.newUpdatableMap();
        try (AutoLock al = new AutoLock(xLock.readLock())) {
            for (Map.Entry<String, ObjSet<String>> entry : knownItems.entrySet()) {
                String userID = entry.getKey();
                Collection<?> ids = entry.getValue();
                int numItems;
                synchronized (ids) {
                    numItems = ids.size();
                }
                count.addValue(userID, numItems);
            }
        }
        return count;
    }

    public Map<String, Integer> getItemCounts() {
        ObjIntMap<String> count = HashObjIntMaps.newUpdatableMap();
        try (AutoLock al = new AutoLock(xLock.readLock())) {
            for (Collection<String> ids : knownItems.values()) {
                synchronized (ids) {
                    for (String id : ids) {
                        count.addValue(id, 1);
                    }
                }
            }
        }
        return count;
    }

    void addKnownItems(String user, Collection<String> items) {

        ObjSet<String> knownItemsForUser = doGetKnownItems(user);

        if (knownItemsForUser != null) {
            try (AutoLock al = new AutoLock(xLock.writeLock())) {
                knownItemsForUser = knownItems.get(user);
                if (knownItemsForUser == null) {
                    knownItemsForUser = HashObjSets.newMutableSet();
                    knownItems.put(user, knownItemsForUser);
                }
            }
        }
        synchronized (knownItemsForUser) {
            knownItemsForUser.addAll(items);
        }
    }

    public List<Pair<String, float[]>> getKnowItemVectorsForUser(String user) {
        float[] userVector = getUserVector(user);

        if (userVector == null) {
            return null;
        }

        Collection<String> knownItems = getKnownItems(user);

        if (knownItems == null || knownItems.isEmpty()) {
            return null;
        }

        List<Pair<String, float[]>> idVectors = new ArrayList<>(knownItems.size());

        synchronized (knownItems) {
            for (String itemID : knownItems) {
                int partition = partition(itemID);
                float[] vector;
                try (AutoLock al = new AutoLock(yLocks[partition].readLock())) {
                    vector = Y[partition].get(itemID);
                }
                idVectors.add(new Pair<>(itemID, vector));
            }
        }
        return idVectors;
    }

    public List<Pair<String, Double>> topN(ToDoubleFunction<float[]> scoreFn, int howmany) {
        return topN(scoreFn, null, howmany, null);
    }

    public List<Pair<String, Double>> topN(final ToDoubleFunction<float[]> scoreFn,
            final ObjDoubleToDoubleFunction<String> rescoreFn, final int howMany,
            final Predicate<String> allowedPredicate) {

        List<Callable<Iterable<Pair<String, Double>>>> tasks = new ArrayList<>(Y.length);
        for (int partition = 0; partition < Y.length; partition++) {
            final int thePartition = partition;
            tasks.add(new LoggingCallable<Iterable<Pair<String, Double>>>() {

                @Override
                public Iterable<Pair<String, Double>> doCall() throws Exception {
                    Queue<Pair<String, Double>> topN = new PriorityQueue<>(howMany + 1,
                            PairComparators.<Double>bySecond());
                    TopNConsumer topNPoc = new TopNConsumer(topN, howMany, scoreFn, rescoreFn, allowedPredicate);
                    try (AutoLock al = new AutoLock(yLocks[thePartition].readLock())) {
                        Y[thePartition].forEach(topNPoc);
                    }

                    return topN;
                }
            });
        }

        List<Iterable<Pair<String, Double>>> iterables = new ArrayList<>();

        if (Y.length >= 2) {
            try {
                for (Future<Iterable<Pair<String, Double>>> future : executor.invokeAll(tasks)) {
                    iterables.add(future.get());
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new IllegalStateException(e);
            }
        } else {
            try {
                iterables.add(tasks.get(0).call());
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        return Ordering.from(PairComparators.<Double>bySecond()).greatestOf(Iterables.concat(iterables), howMany);

    }

    public Collection<String> getAllUserIDs() {
        Collection<String> userList;
        try (AutoLock al = new AutoLock(xLock.readLock())) {
            userList = new ArrayList<>(X.keySet());
        }
        return userList;
    }

    public Collection<String> getAllItemIDs() {
        Collection<String> itemsList = new ArrayList<>();
        for (int partition = 0; partition < Y.length; partition++) {
            try (AutoLock al = new AutoLock(yLocks[partition].readLock())) {
                itemsList.addAll(Y[partition].keySet());
            }
        }
        return itemsList;
    }

    public Solver getYTYSolver() {
        RealMatrix YTY = null;

        for (int partition = 0; partition < Y.length; partition++) {
            RealMatrix YTYpartial;
            try (AutoLock al = new AutoLock(yLocks[partition].readLock())) {
                YTYpartial = VectorMath.transposeTimesSelf(Y[partition].values());
            }

            if (YTYpartial != null) {
                YTY = YTY == null ? YTYpartial : YTY.add(YTYpartial);
            }
        }

        return new LinearSystemSolver().getSolver(YTY);
    }

    void pruneX(Collection<String> users) {
        try (AutoLock al = new AutoLock(xLock.writeLock())) {

            X.removeIf(new KeyOnlyBiPredicate<>(new AndPredicate<>(new NotContainsPredicate<>(users),
                    new NotContainsPredicate<>(recentNewUsers))));

            recentNewUsers.clear();

        }
    }

    void pruneY(Collection<String> items) {
        for (int partition = 0; partition < Y.length; partition++) {
            try (AutoLock al = new AutoLock(yLocks[partition].writeLock())) {
                Y[partition].removeIf(new KeyOnlyBiPredicate<>(new AndPredicate<>(new NotContainsPredicate<>(items),
                        new NotContainsPredicate<>(recentNewItems[partition]))));
                recentNewItems[partition].clear();
            }
        }
    }

    void pruneKnownItems(Collection<String> users, final Collection<String> items) {
        try (AutoLock al = new AutoLock(xLock.writeLock())) {
            knownItems.removeIf(new KeyOnlyBiPredicate<>(new AndPredicate<>(new NotContainsPredicate<>(users),
                    new NotContainsPredicate<>(recentNewUsers))));
        }

        final Collection<String> allRecentKnownItems = new HashSet<>();
        for (int partition = 0; partition < Y.length; partition++) {
            try (AutoLock al = new AutoLock(yLocks[partition].writeLock())) {
                allRecentKnownItems.addAll(recentNewItems[partition]);
            }
        }

        try (AutoLock al = new AutoLock(xLock.readLock())) {
            for (ObjSet<String> knownItemsForUser : knownItems.values()) {
                synchronized (knownItemsForUser) {
                    knownItemsForUser.removeIf(new Predicate<String>() {
                        @Override
                        public boolean test(String value) {
                            return !items.contains(value) && !allRecentKnownItems.contains(value);
                        }
                    });
                }
            }
        }

    }

    @Override
    public String toString() {
        int numItems = 0;

        for (Map<?, ?> partition : Y) {
            numItems += partition.size();
        }

        return "ALSServingModel[features:" + features + ", implicit:" + implicit + ",X(" + X.size() + " users), Y:("
                + numItems + "items)]";
    }

    private static final class TopNConsumer implements BiConsumer<String, float[]> {

        private final Queue<Pair<String, Double>> topN;
        private final int howMany;
        private final ToDoubleFunction<float[]> scoreFn;
        private final ObjDoubleToDoubleFunction<String> rescoreFn;
        private final Predicate<String> allowedPredicate;

        private double topScoreLowerBound;
        private boolean full;

        public TopNConsumer(Queue<Pair<String, Double>> topN, int howMany, ToDoubleFunction<float[]> scoreFn,
                ObjDoubleToDoubleFunction<String> rescoreFn, Predicate<String> allowedPredicate) {
            this.topN = topN;
            this.howMany = howMany;
            this.scoreFn = scoreFn;
            this.rescoreFn = rescoreFn;
            this.allowedPredicate = allowedPredicate;
            topScoreLowerBound = Double.NEGATIVE_INFINITY;
            full = false;
        }

        @Override
        public void accept(String key, float[] value) {
            if (allowedPredicate == null || allowedPredicate.test(key)) {
                double score = scoreFn.applyAsDouble(value);
                if (rescoreFn != null) {
                    score = rescoreFn.applyAsDouble(key, score);
                }

                if (full) {
                    if (score > topScoreLowerBound) {
                        double peek;
                        synchronized (topN) {
                            peek = topN.peek().getSecond();
                            if (score > peek) {
                                topN.poll();
                                topN.add(new Pair<>(key, score));
                            }
                        }
                        if (peek > topScoreLowerBound) {
                            topScoreLowerBound = peek;
                        }
                    }
                } else {
                    int newSize;
                    synchronized (topN) {
                        topN.add(new Pair<>(key, score));
                        newSize = topN.size();
                    }
                    if (newSize >= howMany) {
                        full = true;
                    }
                }
            }
        }
    }

}