Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.utils.eval; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; import org.apache.mahout.cf.taste.impl.model.GenericPreference; import org.apache.mahout.cf.taste.model.Preference; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.IOUtils; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.SparseMatrix; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStreamWriter; import java.io.Writer; import java.nio.charset.Charset; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; /** * <p>Measures the root-mean-squared error of a ratring matrix factorization against a test set.</p> * * <p>the factorization matrices are read into memory, which makes this job pretty fast, if you get OutOfMemoryErrors, * use {@link ParallelFactorizationEvaluator} instead</p> * * <p>Command line arguments specific to this class are:</p> * * <ol> * <li>--output (path): path where output should go</li> * <li>--pairs (path): path containing the test ratings, each line must be userID,itemID,rating</li> * <li>--userFeatures (path): path to the user feature matrix</li> * <li>--itemFeatures (path): path to the item feature matrix</li> * </ol> */ public class InMemoryFactorizationEvaluator extends AbstractJob { public static void main(String[] args) throws Exception { ToolRunner.run(new InMemoryFactorizationEvaluator(), args); } @Override public int run(String[] args) throws Exception { addOption("pairs", "p", "path containing the test ratings, each line must be userID,itemID,rating", true); addOption("userFeatures", "u", "path to the user feature matrix", true); addOption("itemFeatures", "i", "path to the item feature matrix", true); addOutputOption(); Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } Path pairs = new Path(parsedArgs.get("--pairs")); Path userFeatures = new Path(parsedArgs.get("--userFeatures")); Path itemFeatures = new Path(parsedArgs.get("--itemFeatures")); Matrix u = readMatrix(userFeatures); Matrix m = readMatrix(itemFeatures); FullRunningAverage rmseAvg = new FullRunningAverage(); FullRunningAverage maeAvg = new FullRunningAverage(); int pairsUsed = 1; Writer writer = new OutputStreamWriter(System.out); try { for (Preference pref : readProbePreferences(pairs)) { int userID = (int) pref.getUserID(); int itemID = (int) pref.getItemID(); double rating = pref.getValue(); double estimate = u.getRow(userID).dot(m.getRow(itemID)); double err = rating - estimate; rmseAvg.addDatum(err * err); maeAvg.addDatum(Math.abs(err)); writer.write("Probe [" + pairsUsed + "], rating of user [" + userID + "] towards item [" + itemID + "], " + "[" + rating + "] estimated [" + estimate + "]\n"); pairsUsed++; } double rmse = Math.sqrt(rmseAvg.getAverage()); double mae = maeAvg.getAverage(); writer.write("RMSE: " + rmse + ", MAE: " + mae + "\n"); } finally { IOUtils.quietClose(writer); } return 0; } private Matrix readMatrix(Path dir) throws IOException { Matrix matrix = new SparseMatrix(new int[] { Integer.MAX_VALUE, Integer.MAX_VALUE }); FileSystem fs = dir.getFileSystem(getConf()); for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) { Path path = seqFile.getPath(); SequenceFile.Reader reader = null; try { reader = new SequenceFile.Reader(fs, path, getConf()); IntWritable key = new IntWritable(); VectorWritable value = new VectorWritable(); while (reader.next(key, value)) { int row = key.get(); Iterator<Vector.Element> elementsIterator = value.get().iterateNonZero(); while (elementsIterator.hasNext()) { Vector.Element element = elementsIterator.next(); matrix.set(row, element.index(), element.get()); } } } finally { IOUtils.quietClose(reader); } } return matrix; } private List<Preference> readProbePreferences(Path dir) throws IOException { List<Preference> preferences = new LinkedList<Preference>(); FileSystem fs = dir.getFileSystem(getConf()); for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) { Path path = seqFile.getPath(); InputStream in = null; try { in = fs.open(path); BufferedReader reader = new BufferedReader(new InputStreamReader(in, Charset.forName("UTF-8"))); String line; while ((line = reader.readLine()) != null) { String[] tokens = TasteHadoopUtils.splitPrefTokens(line); long userID = Long.parseLong(tokens[0]); long itemID = Long.parseLong(tokens[1]); float value = Float.parseFloat(tokens[2]); preferences.add(new GenericPreference(userID, itemID, value)); } } finally { IOUtils.quietClose(in); } } return preferences; } }