com.twitter.algebra.nmf.ErrDMJ.java Source code

Java tutorial

Introduction

Here is the source code for com.twitter.algebra.nmf.ErrDMJ.java

Source

/*
Copyright 2014 Twitter, Inc.
    
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    
http://www.apache.org/licenses/LICENSE-2.0
    
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebra.nmf;

import java.io.IOException;
import java.util.Iterator;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Counters;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.MultipleInputs;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.DistributedRowMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.twitter.algebra.AlgebraCommon;
import com.twitter.algebra.matrix.format.MapDir;
import com.twitter.algebra.matrix.format.MatrixOutputFormat;

/**
 * | X - A * Y |
 * @author myabandeh
 */
public class ErrDMJ extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger(ErrDMJ.class);

    public static final String MAPDIRMATRIXX = "mapDirMatrixX";
    public static final String MAPDIRMATRIXYT = "mapDirMatrixYt";
    public static final String YTROWS = "YtRows";
    public static final String YTCOLS = "YtCols";

    public static long run(Configuration conf, DistributedRowMatrix X, Vector xColSumVec, DistributedRowMatrix A,
            DistributedRowMatrix Yt, String label)
            throws IOException, InterruptedException, ClassNotFoundException {
        log.info("running " + ErrDMJ.class.getName());
        if (X.numRows() != A.numRows()) {
            throw new CardinalityException(A.numRows(), A.numRows());
        }
        if (A.numCols() != Yt.numCols()) {
            throw new CardinalityException(A.numCols(), Yt.numCols());
        }
        if (X.numCols() != Yt.numRows()) {
            throw new CardinalityException(X.numCols(), Yt.numRows());
        }
        Path outPath = new Path(A.getOutputTempPath(), label);
        FileSystem fs = FileSystem.get(outPath.toUri(), conf);
        ErrDMJ job = new ErrDMJ();
        long totalErr = -1;
        if (!fs.exists(outPath)) {
            Job hJob = job.run(conf, X.getRowPath(), A.getRowPath(), Yt.getRowPath(), outPath, A.numRows(),
                    Yt.numRows(), Yt.numCols());
            Counters counters = hJob.getCounters();
            counters.findCounter("Result", "sumAbs").getValue();
            log.info("FINAL ERR is " + totalErr);
        } else {
            log.warn("----------- Skip already exists: " + outPath);
        }
        Vector sumErrVec = AlgebraCommon.mapDirToSparseVector(outPath, 1, X.numCols(), conf);
        double maxColErr = Double.MIN_VALUE;
        double sumColErr = 0;
        int cntColErr = 0;
        Iterator<Vector.Element> it = sumErrVec.nonZeroes().iterator();
        while (it.hasNext()) {
            Vector.Element el = it.next();
            double errP2 = el.get();
            double origP2 = xColSumVec.get(el.index());
            double colErr = Math.sqrt(errP2 / origP2);
            log.info("col: " + el.index() + " sum(err^2): " + errP2 + " sum(val^2): " + origP2 + " colErr: "
                    + colErr);
            maxColErr = Math.max(colErr, maxColErr);
            sumColErr += colErr;
            cntColErr++;
        }
        log.info(" Max Col Err: " + maxColErr);
        log.info(" Avg Col Err: " + sumColErr / cntColErr);
        return totalErr;
    }

    public Job run(Configuration conf, Path xPath, Path matrixAInputPath, Path ytPath, Path outPath, int aRows,
            int ytRows, int ytCols) throws IOException, InterruptedException, ClassNotFoundException {
        conf = new Configuration(conf);

        conf.set(MAPDIRMATRIXX, xPath.toString());
        conf.set(MAPDIRMATRIXYT, ytPath.toString());
        conf.setInt(YTROWS, ytRows);
        conf.setInt(YTCOLS, ytCols);
        FileSystem fs = FileSystem.get(outPath.toUri(), conf);
        NMFCommon.setNumberOfMapSlots(conf, fs, matrixAInputPath, "err");

        @SuppressWarnings("deprecation")
        Job job = new Job(conf);
        job.setJarByClass(ErrDMJ.class);
        job.setJobName(ErrDMJ.class.getSimpleName() + "-" + outPath.getName());

        matrixAInputPath = fs.makeQualified(matrixAInputPath);
        MultipleInputs.addInputPath(job, matrixAInputPath, SequenceFileInputFormat.class);

        outPath = fs.makeQualified(outPath);
        FileOutputFormat.setOutputPath(job, outPath);
        job.setMapperClass(MyMapper.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(VectorWritable.class);

        int numReducers = 1;
        job.setNumReduceTasks(numReducers);
        job.setCombinerClass(SumVectorsReducer.class);
        job.setReducerClass(SumVectorsReducer.class);

        job.setOutputFormatClass(MatrixOutputFormat.class);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(VectorWritable.class);
        job.submit();
        boolean res = job.waitForCompletion(true);
        if (!res)
            throw new IOException("Job failed! ");
        return job;
    }

    public static class MyMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private MapDir xMapDir;
        private Matrix ytMatrix;
        private VectorWritable xVectorw = new VectorWritable();
        private VectorWritable outvw = new VectorWritable();
        private IntWritable iw = new IntWritable(0);
        double totalDiff = 0;
        Vector resVector = null;

        @Override
        public void setup(Context context) throws IOException {
            Configuration conf = context.getConfiguration();
            Path mapDirMatrixXPath = new Path(conf.get(MAPDIRMATRIXX));
            xMapDir = new MapDir(conf, mapDirMatrixXPath);
            Path mapDirMatrixYtPath = new Path(conf.get(MAPDIRMATRIXYT));
            int ytRows = conf.getInt(YTROWS, 0);
            int ytCols = conf.getInt(YTCOLS, 0);
            ytMatrix = AlgebraCommon.mapDirToSparseMatrix(mapDirMatrixYtPath, ytRows, ytCols, conf);
        }

        @Override
        public void map(IntWritable index, VectorWritable avw, Context context)
                throws IOException, InterruptedException {
            Vector av = avw.get();
            Writable xvw = xMapDir.get(index, xVectorw);
            if (xvw == null) {
                // too many nulls could indicate a bug, good to check
                context.getCounter("MapDir", "nullValues").increment(1);
                return;
            }
            Vector xv = xVectorw.get();
            if (resVector == null)
                resVector = new RandomAccessSparseVector(ytMatrix.numRows());
            AlgebraCommon.vectorTimesMatrixTranspose(av, ytMatrix, resVector);
            Vector errVector = resVector.minus(xv);
            for (Vector.Element el : errVector.nonZeroes()) {
                int eli = el.index();
                double val = el.get();
                errVector.set(eli, val * val);
            }
            totalDiff += Math.abs(errVector.zSum());
            outvw.set(errVector);
            context.write(iw, outvw);
        }

        @Override
        public void cleanup(Context context) throws IOException {
            xMapDir.close();
            int microDiff = (int) (totalDiff * 1000 * 1000);
            System.out.println("totalDiff " + totalDiff + " microDiff " + microDiff);
            context.getCounter("Result", "sumAbsMicro").increment(microDiff);
            context.getCounter("Result", "sumAbsMilli").increment((int) (totalDiff * 1000));
            context.getCounter("Result", "sumAbs").increment((int) (totalDiff));
        }
    }

    static public class SumVectorsReducer
            extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
        @Override
        public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context context)
                throws IOException, InterruptedException {
            Vector merged = sumToVector(vectors.iterator());
            context.write(key, new VectorWritable(new SequentialAccessSparseVector(merged)));
        }

        Vector sumToVector(Iterator<VectorWritable> vectors) {
            Vector accumulator = vectors.next().get();
            while (vectors.hasNext()) {
                VectorWritable v = vectors.next();
                if (v != null) {
                    for (Element nonZeroElement : v.get().nonZeroes()) {
                        int i = nonZeroElement.index();
                        double preVal = accumulator.get(i);
                        accumulator.setQuick(i, preVal + nonZeroElement.get());
                    }
                }
            }
            return accumulator;
        }
    }

    @Override
    public int run(String[] args) throws Exception {
        throw new Exception("Not implemented yet");
    }
}