org.apache.mahout.math.hadoop.similarity.cooccurrence.RowSimilarityJob.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.math.hadoop.similarity.cooccurrence.RowSimilarityJob.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math.hadoop.similarity.cooccurrence;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.mapreduce.VectorSumCombiner;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasures;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
import org.apache.mahout.math.map.OpenIntIntHashMap;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

public class RowSimilarityJob extends AbstractJob {

    public static final double NO_THRESHOLD = Double.MIN_VALUE;
    public static final long NO_FIXED_RANDOM_SEED = Long.MIN_VALUE;

    private static final String SIMILARITY_CLASSNAME = RowSimilarityJob.class + ".distributedSimilarityClassname";
    private static final String NUMBER_OF_COLUMNS = RowSimilarityJob.class + ".numberOfColumns";
    private static final String MAX_SIMILARITIES_PER_ROW = RowSimilarityJob.class + ".maxSimilaritiesPerRow";
    private static final String EXCLUDE_SELF_SIMILARITY = RowSimilarityJob.class + ".excludeSelfSimilarity";

    private static final String THRESHOLD = RowSimilarityJob.class + ".threshold";
    private static final String NORMS_PATH = RowSimilarityJob.class + ".normsPath";
    private static final String MAXVALUES_PATH = RowSimilarityJob.class + ".maxWeightsPath";

    private static final String NUM_NON_ZERO_ENTRIES_PATH = RowSimilarityJob.class + ".nonZeroEntriesPath";
    private static final int DEFAULT_MAX_SIMILARITIES_PER_ROW = 100;

    private static final String OBSERVATIONS_PER_COLUMN_PATH = RowSimilarityJob.class
            + ".observationsPerColumnPath";

    private static final String MAX_OBSERVATIONS_PER_ROW = RowSimilarityJob.class + ".maxObservationsPerRow";
    private static final String MAX_OBSERVATIONS_PER_COLUMN = RowSimilarityJob.class + ".maxObservationsPerColumn";
    private static final String RANDOM_SEED = RowSimilarityJob.class + ".randomSeed";

    private static final int DEFAULT_MAX_OBSERVATIONS_PER_ROW = 500;
    private static final int DEFAULT_MAX_OBSERVATIONS_PER_COLUMN = 500;

    private static final int NORM_VECTOR_MARKER = Integer.MIN_VALUE;
    private static final int MAXVALUE_VECTOR_MARKER = Integer.MIN_VALUE + 1;
    private static final int NUM_NON_ZERO_ENTRIES_VECTOR_MARKER = Integer.MIN_VALUE + 2;

    enum Counters {
        ROWS, USED_OBSERVATIONS, NEGLECTED_OBSERVATIONS, COOCCURRENCES, PRUNED_COOCCURRENCES
    }

    public static void main(String[] args) throws Exception {
        ToolRunner.run(new RowSimilarityJob(), args);
    }

    @Override
    public int run(String[] args) throws Exception {

        addInputOption();
        addOutputOption();
        addOption("numberOfColumns", "r", "Number of columns in the input matrix", false);
        addOption("similarityClassname", "s",
                "Name of distributed similarity class to instantiate, alternatively use "
                        + "one of the predefined similarities (" + VectorSimilarityMeasures.list() + ')');
        addOption("maxSimilaritiesPerRow", "m",
                "Number of maximum similarities per row (default: " + DEFAULT_MAX_SIMILARITIES_PER_ROW + ')',
                String.valueOf(DEFAULT_MAX_SIMILARITIES_PER_ROW));
        addOption("excludeSelfSimilarity", "ess", "compute similarity of rows to themselves?",
                String.valueOf(false));
        addOption("threshold", "tr", "discard row pairs with a similarity value below this", false);
        addOption("maxObservationsPerRow", null, "sample rows down to this number of entries",
                String.valueOf(DEFAULT_MAX_OBSERVATIONS_PER_ROW));
        addOption("maxObservationsPerColumn", null, "sample columns down to this number of entries",
                String.valueOf(DEFAULT_MAX_OBSERVATIONS_PER_COLUMN));
        addOption("randomSeed", null, "use this seed for sampling", false);
        addOption(DefaultOptionCreator.overwriteOption().create());

        Map<String, List<String>> parsedArgs = parseArguments(args);
        if (parsedArgs == null) {
            return -1;
        }

        int numberOfColumns;

        if (hasOption("numberOfColumns")) {
            // Number of columns explicitly specified via CLI
            numberOfColumns = Integer.parseInt(getOption("numberOfColumns"));
        } else {
            // else get the number of columns by determining the cardinality of a vector in the input matrix
            numberOfColumns = getDimensions(getInputPath());
        }

        String similarityClassnameArg = getOption("similarityClassname");
        String similarityClassname;
        try {
            similarityClassname = VectorSimilarityMeasures.valueOf(similarityClassnameArg).getClassname();
        } catch (IllegalArgumentException iae) {
            similarityClassname = similarityClassnameArg;
        }

        // Clear the output and temp paths if the overwrite option has been set
        if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
            // Clear the temp path
            HadoopUtil.delete(getConf(), getTempPath());
            // Clear the output path
            HadoopUtil.delete(getConf(), getOutputPath());
        }

        int maxSimilaritiesPerRow = Integer.parseInt(getOption("maxSimilaritiesPerRow"));
        boolean excludeSelfSimilarity = Boolean.parseBoolean(getOption("excludeSelfSimilarity"));
        double threshold = hasOption("threshold") ? Double.parseDouble(getOption("threshold")) : NO_THRESHOLD;
        long randomSeed = hasOption("randomSeed") ? Long.parseLong(getOption("randomSeed")) : NO_FIXED_RANDOM_SEED;

        int maxObservationsPerRow = Integer.parseInt(getOption("maxObservationsPerRow"));
        int maxObservationsPerColumn = Integer.parseInt(getOption("maxObservationsPerColumn"));

        Path weightsPath = getTempPath("weights");
        Path normsPath = getTempPath("norms.bin");
        Path numNonZeroEntriesPath = getTempPath("numNonZeroEntries.bin");
        Path maxValuesPath = getTempPath("maxValues.bin");
        Path pairwiseSimilarityPath = getTempPath("pairwiseSimilarity");

        Path observationsPerColumnPath = getTempPath("observationsPerColumn.bin");

        AtomicInteger currentPhase = new AtomicInteger();

        Job countObservations = prepareJob(getInputPath(), getTempPath("notUsed"), CountObservationsMapper.class,
                NullWritable.class, VectorWritable.class, SumObservationsReducer.class, NullWritable.class,
                VectorWritable.class);
        countObservations.setCombinerClass(VectorSumCombiner.class);
        countObservations.getConfiguration().set(OBSERVATIONS_PER_COLUMN_PATH,
                observationsPerColumnPath.toString());
        countObservations.setNumReduceTasks(1);
        countObservations.waitForCompletion(true);

        if (shouldRunNextPhase(parsedArgs, currentPhase)) {
            Job normsAndTranspose = prepareJob(getInputPath(), weightsPath, VectorNormMapper.class,
                    IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class,
                    VectorWritable.class);
            normsAndTranspose.setCombinerClass(MergeVectorsCombiner.class);
            Configuration normsAndTransposeConf = normsAndTranspose.getConfiguration();
            normsAndTransposeConf.set(THRESHOLD, String.valueOf(threshold));
            normsAndTransposeConf.set(NORMS_PATH, normsPath.toString());
            normsAndTransposeConf.set(NUM_NON_ZERO_ENTRIES_PATH, numNonZeroEntriesPath.toString());
            normsAndTransposeConf.set(MAXVALUES_PATH, maxValuesPath.toString());
            normsAndTransposeConf.set(SIMILARITY_CLASSNAME, similarityClassname);
            normsAndTransposeConf.set(OBSERVATIONS_PER_COLUMN_PATH, observationsPerColumnPath.toString());
            normsAndTransposeConf.set(MAX_OBSERVATIONS_PER_ROW, String.valueOf(maxObservationsPerRow));
            normsAndTransposeConf.set(MAX_OBSERVATIONS_PER_COLUMN, String.valueOf(maxObservationsPerColumn));
            normsAndTransposeConf.set(RANDOM_SEED, String.valueOf(randomSeed));

            boolean succeeded = normsAndTranspose.waitForCompletion(true);
            if (!succeeded) {
                return -1;
            }
        }

        if (shouldRunNextPhase(parsedArgs, currentPhase)) {
            Job pairwiseSimilarity = prepareJob(weightsPath, pairwiseSimilarityPath, CooccurrencesMapper.class,
                    IntWritable.class, VectorWritable.class, SimilarityReducer.class, IntWritable.class,
                    VectorWritable.class);
            pairwiseSimilarity.setCombinerClass(VectorSumReducer.class);
            Configuration pairwiseConf = pairwiseSimilarity.getConfiguration();
            pairwiseConf.set(THRESHOLD, String.valueOf(threshold));
            pairwiseConf.set(NORMS_PATH, normsPath.toString());
            pairwiseConf.set(NUM_NON_ZERO_ENTRIES_PATH, numNonZeroEntriesPath.toString());
            pairwiseConf.set(MAXVALUES_PATH, maxValuesPath.toString());
            pairwiseConf.set(SIMILARITY_CLASSNAME, similarityClassname);
            pairwiseConf.setInt(NUMBER_OF_COLUMNS, numberOfColumns);
            pairwiseConf.setBoolean(EXCLUDE_SELF_SIMILARITY, excludeSelfSimilarity);
            boolean succeeded = pairwiseSimilarity.waitForCompletion(true);
            if (!succeeded) {
                return -1;
            }
        }

        if (shouldRunNextPhase(parsedArgs, currentPhase)) {
            Job asMatrix = prepareJob(pairwiseSimilarityPath, getOutputPath(), UnsymmetrifyMapper.class,
                    IntWritable.class, VectorWritable.class, MergeToTopKSimilaritiesReducer.class,
                    IntWritable.class, VectorWritable.class);
            asMatrix.setCombinerClass(MergeToTopKSimilaritiesReducer.class);
            asMatrix.getConfiguration().setInt(MAX_SIMILARITIES_PER_ROW, maxSimilaritiesPerRow);
            boolean succeeded = asMatrix.waitForCompletion(true);
            if (!succeeded) {
                return -1;
            }
        }

        return 0;
    }

    public static class CountObservationsMapper
            extends Mapper<IntWritable, VectorWritable, NullWritable, VectorWritable> {

        private Vector columnCounts = new RandomAccessSparseVector(Integer.MAX_VALUE);

        @Override
        protected void map(IntWritable rowIndex, VectorWritable rowVectorWritable, Context ctx)
                throws IOException, InterruptedException {

            Vector row = rowVectorWritable.get();
            for (Vector.Element elem : row.nonZeroes()) {
                columnCounts.setQuick(elem.index(), columnCounts.getQuick(elem.index()) + 1);
            }
        }

        @Override
        protected void cleanup(Context ctx) throws IOException, InterruptedException {
            ctx.write(NullWritable.get(), new VectorWritable(columnCounts));
        }
    }

    public static class SumObservationsReducer
            extends Reducer<NullWritable, VectorWritable, NullWritable, VectorWritable> {
        @Override
        protected void reduce(NullWritable nullWritable, Iterable<VectorWritable> partialVectors, Context ctx)
                throws IOException, InterruptedException {
            Vector counts = Vectors.sum(partialVectors.iterator());
            Vectors.write(counts, new Path(ctx.getConfiguration().get(OBSERVATIONS_PER_COLUMN_PATH)),
                    ctx.getConfiguration());
        }
    }

    public static class VectorNormMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {

        private VectorSimilarityMeasure similarity;
        private Vector norms;
        private Vector nonZeroEntries;
        private Vector maxValues;
        private double threshold;

        private OpenIntIntHashMap observationsPerColumn;
        private int maxObservationsPerRow;
        private int maxObservationsPerColumn;

        private Random random;

        @Override
        protected void setup(Context ctx) throws IOException, InterruptedException {

            Configuration conf = ctx.getConfiguration();

            similarity = ClassUtils.instantiateAs(conf.get(SIMILARITY_CLASSNAME), VectorSimilarityMeasure.class);
            norms = new RandomAccessSparseVector(Integer.MAX_VALUE);
            nonZeroEntries = new RandomAccessSparseVector(Integer.MAX_VALUE);
            maxValues = new RandomAccessSparseVector(Integer.MAX_VALUE);
            threshold = Double.parseDouble(conf.get(THRESHOLD));

            observationsPerColumn = Vectors.readAsIntMap(new Path(conf.get(OBSERVATIONS_PER_COLUMN_PATH)), conf);
            maxObservationsPerRow = conf.getInt(MAX_OBSERVATIONS_PER_ROW, DEFAULT_MAX_OBSERVATIONS_PER_ROW);
            maxObservationsPerColumn = conf.getInt(MAX_OBSERVATIONS_PER_COLUMN,
                    DEFAULT_MAX_OBSERVATIONS_PER_COLUMN);

            long seed = Long.parseLong(conf.get(RANDOM_SEED));
            if (seed == NO_FIXED_RANDOM_SEED) {
                random = RandomUtils.getRandom();
            } else {
                random = RandomUtils.getRandom(seed);
            }
        }

        private Vector sampleDown(Vector rowVector, Context ctx) {

            int observationsPerRow = rowVector.getNumNondefaultElements();
            double rowSampleRate = (double) Math.min(maxObservationsPerRow, observationsPerRow)
                    / (double) observationsPerRow;

            Vector downsampledRow = rowVector.like();
            long usedObservations = 0;
            long neglectedObservations = 0;

            for (Vector.Element elem : rowVector.nonZeroes()) {

                int columnCount = observationsPerColumn.get(elem.index());
                double columnSampleRate = (double) Math.min(maxObservationsPerColumn, columnCount)
                        / (double) columnCount;

                if (random.nextDouble() <= Math.min(rowSampleRate, columnSampleRate)) {
                    downsampledRow.setQuick(elem.index(), elem.get());
                    usedObservations++;
                } else {
                    neglectedObservations++;
                }

            }

            ctx.getCounter(Counters.USED_OBSERVATIONS).increment(usedObservations);
            ctx.getCounter(Counters.NEGLECTED_OBSERVATIONS).increment(neglectedObservations);

            return downsampledRow;
        }

        @Override
        protected void map(IntWritable row, VectorWritable vectorWritable, Context ctx)
                throws IOException, InterruptedException {

            Vector sampledRowVector = sampleDown(vectorWritable.get(), ctx);

            Vector rowVector = similarity.normalize(sampledRowVector);

            int numNonZeroEntries = 0;
            double maxValue = Double.MIN_VALUE;

            for (Vector.Element element : rowVector.nonZeroes()) {
                RandomAccessSparseVector partialColumnVector = new RandomAccessSparseVector(Integer.MAX_VALUE);
                partialColumnVector.setQuick(row.get(), element.get());
                ctx.write(new IntWritable(element.index()), new VectorWritable(partialColumnVector));

                numNonZeroEntries++;
                if (maxValue < element.get()) {
                    maxValue = element.get();
                }
            }

            if (threshold != NO_THRESHOLD) {
                nonZeroEntries.setQuick(row.get(), numNonZeroEntries);
                maxValues.setQuick(row.get(), maxValue);
            }
            norms.setQuick(row.get(), similarity.norm(rowVector));

            ctx.getCounter(Counters.ROWS).increment(1);
        }

        @Override
        protected void cleanup(Context ctx) throws IOException, InterruptedException {
            ctx.write(new IntWritable(NORM_VECTOR_MARKER), new VectorWritable(norms));
            ctx.write(new IntWritable(NUM_NON_ZERO_ENTRIES_VECTOR_MARKER), new VectorWritable(nonZeroEntries));
            ctx.write(new IntWritable(MAXVALUE_VECTOR_MARKER), new VectorWritable(maxValues));
        }
    }

    private static class MergeVectorsCombiner
            extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        @Override
        protected void reduce(IntWritable row, Iterable<VectorWritable> partialVectors, Context ctx)
                throws IOException, InterruptedException {
            ctx.write(row, new VectorWritable(Vectors.merge(partialVectors)));
        }
    }

    public static class MergeVectorsReducer
            extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {

        private Path normsPath;
        private Path numNonZeroEntriesPath;
        private Path maxValuesPath;

        @Override
        protected void setup(Context ctx) throws IOException, InterruptedException {
            normsPath = new Path(ctx.getConfiguration().get(NORMS_PATH));
            numNonZeroEntriesPath = new Path(ctx.getConfiguration().get(NUM_NON_ZERO_ENTRIES_PATH));
            maxValuesPath = new Path(ctx.getConfiguration().get(MAXVALUES_PATH));
        }

        @Override
        protected void reduce(IntWritable row, Iterable<VectorWritable> partialVectors, Context ctx)
                throws IOException, InterruptedException {
            Vector partialVector = Vectors.merge(partialVectors);

            if (row.get() == NORM_VECTOR_MARKER) {
                Vectors.write(partialVector, normsPath, ctx.getConfiguration());
            } else if (row.get() == MAXVALUE_VECTOR_MARKER) {
                Vectors.write(partialVector, maxValuesPath, ctx.getConfiguration());
            } else if (row.get() == NUM_NON_ZERO_ENTRIES_VECTOR_MARKER) {
                Vectors.write(partialVector, numNonZeroEntriesPath, ctx.getConfiguration(), true);
            } else {
                ctx.write(row, new VectorWritable(partialVector));
            }
        }
    }

    public static class CooccurrencesMapper
            extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {

        private VectorSimilarityMeasure similarity;

        private OpenIntIntHashMap numNonZeroEntries;
        private Vector maxValues;
        private double threshold;

        private static final Comparator<Vector.Element> BY_INDEX = new Comparator<Vector.Element>() {
            @Override
            public int compare(Vector.Element one, Vector.Element two) {
                return Ints.compare(one.index(), two.index());
            }
        };

        @Override
        protected void setup(Context ctx) throws IOException, InterruptedException {
            similarity = ClassUtils.instantiateAs(ctx.getConfiguration().get(SIMILARITY_CLASSNAME),
                    VectorSimilarityMeasure.class);
            numNonZeroEntries = Vectors.readAsIntMap(
                    new Path(ctx.getConfiguration().get(NUM_NON_ZERO_ENTRIES_PATH)), ctx.getConfiguration());
            maxValues = Vectors.read(new Path(ctx.getConfiguration().get(MAXVALUES_PATH)), ctx.getConfiguration());
            threshold = Double.parseDouble(ctx.getConfiguration().get(THRESHOLD));
        }

        private boolean consider(Vector.Element occurrenceA, Vector.Element occurrenceB) {
            int numNonZeroEntriesA = numNonZeroEntries.get(occurrenceA.index());
            int numNonZeroEntriesB = numNonZeroEntries.get(occurrenceB.index());

            double maxValueA = maxValues.get(occurrenceA.index());
            double maxValueB = maxValues.get(occurrenceB.index());

            return similarity.consider(numNonZeroEntriesA, numNonZeroEntriesB, maxValueA, maxValueB, threshold);
        }

        @Override
        protected void map(IntWritable column, VectorWritable occurrenceVector, Context ctx)
                throws IOException, InterruptedException {
            Vector.Element[] occurrences = Vectors.toArray(occurrenceVector);
            Arrays.sort(occurrences, BY_INDEX);

            int cooccurrences = 0;
            int prunedCooccurrences = 0;
            for (int n = 0; n < occurrences.length; n++) {
                Vector.Element occurrenceA = occurrences[n];
                Vector dots = new RandomAccessSparseVector(Integer.MAX_VALUE);
                for (int m = n; m < occurrences.length; m++) {
                    Vector.Element occurrenceB = occurrences[m];
                    if (threshold == NO_THRESHOLD || consider(occurrenceA, occurrenceB)) {
                        dots.setQuick(occurrenceB.index(),
                                similarity.aggregate(occurrenceA.get(), occurrenceB.get()));
                        cooccurrences++;
                    } else {
                        prunedCooccurrences++;
                    }
                }
                ctx.write(new IntWritable(occurrenceA.index()), new VectorWritable(dots));
            }
            ctx.getCounter(Counters.COOCCURRENCES).increment(cooccurrences);
            ctx.getCounter(Counters.PRUNED_COOCCURRENCES).increment(prunedCooccurrences);
        }
    }

    public static class SimilarityReducer
            extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {

        private VectorSimilarityMeasure similarity;
        private int numberOfColumns;
        private boolean excludeSelfSimilarity;
        private Vector norms;
        private double treshold;

        @Override
        protected void setup(Context ctx) throws IOException, InterruptedException {
            similarity = ClassUtils.instantiateAs(ctx.getConfiguration().get(SIMILARITY_CLASSNAME),
                    VectorSimilarityMeasure.class);
            numberOfColumns = ctx.getConfiguration().getInt(NUMBER_OF_COLUMNS, -1);
            Preconditions.checkArgument(numberOfColumns > 0,
                    "Number of columns must be greater then 0! But numberOfColumns = " + numberOfColumns);
            excludeSelfSimilarity = ctx.getConfiguration().getBoolean(EXCLUDE_SELF_SIMILARITY, false);
            norms = Vectors.read(new Path(ctx.getConfiguration().get(NORMS_PATH)), ctx.getConfiguration());
            treshold = Double.parseDouble(ctx.getConfiguration().get(THRESHOLD));
        }

        @Override
        protected void reduce(IntWritable row, Iterable<VectorWritable> partialDots, Context ctx)
                throws IOException, InterruptedException {
            Iterator<VectorWritable> partialDotsIterator = partialDots.iterator();
            Vector dots = partialDotsIterator.next().get();
            while (partialDotsIterator.hasNext()) {
                Vector toAdd = partialDotsIterator.next().get();
                for (Element nonZeroElement : toAdd.nonZeroes()) {
                    dots.setQuick(nonZeroElement.index(),
                            dots.getQuick(nonZeroElement.index()) + nonZeroElement.get());
                }
            }

            Vector similarities = dots.like();
            double normA = norms.getQuick(row.get());
            for (Element b : dots.nonZeroes()) {
                double similarityValue = similarity.similarity(b.get(), normA, norms.getQuick(b.index()),
                        numberOfColumns);
                if (similarityValue >= treshold) {
                    similarities.set(b.index(), similarityValue);
                }
            }
            if (excludeSelfSimilarity) {
                similarities.setQuick(row.get(), 0);
            }
            ctx.write(row, new VectorWritable(similarities));
        }
    }

    public static class UnsymmetrifyMapper
            extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {

        private int maxSimilaritiesPerRow;

        @Override
        protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
            maxSimilaritiesPerRow = ctx.getConfiguration().getInt(MAX_SIMILARITIES_PER_ROW, 0);
            Preconditions.checkArgument(maxSimilaritiesPerRow > 0,
                    "Maximum number of similarities per row must be greater then 0!");
        }

        @Override
        protected void map(IntWritable row, VectorWritable similaritiesWritable, Context ctx)
                throws IOException, InterruptedException {
            Vector similarities = similaritiesWritable.get();
            // For performance, the creation of transposedPartial is moved out of the while loop and it is reused inside
            Vector transposedPartial = new RandomAccessSparseVector(similarities.size(), 1);
            TopElementsQueue topKQueue = new TopElementsQueue(maxSimilaritiesPerRow);
            for (Element nonZeroElement : similarities.nonZeroes()) {
                MutableElement top = topKQueue.top();
                double candidateValue = nonZeroElement.get();
                if (candidateValue > top.get()) {
                    top.setIndex(nonZeroElement.index());
                    top.set(candidateValue);
                    topKQueue.updateTop();
                }

                transposedPartial.setQuick(row.get(), candidateValue);
                ctx.write(new IntWritable(nonZeroElement.index()), new VectorWritable(transposedPartial));
                transposedPartial.setQuick(row.get(), 0.0);
            }
            Vector topKSimilarities = new RandomAccessSparseVector(similarities.size(), maxSimilaritiesPerRow);
            for (Vector.Element topKSimilarity : topKQueue.getTopElements()) {
                topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
            }
            ctx.write(row, new VectorWritable(topKSimilarities));
        }
    }

    public static class MergeToTopKSimilaritiesReducer
            extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {

        private int maxSimilaritiesPerRow;

        @Override
        protected void setup(Context ctx) throws IOException, InterruptedException {
            maxSimilaritiesPerRow = ctx.getConfiguration().getInt(MAX_SIMILARITIES_PER_ROW, 0);
            Preconditions.checkArgument(maxSimilaritiesPerRow > 0,
                    "Maximum number of similarities per row must be greater then 0!");
        }

        @Override
        protected void reduce(IntWritable row, Iterable<VectorWritable> partials, Context ctx)
                throws IOException, InterruptedException {
            Vector allSimilarities = Vectors.merge(partials);
            Vector topKSimilarities = Vectors.topKElements(maxSimilaritiesPerRow, allSimilarities);
            ctx.write(row, new VectorWritable(topKSimilarities));
        }
    }

}