Java tutorial
/* * Copyright Myrrix Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.myrrix.online.eval; import java.io.File; import java.util.Collection; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Multimap; import org.apache.commons.math3.random.RandomGenerator; import org.apache.mahout.cf.taste.common.NoSuchItemException; import org.apache.mahout.cf.taste.common.NoSuchUserException; import org.apache.mahout.cf.taste.common.TasteException; import org.apache.mahout.cf.taste.recommender.RecommendedItem; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import net.myrrix.common.MyrrixRecommender; import net.myrrix.common.collection.FastByIDMap; import net.myrrix.common.collection.FastIDSet; import net.myrrix.common.parallel.Paralleler; import net.myrrix.common.parallel.Processor; import net.myrrix.common.random.RandomManager; import net.myrrix.common.random.RandomUtils; import net.myrrix.online.RescorerProvider; /** * <p>This implementation calculates Area under curve (AUC), which may be understood as the probability * that a random "good" recommendation is ranked higher than a random "bad" recommendation.</p> * * <p>This class can be run as a Java program; the single argument is a directory containing test data. * The {@link EvaluationResult} is printed to standard out.</p> * * @author Sean Owen * @since 1.0 */ public final class AUCEvaluator extends AbstractEvaluator { private static final Logger log = LoggerFactory.getLogger(AUCEvaluator.class); @Override protected boolean isSplitTestByPrefValue() { return true; } @Override public EvaluationResult evaluate(MyrrixRecommender recommender, RescorerProvider provider, // ignored Multimap<Long, RecommendedItem> testData) throws TasteException { FastByIDMap<FastIDSet> converted = new FastByIDMap<FastIDSet>(testData.size()); for (long userID : testData.keySet()) { Collection<RecommendedItem> userTestData = testData.get(userID); FastIDSet itemIDs = new FastIDSet(userTestData.size()); converted.put(userID, itemIDs); for (RecommendedItem datum : userTestData) { itemIDs.add(datum.getItemID()); } } return evaluate(recommender, converted); } public EvaluationResult evaluate(final MyrrixRecommender recommender, final FastByIDMap<FastIDSet> testData) throws TasteException { final AtomicInteger underCurve = new AtomicInteger(0); final AtomicInteger total = new AtomicInteger(0); final long[] allItemIDs = recommender.getAllItemIDs().toArray(); Processor<Long> processor = new Processor<Long>() { private final RandomGenerator random = RandomManager.getRandom(); @Override public void process(Long userID, long count) throws ExecutionException { FastIDSet testItemIDs = testData.get(userID); int numTest = testItemIDs.size(); for (int i = 0; i < numTest; i++) { long randomTestItemID; long randomTrainingItemID; synchronized (random) { randomTestItemID = RandomUtils.randomFrom(testItemIDs, random); do { randomTrainingItemID = allItemIDs[random.nextInt(allItemIDs.length)]; } while (testItemIDs.contains(randomTrainingItemID)); } float relevantEstimate; float nonRelevantEstimate; try { relevantEstimate = recommender.estimatePreference(userID, randomTestItemID); nonRelevantEstimate = recommender.estimatePreference(userID, randomTrainingItemID); } catch (NoSuchItemException nsie) { // OK; it's possible item only showed up in test split continue; } catch (NoSuchUserException nsie) { // OK; it's possible user only showed up in test split continue; } catch (TasteException te) { throw new ExecutionException(te); } if (relevantEstimate > nonRelevantEstimate) { underCurve.incrementAndGet(); } total.incrementAndGet(); if (count % 100000 == 0) { log.info("AUC: {}", (double) underCurve.get() / total.get()); } } } }; try { new Paralleler<Long>(testData.keySetIterator(), processor, "AUCEval").runInParallel(); } catch (InterruptedException ie) { throw new TasteException(ie); } catch (ExecutionException e) { throw new TasteException(e.getCause()); } double score = (double) underCurve.get() / total.get(); log.info("AUC: {}", score); return new EvaluationResultImpl(score); } public static void main(String[] args) throws Exception { EvaluationResult result = new AUCEvaluator().evaluate(new File(args[0])); log.info(String.valueOf(result)); } }