io.bfscan.clueweb12.LMRetrieval.java Source code

Java tutorial

Introduction

Here is the source code for io.bfscan.clueweb12.LMRetrieval.java

Source

/*
 * ClueWeb Tools: Hadoop tools for manipulating ClueWeb collections
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you
 * may not use this file except in compliance with the License. You may
 * obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package io.bfscan.clueweb12;

import io.bfscan.data.PForDocVector;
import io.bfscan.data.TermStatistics;
import io.bfscan.dictionary.DefaultFrequencySortedDictionary;
import io.bfscan.util.AnalyzerFactory;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Partitioner;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Logger;
import org.apache.lucene.analysis.Analyzer;

import tl.lin.data.array.IntArrayWritable;
import tl.lin.data.pair.PairOfIntString;
import tl.lin.data.pair.PairOfStringFloat;
import tl.lin.lucene.AnalyzerUtils;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

/**
 * <p>Implementation of language modeling. Retrieval parameter <i>smoothing</i> determines the type: 
 * smoothing<=1 means Jelineck-Mercer and smoothing>1 means Dirichlet.</p>
 *
 * <p>Approach:</p>
 *
 * <ol>
 * <li> read the queries and convert into termids (based on the dictionary);
 *      make sure to use the same Lucene Analyzer as in ComputeTermStatistics.java
 * 
 * <li> MyMapper: walk over all document vectors
 *   <ul>
 *      <li> determine all queries which have at least one query term is occurring in the document
 *      <li> for each such query, compute the LM score and emit composite key: (qid,docid), value: (score)
 *   </ul>
 *
 * <li> MyPartitioner: ensure that all keys (qid,docid) with the same qid end up in the same reducer
 * 
 * <li> MyReducer: for each query
 *   <ul>
 *      <li> create a priority queue (minimum heap): we only need to keep the topk highest probability scores
 *      <li> once all key/values are processed, "sort" the doc/score elements in the priority queue (already semi-done in heap)
 *      <li> output the results in TREC result file format
 *   </ul>
 * </ol>
 *
 * @author Claudia Hauff
 */
public class LMRetrieval extends Configured implements Tool {
    private static final Logger LOG = Logger.getLogger(LMRetrieval.class);

    /*
     * Partitioner: all keys with the same qid go to the same reducer
     */
    private static class MyPartitioner extends Partitioner<PairOfIntString, FloatWritable> {

        @Override
        public int getPartition(PairOfIntString arg0, FloatWritable arg1, int numPartitions) {
            return arg0.getLeftElement() % numPartitions;
        }
    }

    /*
     * comparator for the priority queue: elements (docid,score) are sorted by score
     */
    private static class CustomComparator implements Comparator<PairOfStringFloat> {
        @Override
        public int compare(PairOfStringFloat o1, PairOfStringFloat o2) {

            if (o1.getRightElement() == o2.getRightElement()) {
                return 0;
            }
            if (o1.getRightElement() > o2.getRightElement()) {
                return 1;
            }
            return -1;
        }
    }

    /*
     * Mapper outKey: (qid,docid), value: probability score
     */
    private static class MyMapper extends Mapper<Text, IntArrayWritable, PairOfIntString, FloatWritable> {

        private static final PForDocVector DOC = new PForDocVector();
        private DefaultFrequencySortedDictionary dictionary;
        private TermStatistics stats;
        private double smoothingParam;

        private static Analyzer ANALYZER;

        /*
         * for quick access store the queries in two hashmaps: 1. key: termid, value: list of queries in
         * which the termid occurs 2. key: qid, value: list of termids that occur in the query
         */
        private Map<Integer, Set<Integer>> termidQuerySet;
        private Map<Integer, Set<Integer>> queryTermidSet;

        // complex key: (qid,docid)
        private static final PairOfIntString keyOut = new PairOfIntString();
        // value: float; probability score log(P(q|d))
        private static final FloatWritable valueOut = new FloatWritable();

        @Override
        public void setup(Context context) throws IOException {
            FileSystem fs = FileSystem.get(context.getConfiguration());
            String path = context.getConfiguration().get(DICTIONARY_OPTION);
            dictionary = new DefaultFrequencySortedDictionary(path, fs);
            stats = new TermStatistics(new Path(path), fs);

            smoothingParam = context.getConfiguration().getFloat(SMOOTHING, 1000f);
            LOG.info("Smoothing set to " + smoothingParam);

            String analyzerType = context.getConfiguration().get(PREPROCESSING);
            ANALYZER = AnalyzerFactory.getAnalyzer(analyzerType);
            if (ANALYZER == null) {
                LOG.error("Error: proprocessing type not recognized. Abort " + this.getClass().getName());
                return;
            }

            // read the queries from file
            termidQuerySet = Maps.newHashMap();
            queryTermidSet = Maps.newHashMap();
            FSDataInputStream fsin = fs.open(new Path(context.getConfiguration().get(QUERIES_OPTION)));
            BufferedReader br = new BufferedReader(new InputStreamReader(fsin));
            String line;
            while ((line = br.readLine()) != null) {
                int index = line.indexOf(':');
                if (index < 0) {
                    LOG.info("Query file line in incorrect format, expecting <num>:<term> <term>...\nInstead got:\n"
                            + line);
                    continue;
                }
                int qid = Integer.parseInt(line.substring(0, index));
                Set<Integer> termidSet = Sets.newHashSet();

                LOG.info("Parsing query line " + line);

                // normalize the terms (same way as the documents)
                for (String term : AnalyzerUtils.parse(ANALYZER, line.substring(index + 1))) {
                    int termid = dictionary.getId(term);
                    LOG.info("parsed term [" + term + "] has termid " + termid);

                    if (termid < 0) {
                        continue;
                    }

                    termidSet.add(termid);

                    if (termidQuerySet.containsKey(termid)) {
                        termidQuerySet.get(termid).add(qid);
                    } else {
                        Set<Integer> qids = Sets.newHashSet();
                        qids.add(qid);
                        termidQuerySet.put(termid, (HashSet<Integer>) qids);
                    }
                }

                queryTermidSet.put(qid, termidSet);
            }
            br.close();
            fsin.close();
        }

        @Override
        public void map(Text key, IntArrayWritable ints, Context context) throws IOException, InterruptedException {
            PForDocVector.fromIntArrayWritable(ints, DOC);

            // determine which queries we care about for this document
            HashSet<Integer> queriesToDo = Sets.newHashSet();

            // tfMap of the document
            HashMap<Integer, Integer> tfMap = Maps.newHashMap();
            for (int termid : DOC.getTermIds()) {
                int tf = 1;
                if (tfMap.containsKey(termid))
                    tf += tfMap.get(termid);
                tfMap.put(termid, tf);

                if (termidQuerySet.containsKey(termid)) {
                    for (int qid : termidQuerySet.get(termid)) {
                        queriesToDo.add(qid);
                    }
                }
            }

            // for each of the interesting queries, compute log(P(q|d))
            for (int qid : queriesToDo) {
                double score = 0.0;

                for (int termid : queryTermidSet.get(qid)) {
                    double tf = 0.0;
                    if (tfMap.containsKey(termid))
                        tf = tfMap.get(termid);
                    double df = stats.getDf(termid);

                    double mlProb = tf / (double) DOC.getLength();
                    double colProb = df / (double) stats.getCollectionSize();

                    double prob = 0.0;

                    // JM smoothing
                    if (smoothingParam <= 1.0) {
                        prob = smoothingParam * mlProb + (1.0 - smoothingParam) * colProb;
                    }
                    // Dirichlet smoothing
                    else {
                        prob = (double) (tf + smoothingParam * colProb)
                                / (double) (DOC.getLength() + smoothingParam);
                    }

                    score += (float) Math.log(prob);
                }

                keyOut.set(qid, key.toString());
                valueOut.set((float) score);
                context.write(keyOut, valueOut);
            }
        }
    }

    private static class MyReducer extends Reducer<PairOfIntString, FloatWritable, NullWritable, Text> {
        private int topk;
        // PairOfStringFloat is (docid,score)
        private Map<Integer, PriorityQueue<PairOfStringFloat>> queueMap;
        private static final NullWritable nullKey = NullWritable.get();
        private static final Text valueOut = new Text();

        public void setup(Context context) throws IOException {
            topk = context.getConfiguration().getInt(TOPK, 1000);
            LOG.info("topk parameter set to " + topk);

            queueMap = Maps.newHashMap();
        }

        @Override
        public void reduce(PairOfIntString key, Iterable<FloatWritable> values, Context context)
                throws IOException, InterruptedException {
            int qid = key.getLeftElement();

            PriorityQueue<PairOfStringFloat> queue = null;

            if (queueMap.containsKey(qid)) {
                queue = queueMap.get(qid);
            } else {
                queue = new PriorityQueue<PairOfStringFloat>(topk + 1, new CustomComparator());
                queueMap.put(qid, queue);
            }

            // actually, it should only be a single element
            float scoreSum = 0f;
            for (FloatWritable v : values) {
                scoreSum += v.get();
            }

            // if there are less than topk elements, add the new (docid, score) to the queue
            if (queue.size() < topk) {
                queue.add(new PairOfStringFloat(key.getRightElement(), scoreSum));
            }
            // if we have topk elements in the queue, we need to check if the queue's current minimum is
            // smaller than the incoming score; if yes, "exchange" the (docid,score) elements
            else if (queue.peek().getRightElement() < scoreSum) {
                queue.remove();
                queue.add(new PairOfStringFloat(key.getRightElement(), scoreSum));
            }
        }

        // emit the scores for all queries
        public void cleanup(Context context) throws IOException, InterruptedException {
            for (int qid : queueMap.keySet()) {
                PriorityQueue<PairOfStringFloat> queue = queueMap.get(qid);

                if (queue.size() == 0) {
                    continue;
                }

                List<PairOfStringFloat> orderedList = Lists.newArrayList();
                while (queue.size() > 0) {
                    orderedList.add(queue.remove());
                }

                for (int i = orderedList.size(); i > 0; i--) {
                    PairOfStringFloat p = orderedList.get(i - 1);
                    valueOut.set(qid + " Q0 " + p.getLeftElement() + " " + (orderedList.size() - i + 1) + " "
                            + p.getRightElement() + " lmretrieval");
                    context.write(nullKey, valueOut);
                }
            }
        }
    }

    public static final String DOCVECTOR_OPTION = "docvector";
    public static final String OUTPUT_OPTION = "output";
    public static final String DICTIONARY_OPTION = "dictionary";
    public static final String QUERIES_OPTION = "queries";
    public static final String SMOOTHING = "smoothing";
    public static final String TOPK = "topk";
    public static final String PREPROCESSING = "preprocessing";

    /**
     * Runs this tool.
     */
    @SuppressWarnings({ "static-access" })
    public int run(String[] args) throws Exception {
        Options options = new Options();

        options.addOption(OptionBuilder.withArgName("path").hasArg()
                .withDescription("input path (pfor format expected, add * to retrieve files)")
                .create(DOCVECTOR_OPTION));
        options.addOption(
                OptionBuilder.withArgName("path").hasArg().withDescription("output path").create(OUTPUT_OPTION));
        options.addOption(
                OptionBuilder.withArgName("path").hasArg().withDescription("dictionary").create(DICTIONARY_OPTION));
        options.addOption(
                OptionBuilder.withArgName("path").hasArg().withDescription("queries").create(QUERIES_OPTION));
        options.addOption(
                OptionBuilder.withArgName("float").hasArg().withDescription("smoothing").create(SMOOTHING));
        options.addOption(OptionBuilder.withArgName("int").hasArg().withDescription("topk").create(TOPK));
        options.addOption(OptionBuilder.withArgName("string " + AnalyzerFactory.getOptions()).hasArg()
                .withDescription("preprocessing").create(PREPROCESSING));

        CommandLine cmdline;
        CommandLineParser parser = new GnuParser();
        try {
            cmdline = parser.parse(options, args);
        } catch (ParseException exp) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp(this.getClass().getName(), options);
            ToolRunner.printGenericCommandUsage(System.out);
            System.err.println("Error parsing command line: " + exp.getMessage());
            return -1;
        }

        if (!cmdline.hasOption(DOCVECTOR_OPTION) || !cmdline.hasOption(OUTPUT_OPTION)
                || !cmdline.hasOption(DICTIONARY_OPTION) || !cmdline.hasOption(QUERIES_OPTION)
                || !cmdline.hasOption(SMOOTHING) || !cmdline.hasOption(TOPK) || !cmdline.hasOption(PREPROCESSING)) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp(this.getClass().getName(), options);
            ToolRunner.printGenericCommandUsage(System.out);
            return -1;
        }

        String docvector = cmdline.getOptionValue(DOCVECTOR_OPTION);
        String output = cmdline.getOptionValue(OUTPUT_OPTION);
        String dictionary = cmdline.getOptionValue(DICTIONARY_OPTION);
        String queries = cmdline.getOptionValue(QUERIES_OPTION);
        String smoothing = cmdline.getOptionValue(SMOOTHING);
        String topk = cmdline.getOptionValue(TOPK);
        String preprocessing = cmdline.getOptionValue(PREPROCESSING);

        LOG.info("Tool name: " + LMRetrieval.class.getSimpleName());
        LOG.info(" - docvector: " + docvector);
        LOG.info(" - output: " + output);
        LOG.info(" - dictionary: " + dictionary);
        LOG.info(" - queries: " + queries);
        LOG.info(" - smoothing: " + smoothing);
        LOG.info(" - topk: " + topk);
        LOG.info(" - preprocessing: " + preprocessing);

        Configuration conf = getConf();
        conf.set(DICTIONARY_OPTION, dictionary);
        conf.set(QUERIES_OPTION, queries);
        conf.setFloat(SMOOTHING, Float.parseFloat(smoothing));
        conf.setInt(TOPK, Integer.parseInt(topk));
        conf.set(PREPROCESSING, preprocessing);

        conf.set("mapreduce.map.memory.mb", "10048");
        conf.set("mapreduce.map.java.opts", "-Xmx10048m");
        conf.set("mapreduce.reduce.memory.mb", "10048");
        conf.set("mapreduce.reduce.java.opts", "-Xmx10048m");
        conf.set("mapred.task.timeout", "6000000"); // default is 600000

        FileSystem fs = FileSystem.get(conf);
        if (fs.exists(new Path(output))) {
            fs.delete(new Path(output), true);
        }

        Job job = new Job(conf, LMRetrieval.class.getSimpleName() + ":" + docvector);
        job.setJarByClass(LMRetrieval.class);

        FileInputFormat.setInputPaths(job, docvector);
        FileOutputFormat.setOutputPath(job, new Path(output));

        job.setInputFormatClass(SequenceFileInputFormat.class);

        job.setMapOutputKeyClass(PairOfIntString.class);
        job.setMapOutputValueClass(FloatWritable.class);
        job.setOutputKeyClass(NullWritable.class);
        job.setOutputValueClass(Text.class);

        job.setMapperClass(MyMapper.class);
        job.setPartitionerClass(MyPartitioner.class);
        job.setReducerClass(MyReducer.class);

        long startTime = System.currentTimeMillis();
        job.waitForCompletion(true);
        LOG.info("Job Finished in " + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
        return 0;
    }

    /**
     * Dispatches command-line arguments to the tool via the <code>ToolRunner</code>.
     */
    public static void main(String[] args) throws Exception {
        LOG.info("Running " + LMRetrieval.class.getCanonicalName() + " with args " + Arrays.toString(args));
        ToolRunner.run(new LMRetrieval(), args);
    }
}