com.webir.popcornsaver.cluster.TreeClusteringRecommender.java Source code

Java tutorial

Introduction

Here is the source code for com.webir.popcornsaver.cluster.TreeClusteringRecommender.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.webir.popcornsaver.cluster;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;

import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.recommender.AbstractRecommender;
import org.apache.mahout.cf.taste.impl.recommender.TopItems;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.recommender.IDRescorer;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Preconditions;

/**
 * <p>
 * A {@link org.apache.mahout.cf.taste.recommender.Recommender} that clusters users, then determines the
 * clusters' top recommendations. This implementation builds clusters by repeatedly merging clusters until
 * only a certain number remain, meaning that each cluster is sort of a tree of other clusters.
 * </p>
 *
 * <p>
 * This {@link org.apache.mahout.cf.taste.recommender.Recommender} therefore has a few properties to note:
 * </p>
 *
 * <ul>
 * <li>For all users in a cluster, recommendations will be the same</li>
 * <li>{@link #estimatePreference(long, long)} may well return {@link Double#NaN}; it does so when asked to
 * estimate preference for an item for which no preference is expressed in the users in the cluster.</li>
 * </ul>
 */
public final class TreeClusteringRecommender extends AbstractRecommender implements ClusteringRecommender {

    private static final Logger log = LoggerFactory.getLogger(TreeClusteringRecommender.class);

    private static final FastIDSet[] NO_CLUSTERS = new FastIDSet[0];
    private static final Random RANDOM = RandomUtils.getRandom();

    private final ClusterSimilarity clusterSimilarity;
    private final int numClusters;
    private final double clusteringThreshold;
    private final boolean clusteringByThreshold;
    private final double samplingRate;
    private FastByIDMap<List<RecommendedItem>> topRecsByUserID;
    private FastIDSet[] allClusters;
    private FastByIDMap<FastIDSet> clustersByUserID;
    private final RefreshHelper refreshHelper;

    /**
     * @param dataModel
     *          {@link DataModel} which provides users
     * @param clusterSimilarity
     *          {@link ClusterSimilarity} used to compute cluster similarity
     * @param numClusters
     *          desired number of clusters to create
     * @throws IllegalArgumentException
     *           if arguments are <code>null</code>, or <code>numClusters</code> is less than 2
     */
    public TreeClusteringRecommender(DataModel dataModel, ClusterSimilarity clusterSimilarity, int numClusters)
            throws TasteException {
        this(dataModel, clusterSimilarity, numClusters, 1.0);
    }

    /**
     * @param dataModel
     *          {@link DataModel} which provdes users
     * @param clusterSimilarity
     *          {@link ClusterSimilarity} used to compute cluster similarity
     * @param numClusters
     *          desired number of clusters to create
     * @param samplingRate
     *          percentage of all cluster-cluster pairs to consider when finding next-most-similar clusters.
     *          Decreasing this value from 1.0 can increase performance at the cost of accuracy
     * @throws IllegalArgumentException
     *           if arguments are <code>null</code>, or <code>numClusters</code> is less than 2, or samplingRate
     *           is {@link Double#NaN} or nonpositive or greater than 1.0
     */
    public TreeClusteringRecommender(DataModel dataModel, ClusterSimilarity clusterSimilarity, int numClusters,
            double samplingRate) throws TasteException {
        super(dataModel);
        Preconditions.checkArgument(clusterSimilarity != null, "clusterSimilarity is null");
        Preconditions.checkArgument(numClusters >= 2, "numClusters must be at least 2");
        Preconditions.checkArgument(samplingRate > 0.0 && samplingRate <= 1.0, "samplingRate is invalid: %f",
                samplingRate);
        this.clusterSimilarity = clusterSimilarity;
        this.numClusters = numClusters;
        this.clusteringThreshold = Double.NaN;
        this.clusteringByThreshold = false;
        this.samplingRate = samplingRate;
        this.refreshHelper = new RefreshHelper(new Callable<Object>() {
            @Override
            public Object call() throws TasteException {
                buildClusters();
                return null;
            }
        });
        refreshHelper.addDependency(dataModel);
        refreshHelper.addDependency(clusterSimilarity);
        buildClusters();
    }

    /**
     * @param dataModel
     *          {@link DataModel} which provdes users
     * @param clusterSimilarity
     *          {@link ClusterSimilarity} used to compute cluster similarity
     * @param clusteringThreshold
     *          clustering similarity threshold; clusters will be aggregated into larger clusters until the next
     *          two nearest clusters' similarity drops below this threshold
     * @throws IllegalArgumentException
     *           if arguments are <code>null</code>, or <code>clusteringThreshold</code> is {@link Double#NaN}
     */
    public TreeClusteringRecommender(DataModel dataModel, ClusterSimilarity clusterSimilarity,
            double clusteringThreshold) throws TasteException {
        this(dataModel, clusterSimilarity, clusteringThreshold, 1.0);
    }

    /**
     * @param dataModel
     *          {@link DataModel} which provides users
     * @param clusterSimilarity
     *          {@link ClusterSimilarity} used to compute cluster similarity
     * @param clusteringThreshold
     *          clustering similarity threshold; clusters will be aggregated into larger clusters until the next
     *          two nearest clusters' similarity drops below this threshold
     * @param samplingRate
     *          percentage of all cluster-cluster pairs to consider when finding next-most-similar clusters.
     *          Decreasing this value from 1.0 can increase performance at the cost of accuracy
     * @throws IllegalArgumentException
     *           if arguments are <code>null</code>, or <code>clusteringThreshold</code> is {@link Double#NaN},
     *           or samplingRate is {@link Double#NaN} or nonpositive or greater than 1.0
     */
    public TreeClusteringRecommender(DataModel dataModel, ClusterSimilarity clusterSimilarity,
            double clusteringThreshold, double samplingRate) throws TasteException {
        super(dataModel);
        Preconditions.checkArgument(clusterSimilarity != null, "clusterSimilarity is null");
        Preconditions.checkArgument(!Double.isNaN(clusteringThreshold), "clusteringThreshold must not be NaN");
        Preconditions.checkArgument(samplingRate > 0.0 && samplingRate <= 1.0, "samplingRate is invalid: %f",
                samplingRate);
        this.clusterSimilarity = clusterSimilarity;
        this.numClusters = Integer.MIN_VALUE;
        this.clusteringThreshold = clusteringThreshold;
        this.clusteringByThreshold = true;
        this.samplingRate = samplingRate;
        this.refreshHelper = new RefreshHelper(new Callable<Object>() {
            @Override
            public Object call() throws TasteException {
                buildClusters();
                return null;
            }
        });
        refreshHelper.addDependency(dataModel);
        refreshHelper.addDependency(clusterSimilarity);
        buildClusters();
    }

    @Override
    public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
        Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
        buildClusters();

        log.debug("Recommending items for user ID '{}'", userID);

        List<RecommendedItem> recommended = topRecsByUserID.get(userID);
        if (recommended == null) {
            return Collections.emptyList();
        }

        DataModel dataModel = getDataModel();
        List<RecommendedItem> rescored = new ArrayList<RecommendedItem>(recommended.size());
        // Only add items the user doesn't already have a preference for.
        // And that the rescorer doesn't "reject".
        for (RecommendedItem recommendedItem : recommended) {
            long itemID = recommendedItem.getItemID();
            if ((rescorer != null) && rescorer.isFiltered(itemID)) {
                continue;
            }
            if ((dataModel.getPreferenceValue(userID, itemID) == null) && ((rescorer == null)
                    || !Double.isNaN(rescorer.rescore(itemID, recommendedItem.getValue())))) {
                rescored.add(recommendedItem);
            }
        }
        Collections.sort(rescored, new ByRescoreComparator(rescorer));

        return rescored;
    }

    @Override
    public List<RecommendedItem> recommend(long l, int i, IDRescorer idRescorer, boolean b) throws TasteException {
        return recommend(l, i, idRescorer);
    }

    @Override
    public float estimatePreference(long userID, long itemID) throws TasteException {
        DataModel model = getDataModel();
        Float actualPref = model.getPreferenceValue(userID, itemID);
        if (actualPref != null) {
            return actualPref;
        }
        buildClusters();
        List<RecommendedItem> topRecsForUser = topRecsByUserID.get(userID);
        if (topRecsForUser != null) {
            for (RecommendedItem item : topRecsForUser) {
                if (itemID == item.getItemID()) {
                    return item.getValue();
                }
            }
        }
        // Hmm, we have no idea. The item is not in the user's cluster
        return Float.NaN;
    }

    @Override
    public FastIDSet getCluster(long userID) throws TasteException {
        buildClusters();
        FastIDSet cluster = clustersByUserID.get(userID);
        return cluster == null ? new FastIDSet() : cluster;
    }

    @Override
    public FastIDSet[] getClusters() throws TasteException {
        buildClusters();
        return allClusters;
    }

    private void buildClusters() throws TasteException {
        DataModel model = getDataModel();
        int numUsers = model.getNumUsers();
        if (numUsers > 0) {
            List<FastIDSet> newClusters = new ArrayList<FastIDSet>(numUsers);
            // Begin with a cluster for each user:
            LongPrimitiveIterator it = model.getUserIDs();
            while (it.hasNext()) {
                FastIDSet newCluster = new FastIDSet();
                newCluster.add(it.nextLong());
                newClusters.add(newCluster);
            }
            if (numUsers > 1) {
                findClusters(newClusters);
            }
            topRecsByUserID = computeTopRecsPerUserID(newClusters);
            clustersByUserID = computeClustersPerUserID(newClusters);
            allClusters = newClusters.toArray(new FastIDSet[newClusters.size()]);
        } else {
            topRecsByUserID = new FastByIDMap<List<RecommendedItem>>();
            clustersByUserID = new FastByIDMap<FastIDSet>();
            allClusters = NO_CLUSTERS;
        }
    }

    private void findClusters(List<FastIDSet> newClusters) throws TasteException {
        if (clusteringByThreshold) {
            Pair<FastIDSet, FastIDSet> nearestPair = findNearestClusters(newClusters);
            if (nearestPair != null) {
                FastIDSet cluster1 = nearestPair.getFirst();
                FastIDSet cluster2 = nearestPair.getSecond();
                while (clusterSimilarity.getSimilarity(cluster1, cluster2) >= clusteringThreshold) {
                    newClusters.remove(cluster1);
                    newClusters.remove(cluster2);
                    FastIDSet merged = new FastIDSet(cluster1.size() + cluster2.size());
                    merged.addAll(cluster1);
                    merged.addAll(cluster2);
                    newClusters.add(merged);
                    nearestPair = findNearestClusters(newClusters);
                    if (nearestPair == null) {
                        break;
                    }
                    cluster1 = nearestPair.getFirst();
                    cluster2 = nearestPair.getSecond();
                }
            }
        } else {
            while (newClusters.size() > numClusters) {
                Pair<FastIDSet, FastIDSet> nearestPair = findNearestClusters(newClusters);
                if (nearestPair == null) {
                    break;
                }
                FastIDSet cluster1 = nearestPair.getFirst();
                FastIDSet cluster2 = nearestPair.getSecond();
                newClusters.remove(cluster1);
                newClusters.remove(cluster2);
                FastIDSet merged = new FastIDSet(cluster1.size() + cluster2.size());
                merged.addAll(cluster1);
                merged.addAll(cluster2);
                newClusters.add(merged);
            }
        }
    }

    private Pair<FastIDSet, FastIDSet> findNearestClusters(List<FastIDSet> clusters) throws TasteException {
        int size = clusters.size();
        Pair<FastIDSet, FastIDSet> nearestPair = null;
        double bestSimilarity = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < size; i++) {
            FastIDSet cluster1 = clusters.get(i);
            for (int j = i + 1; j < size; j++) {
                if ((samplingRate >= 1.0) || (RANDOM.nextDouble() < samplingRate)) {
                    FastIDSet cluster2 = clusters.get(j);
                    double similarity = clusterSimilarity.getSimilarity(cluster1, cluster2);
                    if (!Double.isNaN(similarity) && (similarity > bestSimilarity)) {
                        bestSimilarity = similarity;
                        nearestPair = new Pair<FastIDSet, FastIDSet>(cluster1, cluster2);
                    }
                }
            }
        }
        return nearestPair;
    }

    private FastByIDMap<List<RecommendedItem>> computeTopRecsPerUserID(Iterable<FastIDSet> clusters)
            throws TasteException {
        FastByIDMap<List<RecommendedItem>> recsPerUser = new FastByIDMap<List<RecommendedItem>>();
        for (FastIDSet cluster : clusters) {
            List<RecommendedItem> recs = computeTopRecsForCluster(cluster);
            LongPrimitiveIterator it = cluster.iterator();
            while (it.hasNext()) {
                recsPerUser.put(it.next(), recs);
            }
        }
        return recsPerUser;
    }

    private List<RecommendedItem> computeTopRecsForCluster(FastIDSet cluster) throws TasteException {
        DataModel dataModel = getDataModel();
        FastIDSet possibleItemIDs = new FastIDSet();
        LongPrimitiveIterator it = cluster.iterator();
        while (it.hasNext()) {
            possibleItemIDs.addAll(dataModel.getItemIDsFromUser(it.next()));
        }

        TopItems.Estimator<Long> estimator = new Estimator(cluster);

        List<RecommendedItem> topItems = TopItems.getTopItems(possibleItemIDs.size(), possibleItemIDs.iterator(),
                null, estimator);

        log.debug("Recommendations are: {}", topItems);
        return Collections.unmodifiableList(topItems);
    }

    private static FastByIDMap<FastIDSet> computeClustersPerUserID(Collection<FastIDSet> clusters) {
        FastByIDMap<FastIDSet> clustersPerUser = new FastByIDMap<FastIDSet>(clusters.size());
        for (FastIDSet cluster : clusters) {
            LongPrimitiveIterator it = cluster.iterator();
            while (it.hasNext()) {
                clustersPerUser.put(it.next(), cluster);
            }
        }
        return clustersPerUser;
    }

    @Override
    public void refresh(Collection<Refreshable> alreadyRefreshed) {
        refreshHelper.refresh(alreadyRefreshed);
    }

    @Override
    public String toString() {
        return "TreeClusteringRecommender[clusterSimilarity:" + clusterSimilarity + ']';
    }

    private final class Estimator implements TopItems.Estimator<Long> {

        private final FastIDSet cluster;

        private Estimator(FastIDSet cluster) {
            this.cluster = cluster;
        }

        @Override
        public double estimate(Long itemID) throws TasteException {
            DataModel dataModel = getDataModel();
            RunningAverage average = new FullRunningAverage();
            LongPrimitiveIterator it = cluster.iterator();
            while (it.hasNext()) {
                Float pref = dataModel.getPreferenceValue(it.next(), itemID);
                if (pref != null) {
                    average.addDatum(pref);
                }
            }
            return average.getAverage();
        }
    }
}