org.apache.mahout.utils.eval.InMemoryFactorizationEvaluator.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.utils.eval.InMemoryFactorizationEvaluator.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.utils.eval;

import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.model.GenericPreference;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.IOUtils;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.nio.charset.Charset;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
 * <p>Measures the root-mean-squared error of a ratring matrix factorization against a test set.</p>
 *
 * <p>the factorization matrices are read into memory, which makes this job pretty fast, if you get OutOfMemoryErrors,
 * use {@link ParallelFactorizationEvaluator} instead</p>
 *
  * <p>Command line arguments specific to this class are:</p>
 *
 * <ol>
 * <li>--output (path): path where output should go</li>
 * <li>--pairs (path): path containing the test ratings, each line must be userID,itemID,rating</li>
 * <li>--userFeatures (path): path to the user feature matrix</li>
 * <li>--itemFeatures (path): path to the item feature matrix</li>
 * </ol>
 */
public class InMemoryFactorizationEvaluator extends AbstractJob {

    public static void main(String[] args) throws Exception {
        ToolRunner.run(new InMemoryFactorizationEvaluator(), args);
    }

    @Override
    public int run(String[] args) throws Exception {

        addOption("pairs", "p", "path containing the test ratings, each line must be userID,itemID,rating", true);
        addOption("userFeatures", "u", "path to the user feature matrix", true);
        addOption("itemFeatures", "i", "path to the item feature matrix", true);
        addOutputOption();

        Map<String, String> parsedArgs = parseArguments(args);
        if (parsedArgs == null) {
            return -1;
        }

        Path pairs = new Path(parsedArgs.get("--pairs"));
        Path userFeatures = new Path(parsedArgs.get("--userFeatures"));
        Path itemFeatures = new Path(parsedArgs.get("--itemFeatures"));

        Matrix u = readMatrix(userFeatures);
        Matrix m = readMatrix(itemFeatures);

        FullRunningAverage rmseAvg = new FullRunningAverage();
        FullRunningAverage maeAvg = new FullRunningAverage();
        int pairsUsed = 1;
        Writer writer = new OutputStreamWriter(System.out);
        try {
            for (Preference pref : readProbePreferences(pairs)) {
                int userID = (int) pref.getUserID();
                int itemID = (int) pref.getItemID();

                double rating = pref.getValue();
                double estimate = u.getRow(userID).dot(m.getRow(itemID));
                double err = rating - estimate;
                rmseAvg.addDatum(err * err);
                maeAvg.addDatum(Math.abs(err));
                writer.write("Probe [" + pairsUsed + "], rating of user [" + userID + "] towards item [" + itemID
                        + "], " + "[" + rating + "] estimated [" + estimate + "]\n");
                pairsUsed++;
            }
            double rmse = Math.sqrt(rmseAvg.getAverage());
            double mae = maeAvg.getAverage();
            writer.write("RMSE: " + rmse + ", MAE: " + mae + "\n");
        } finally {
            IOUtils.quietClose(writer);
        }
        return 0;
    }

    private Matrix readMatrix(Path dir) throws IOException {

        Matrix matrix = new SparseMatrix(new int[] { Integer.MAX_VALUE, Integer.MAX_VALUE });

        FileSystem fs = dir.getFileSystem(getConf());
        for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) {
            Path path = seqFile.getPath();
            SequenceFile.Reader reader = null;
            try {
                reader = new SequenceFile.Reader(fs, path, getConf());
                IntWritable key = new IntWritable();
                VectorWritable value = new VectorWritable();
                while (reader.next(key, value)) {
                    int row = key.get();
                    Iterator<Vector.Element> elementsIterator = value.get().iterateNonZero();
                    while (elementsIterator.hasNext()) {
                        Vector.Element element = elementsIterator.next();
                        matrix.set(row, element.index(), element.get());
                    }
                }
            } finally {
                IOUtils.quietClose(reader);
            }
        }
        return matrix;
    }

    private List<Preference> readProbePreferences(Path dir) throws IOException {

        List<Preference> preferences = new LinkedList<Preference>();
        FileSystem fs = dir.getFileSystem(getConf());
        for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) {
            Path path = seqFile.getPath();
            InputStream in = null;
            try {
                in = fs.open(path);
                BufferedReader reader = new BufferedReader(new InputStreamReader(in, Charset.forName("UTF-8")));
                String line;
                while ((line = reader.readLine()) != null) {
                    String[] tokens = TasteHadoopUtils.splitPrefTokens(line);
                    long userID = Long.parseLong(tokens[0]);
                    long itemID = Long.parseLong(tokens[1]);
                    float value = Float.parseFloat(tokens[2]);
                    preferences.add(new GenericPreference(userID, itemID, value));
                }
            } finally {
                IOUtils.quietClose(in);
            }
        }
        return preferences;
    }
}