net.myrrix.online.eval.PrecisionRecallEvaluator.java Source code

Java tutorial

Introduction

Here is the source code for net.myrrix.online.eval.PrecisionRecallEvaluator.java

Source

/*
 * Copyright Myrrix Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.myrrix.online.eval;

import java.io.File;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;

import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.recommender.IDRescorer;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.myrrix.common.MyrrixRecommender;
import net.myrrix.common.parallel.Paralleler;
import net.myrrix.common.parallel.Processor;
import net.myrrix.online.RescorerProvider;

/**
 * <p>A simple evaluation framework for a recommender, which calculates precision, recall, F1,
 * mean average precision, and other basic statistics.</p>
 * 
 * <p>This class can be run as a Java program; the single argument is a directory containing test data.
  * The {@link EvaluationResult} is printed to standard out.</p>
 *
 * @author Sean Owen
 * @since 1.0
 */
public final class PrecisionRecallEvaluator extends AbstractEvaluator {

    private static final Logger log = LoggerFactory.getLogger(PrecisionRecallEvaluator.class);

    private static final double LN2 = Math.log(2.0);

    @Override
    protected boolean isSplitTestByPrefValue() {
        return true;
    }

    @Override
    public EvaluationResult evaluate(final MyrrixRecommender recommender, final RescorerProvider provider,
            final Multimap<Long, RecommendedItem> testData) throws TasteException {

        final Mean precision = new Mean();
        final Mean recall = new Mean();
        final Mean ndcg = new Mean();
        final Mean meanAveragePrecision = new Mean();

        Processor<Long> processor = new Processor<Long>() {
            @Override
            public void process(Long userID, long count) {

                Collection<RecommendedItem> values = testData.get(userID);
                int numValues = values.size();
                if (numValues == 0) {
                    return;
                }

                IDRescorer rescorer = provider == null ? null
                        : provider.getRecommendRescorer(new long[] { userID }, recommender);

                List<RecommendedItem> recs;
                try {
                    recs = recommender.recommend(userID, numValues, rescorer);
                } catch (NoSuchUserException nsue) {
                    // Probably OK, just removed all data for this user from training
                    log.warn("User only in test data: {}", userID);
                    return;
                } catch (TasteException te) {
                    log.warn("Unexpected exception", te);
                    return;
                }
                int numRecs = recs.size();

                Collection<Long> valueIDs = Sets.newHashSet();
                for (RecommendedItem rec : values) {
                    valueIDs.add(rec.getItemID());
                }

                int intersectionSize = 0;
                double score = 0.0;
                double maxScore = 0.0;
                Mean precisionAtI = new Mean();
                double averagePrecision = 0.0;

                for (int i = 0; i < numRecs; i++) {
                    RecommendedItem rec = recs.get(i);
                    double value = LN2 / Math.log(2.0 + i); // 1 / log_2(1 + (i+1))
                    if (valueIDs.contains(rec.getItemID())) {
                        intersectionSize++;
                        score += value;
                        precisionAtI.increment(1.0);
                        averagePrecision += precisionAtI.getResult();
                    } else {
                        precisionAtI.increment(0.0);
                    }
                    maxScore += value;
                }
                averagePrecision /= numValues;

                synchronized (precision) {
                    precision.increment(numRecs == 0 ? 0.0 : (double) intersectionSize / numRecs);
                    recall.increment((double) intersectionSize / numValues);
                    ndcg.increment(maxScore == 0.0 ? 0.0 : score / maxScore);
                    meanAveragePrecision.increment(averagePrecision);
                    if (count % 10000 == 0) {
                        log.info(new IRStatisticsImpl(precision.getResult(), recall.getResult(), ndcg.getResult(),
                                meanAveragePrecision.getResult()).toString());
                    }
                }
            }
        };

        Paralleler<Long> paralleler = new Paralleler<Long>(testData.keySet().iterator(), processor, "PREval");
        try {
            if (Boolean.parseBoolean(System.getProperty("eval.parallel", "true"))) {
                paralleler.runInParallel();
            } else {
                paralleler.runInSerial();
            }
        } catch (InterruptedException ie) {
            throw new TasteException(ie);
        } catch (ExecutionException e) {
            throw new TasteException(e.getCause());
        }

        EvaluationResult result;
        if (precision.getN() > 0) {
            result = new IRStatisticsImpl(precision.getResult(), recall.getResult(), ndcg.getResult(),
                    meanAveragePrecision.getResult());
        } else {
            result = null;
        }
        log.info(String.valueOf(result));
        return result;
    }

    public static void main(String[] args) throws Exception {
        EvaluationResult result = new PrecisionRecallEvaluator().evaluate(new File(args[0]));
        log.info(String.valueOf(result));
    }

}