io.anserini.rerank.lib.AxiomReranker.java Source code

Java tutorial

Introduction

Here is the source code for io.anserini.rerank.lib.AxiomReranker.java

Source

/**
 * Anserini: An information retrieval toolkit built on Lucene
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.anserini.rerank.lib;

import io.anserini.index.generator.LuceneDocumentGenerator;
import io.anserini.index.generator.TweetGenerator;
import io.anserini.rerank.Reranker;
import io.anserini.rerank.RerankerContext;
import io.anserini.rerank.ScoredDocuments;
import io.anserini.search.SearchArgs;

import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.*;
import org.apache.lucene.queryparser.flexible.standard.StandardQueryParser;
import org.apache.lucene.search.*;
import org.apache.lucene.store.FSDirectory;

import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import java.util.regex.Pattern;

import static io.anserini.search.SearchCollection.BREAK_SCORE_TIES_BY_DOCID;
import static io.anserini.search.SearchCollection.BREAK_SCORE_TIES_BY_TWEETID;

/*
 * Axiomatic reranking or Axiomatic semantic relevance feedback model.
 *
 * NOTE: This model supports finding expansion terms using another index. But please make sure
 * that both indexes have the same stemming rules and were built using the same Generator
 * (see {@link io.anserini.index.generator.LuceneDocumentGenerator}) or the model won't work properly.
 * For example, we may stem tweets differently from newswire corpus (TweetsAnalyzer vs. EnglishAnalyzer).
 * Then it is better NOT to using a newswire index for expansion terms and feed them to the original
 * tweets index.
 *
 */
public class AxiomReranker<T> implements Reranker<T> {
    private static final Logger LOG = LogManager.getLogger(AxiomReranker.class);

    private final String field; // from which field we look for the expansion terms, e.g. "body"
    private final boolean deterministic; // whether the expansion terms are deterministically picked
    private final long seed;
    private final String externalIndexPath; // Axiomatic reranking can opt to use
                                            // external sources for searching the expansion
                                            // terms. Typically, we build another index
                                            // separately and include its information here.
    private final ScoreDoc[] internalDocidsCache; // When enabling the deterministic reranking we could cache all the
                                                  // internal Docids for all queries
    private final List<String> externalDocidsCache; // When enabling the deterministic reranking we can opt to read sorted docids
    // from a file. The file can be obtained by running
    // `IndexUtils -index /path/to/index -dumpAllDocids GZ`

    private final int R; // number of top documents in initial results
    private final int N; // factor that used in extracting random documents, we will extract (N-1)*R randomly select documents
    private final int K = 1000; // top similar terms
    private final int M = 20; // number of expansion terms
    private final float beta; // scaling parameter
    private final boolean outputQuery;

    public AxiomReranker(String field, SearchArgs args) throws IOException {
        this.field = field;
        this.deterministic = args.axiom_deterministic;
        this.seed = args.axiom_seed;
        this.R = args.axiom_r;
        this.N = args.axiom_n;
        this.beta = args.axiom_beta;
        this.externalIndexPath = args.axiom_index;
        this.outputQuery = args.axiom_outputQuery;

        if (this.deterministic && this.N > 1) {
            if (args.axiom_docids != null) {
                this.externalDocidsCache = buildExternalDocidsCache(args);
                this.internalDocidsCache = null;
            } else {
                this.internalDocidsCache = buildInternalDocidsCache(args);
                this.externalDocidsCache = null;
            }
        } else {
            this.internalDocidsCache = null;
            this.externalDocidsCache = null;
        }
    }

    @Override
    public ScoredDocuments rerank(ScoredDocuments docs, RerankerContext<T> context) {
        assert (docs.documents.length == docs.scores.length);

        try {
            // First to search against external index if it is not null
            docs = processExternalContext(docs, context);
            // Select R*M docs from the original ranking list as the reranking pool
            Set<Integer> usedDocs = selectDocs(docs, context);
            // Extract an inverted list from the reranking pool
            Map<String, Set<Integer>> termInvertedList = extractTerms(usedDocs, context, null);
            // Calculate all the terms in the reranking pool and pick top K of them
            Map<String, Double> expandedTermScores = computeTermScore(termInvertedList, context);
            StringBuilder builder = new StringBuilder();
            for (Map.Entry<String, Double> termScore : expandedTermScores.entrySet()) {
                String term = termScore.getKey();
                double score = termScore.getValue();
                builder.append(term).append("^").append(score).append(" ");
            }
            String queryText = builder.toString().trim();

            if (queryText.isEmpty()) {
                LOG.info("[Empty Expanded Query]: " + context.getQueryTokens());
                queryText = context.getQueryText();
            }

            StandardQueryParser p = new StandardQueryParser();
            Query nq = p.parse(queryText, this.field);

            if (this.outputQuery) {
                LOG.info("QID: " + context.getQueryId());
                LOG.info("Original Query: " + context.getQuery().toString(this.field));
                LOG.info("Running new query: " + nq.toString(this.field));
            }

            return searchTopDocs(nq, context);
        } catch (Exception e) {
            e.printStackTrace();
            return docs;
        }
    }

    /**
     * Please note that the query in the context is always the keywordQuery w/o filter!
     */
    private ScoredDocuments searchTopDocs(Query query, RerankerContext<T> context) throws IOException {
        IndexSearcher searcher = context.getIndexSearcher();
        Query finalQuery;
        if (query == null) { // we are dealing with the external index and we DONOT apply filter to it.
            finalQuery = context.getQuery();
        } else {
            if (context.getFilter() != null) {
                // If there's a filter condition, we need to add in the constraint.
                // Otherwise, just use the original query.
                BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
                bqBuilder.add(context.getFilter(), BooleanClause.Occur.FILTER);
                bqBuilder.add(query, BooleanClause.Occur.MUST);
                finalQuery = bqBuilder.build();
            } else {
                finalQuery = query;
            }
        }

        TopDocs rs;
        // Figure out how to break the scoring ties.
        if (context.getSearchArgs().arbitraryScoreTieBreak) {
            rs = searcher.search(finalQuery, context.getSearchArgs().hits);
        } else if (context.getSearchArgs().searchtweets) {
            rs = searcher.search(finalQuery, context.getSearchArgs().hits, BREAK_SCORE_TIES_BY_TWEETID, true, true);
        } else {
            rs = searcher.search(finalQuery, context.getSearchArgs().hits, BREAK_SCORE_TIES_BY_DOCID, true, true);
        }

        return ScoredDocuments.fromTopDocs(rs, searcher);
    }

    public InputStream getReadFileStream(String path) throws IOException {
        InputStream fin = Files.newInputStream(Paths.get(path), StandardOpenOption.READ);
        BufferedInputStream in = new BufferedInputStream(fin);
        if (path.endsWith(".bz2")) {
            BZip2CompressorInputStream bzIn = new BZip2CompressorInputStream(in);
            return bzIn;
        } else if (path.endsWith(".gz")) {
            GzipCompressorInputStream gzIn = new GzipCompressorInputStream(in);
            return gzIn;
        } else if (path.endsWith(".zip")) {
            GzipCompressorInputStream zipIn = new GzipCompressorInputStream(in);
            return zipIn;
        }
        return in;
    }

    /**
     * If the result is deterministic we can cache all the external docids by reading them from a file
     */
    private List<String> buildExternalDocidsCache(SearchArgs args) throws IOException {
        InputStream in = getReadFileStream(args.axiom_docids);
        BufferedReader bRdr = new BufferedReader(new InputStreamReader(in));
        return IOUtils.readLines(bRdr);
    }

    /**
     * If the result is deterministic we can cache all the docids. All queries can share this
     * cache.
     */
    private ScoreDoc[] buildInternalDocidsCache(SearchArgs args) throws IOException {
        String index = args.axiom_index == null ? args.index : args.axiom_index;
        Path indexPath = Paths.get(index);
        if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
            throw new IllegalArgumentException(index + " does not exist or is not a directory.");
        }
        IndexReader reader = DirectoryReader.open(FSDirectory.open(indexPath));
        IndexSearcher searcher = new IndexSearcher(reader);
        if (args.searchtweets) {
            return searcher.search(new FieldValueQuery(TweetGenerator.StatusField.ID_LONG.name), reader.maxDoc(),
                    BREAK_SCORE_TIES_BY_TWEETID).scoreDocs;
        }
        return searcher.search(new FieldValueQuery(LuceneDocumentGenerator.FIELD_ID), reader.maxDoc(),
                BREAK_SCORE_TIES_BY_DOCID).scoreDocs;
    }

    /**
     * If the external reranking context is not null we will first search against the external
     * index and return the top ranked documents.
     *
     * @param docs The initial ranking results against target index. We will return them if external
     *             index is null.
     *
     * @return Top ranked ScoredDocuments from searching external index
     */
    private ScoredDocuments processExternalContext(ScoredDocuments docs, RerankerContext<T> context)
            throws IOException {
        if (externalIndexPath != null) {
            Path indexPath = Paths.get(this.externalIndexPath);
            if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
                throw new IllegalArgumentException(
                        this.externalIndexPath + " does not exist or is not a directory.");
            }
            IndexReader reader = DirectoryReader.open(FSDirectory.open(indexPath));
            IndexSearcher searcher = new IndexSearcher(reader);
            searcher.setSimilarity(context.getIndexSearcher().getSimilarity(true));

            SearchArgs args = new SearchArgs();
            args.hits = this.R;
            args.arbitraryScoreTieBreak = context.getSearchArgs().arbitraryScoreTieBreak;
            args.searchtweets = context.getSearchArgs().searchtweets;

            RerankerContext<T> externalContext = new RerankerContext<>(searcher, context.getQueryId(),
                    context.getQuery(), context.getQueryText(), context.getQueryTokens(), context.getFilter(),
                    args);

            return searchTopDocs(null, externalContext);
        } else {
            return docs;
        }
    }

    /**
     * Select {@code R*N} docs from the ranking results and the index as the reranking pool.
     * The process is:
     * 1. Keep the top R documents in the original ranking list
     * 2. Randomly pick {@code (N-1)*R} documents from the rest of the index so in total we have R*M documents
     *
     * @param docs The initial ranking results
     * @param context An instance of RerankerContext
     * @return a Set of {@code R*N} document Ids
     */
    private Set<Integer> selectDocs(ScoredDocuments docs, RerankerContext<T> context) throws IOException {
        Set<Integer> docidSet = new HashSet<>(Arrays
                .asList(ArrayUtils.toObject(Arrays.copyOfRange(docs.ids, 0, Math.min(this.R, docs.ids.length)))));
        long targetSize = this.R * this.N;

        if (docidSet.size() < targetSize) {
            IndexReader reader;
            IndexSearcher searcher;
            if (this.externalIndexPath != null) {
                Path indexPath = Paths.get(this.externalIndexPath);
                if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
                    throw new IllegalArgumentException(
                            this.externalIndexPath + " does not exist or is not a directory.");
                }
                reader = DirectoryReader.open(FSDirectory.open(indexPath));
                searcher = new IndexSearcher(reader);
            } else {
                searcher = context.getIndexSearcher();
                reader = searcher.getIndexReader();
            }
            int availableDocsCnt = reader.getDocCount(this.field);
            if (this.deterministic) { // internal docid cannot be relied due to multi-threads indexing,
                                      // we have to rely on external docid here
                Random random = new Random(this.seed);
                while (docidSet.size() < targetSize) {
                    if (this.externalDocidsCache != null) {
                        String docid = this.externalDocidsCache
                                .get(random.nextInt(this.externalDocidsCache.size()));
                        Query q = new TermQuery(new Term(LuceneDocumentGenerator.FIELD_ID, docid));
                        TopDocs rs = searcher.search(q, 1);
                        docidSet.add(rs.scoreDocs[0].doc);
                    } else {
                        docidSet.add(this.internalDocidsCache[random.nextInt(this.internalDocidsCache.length)].doc);
                    }
                }
            } else {
                Random random = new Random();
                while (docidSet.size() < targetSize) {
                    docidSet.add(random.nextInt(availableDocsCnt));
                }
            }
        }

        return docidSet;
    }

    /**
     * Extract ALL the terms from the documents pool.
     *
     * @param docIds The reranking pool, see {@link #selectDocs} for explanations
     * @param context An instance of RerankerContext
     * @param filterPattern A Regex pattern that terms are collected only they matches the pattern, could be null
     * @return A Map of <term -> Set<docId>> kind of a small inverted list where the Set of docIds is where the term occurs
     */
    private Map<String, Set<Integer>> extractTerms(Set<Integer> docIds, RerankerContext<T> context,
            Pattern filterPattern) throws Exception, IOException {
        IndexReader reader;
        IndexSearcher searcher;
        if (this.externalIndexPath != null) {
            Path indexPath = Paths.get(this.externalIndexPath);
            if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
                throw new IllegalArgumentException(
                        this.externalIndexPath + " does not exist or is not a directory.");
            }
            reader = DirectoryReader.open(FSDirectory.open(indexPath));
            searcher = new IndexSearcher(reader);
        } else {
            searcher = context.getIndexSearcher();
            reader = searcher.getIndexReader();
        }
        Map<String, Set<Integer>> termDocidSets = new HashMap<>();
        for (int docid : docIds) {
            Terms terms = reader.getTermVector(docid, LuceneDocumentGenerator.FIELD_BODY);
            if (terms == null) {
                LOG.warn("Document vector not stored for docid: " + docid);
                continue;
            }
            TermsEnum te = terms.iterator();
            if (te == null) {
                LOG.warn("Document vector not stored for docid: " + docid);
                continue;
            }
            while ((te.next()) != null) {
                String term = te.term().utf8ToString();
                // We do some noisy filtering here ... pure empirical heuristic
                if (term.length() < 2)
                    continue;
                if (!term.matches("[a-z]+"))
                    continue;
                if (filterPattern == null || filterPattern.matcher(term).matches()) {
                    if (!termDocidSets.containsKey(term)) {
                        termDocidSets.put(term, new HashSet<>());
                    }
                    termDocidSets.get(term).add(docid);
                }
            }
        }
        return termDocidSets;
    }

    /**
     * Calculate the scores (weights) of each term that occured in the reranking pool.
     * The Process:
     * 1. For each query term, calculate its score for each term in the reranking pool. the score
     * is calcuated as
     * <pre>
     * P(both occurs)*log{P(both occurs)/P(t1 occurs)/P(t2 occurs)}
     * + P(both not occurs)*log{P(both not occurs)/P(t1 not occurs)/P(t2 not occurs)}
     * + P(t1 occurs t2 not occurs)*log{P(t1 occurs t2 not occurs)/P(t1 occurs)/P(t2 not occurs)}
     * + P(t1 not occurs t2 occurs)*log{P(t1 not occurs t2 occurs)/P(t1 not occurs)/P(t2 occurs)}
     * </pre>
     * 2. For each query term the scores of every other term in the reranking pool are stored in a
     * PriorityQueue, only the top {@code K} are kept.
     * 3. Add the scores of the same term together and pick the top {@code M} ones.
     *
     * @param termInvertedList A Map of <term -> Set<docId>> where the Set of docIds is where the term occurs
     * @param context An instance of RerankerContext
     * @return Map<String, Double> Top terms and their weight scores in a HashMap
     */
    private Map<String, Double> computeTermScore(Map<String, Set<Integer>> termInvertedList,
            RerankerContext<T> context) throws IOException {
        class ScoreComparator implements Comparator<Pair<String, Double>> {
            public int compare(Pair<String, Double> a, Pair<String, Double> b) {
                int cmp = Double.compare(b.getRight(), a.getRight());
                if (cmp == 0) {
                    return a.getLeft().compareToIgnoreCase(b.getLeft());
                } else {
                    return cmp;
                }
            }
        }

        // get collection statistics so that we can get idf later on.
        IndexReader reader;
        if (this.externalIndexPath != null) {
            Path indexPath = Paths.get(this.externalIndexPath);
            if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
                throw new IllegalArgumentException(
                        this.externalIndexPath + " does not exist or is not a directory.");
            }
            reader = DirectoryReader.open(FSDirectory.open(indexPath));
        } else {
            IndexSearcher searcher = context.getIndexSearcher();
            reader = searcher.getIndexReader();
        }
        final long docCount = reader.numDocs() == -1 ? reader.maxDoc() : reader.numDocs();

        //calculate the Mutual Information between term with each query term
        List<String> queryTerms = context.getQueryTokens();
        Map<String, Integer> queryTermsCounts = new HashMap<>();
        for (String qt : queryTerms) {
            queryTermsCounts.put(qt, queryTermsCounts.getOrDefault(qt, 0) + 1);
        }

        Set<Integer> allDocIds = new HashSet<>();
        for (Set<Integer> s : termInvertedList.values()) {
            allDocIds.addAll(s);
        }
        int docIdsCount = allDocIds.size();

        // Each priority queue corresponds to a query term: The p-queue itself stores all terms
        // in the reranking pool and their reranking scores to the query term.
        List<PriorityQueue<Pair<String, Double>>> allTermScoresPQ = new ArrayList<>();
        for (Map.Entry<String, Integer> q : queryTermsCounts.entrySet()) {
            String queryTerm = q.getKey();
            long df = reader.docFreq(new Term(LuceneDocumentGenerator.FIELD_BODY, queryTerm));
            if (df == 0L) {
                continue;
            }
            float idf = (float) Math.log((1 + docCount) / df);
            int qtf = q.getValue();
            if (termInvertedList.containsKey(queryTerm)) {
                PriorityQueue<Pair<String, Double>> termScorePQ = new PriorityQueue<>(new ScoreComparator());
                double selfMI = computeMutualInformation(termInvertedList.get(queryTerm),
                        termInvertedList.get(queryTerm), docIdsCount);
                for (Map.Entry<String, Set<Integer>> termEntry : termInvertedList.entrySet()) {
                    double score;
                    if (termEntry.getKey().equals(queryTerm)) { // The mutual information to itself will always be 1
                        score = idf * qtf;
                    } else {
                        double crossMI = computeMutualInformation(termInvertedList.get(queryTerm),
                                termEntry.getValue(), docIdsCount);
                        score = idf * beta * qtf * crossMI / selfMI;
                    }
                    termScorePQ.add(Pair.of(termEntry.getKey(), score));
                }
                allTermScoresPQ.add(termScorePQ);
            }
        }

        Map<String, Double> aggTermScores = new HashMap<>();
        for (PriorityQueue<Pair<String, Double>> termScores : allTermScoresPQ) {
            for (int i = 0; i < Math.min(termScores.size(), this.K); i++) {
                Pair<String, Double> termScore = termScores.poll();
                String term = termScore.getLeft();
                Double score = termScore.getRight();
                if (score - 0.0 > 1e-8) {
                    aggTermScores.put(term, aggTermScores.getOrDefault(term, 0.0) + score);
                }
            }
        }
        PriorityQueue<Pair<String, Double>> termScoresPQ = new PriorityQueue<>(new ScoreComparator());
        for (Map.Entry<String, Double> termScore : aggTermScores.entrySet()) {
            termScoresPQ.add(Pair.of(termScore.getKey(), termScore.getValue() / queryTerms.size()));
        }
        Map<String, Double> resultTermScores = new HashMap<>();
        for (int i = 0; i < Math.min(termScoresPQ.size(), this.M); i++) {
            Pair<String, Double> termScore = termScoresPQ.poll();
            String term = termScore.getKey();
            double score = termScore.getValue();
            resultTermScores.put(term, score);
        }

        return resultTermScores;
    }

    private double computeMutualInformation(Set<Integer> docidsX, Set<Integer> docidsY, int totalDocCount) {
        int x1 = docidsX.size(), y1 = docidsY.size(); //document that x occurres
        int x0 = totalDocCount - x1, y0 = totalDocCount - y1; //document num that x doesn't occurres

        if (x1 == 0 || x0 == 0 || y1 == 0 || y0 == 0) {
            return 0;
        }

        float pX0 = 1.0f * x0 / totalDocCount;
        float pX1 = 1.0f * x1 / totalDocCount;
        float pY0 = 1.0f * y0 / totalDocCount;
        float pY1 = 1.0f * y1 / totalDocCount;

        //get the intersection of docIds
        Set<Integer> docidsXClone = new HashSet<>(docidsX); // directly operate on docidsX will change it permanently
        docidsXClone.retainAll(docidsY);
        int numXY11 = docidsXClone.size();
        int numXY10 = numXY10 = x1 - numXY11; //doc num that x occurs but y doesn't
        int numXY01 = y1 - numXY11; // doc num that y occurs but x doesn't
        int numXY00 = totalDocCount - numXY11 - numXY10 - numXY01; //doc num that neither x nor y occurs

        float pXY11 = 1.0f * numXY11 / totalDocCount;
        float pXY10 = 1.0f * numXY10 / totalDocCount;
        float pXY01 = 1.0f * numXY01 / totalDocCount;
        float pXY00 = 1.0f * numXY00 / totalDocCount;

        double m00 = 0, m01 = 0, m10 = 0, m11 = 0;
        if (pXY00 != 0)
            m00 = pXY00 * Math.log(pXY00 / (pX0 * pY0));
        if (pXY01 != 0)
            m01 = pXY01 * Math.log(pXY01 / (pX0 * pY1));
        if (pXY10 != 0)
            m10 = pXY10 * Math.log(pXY10 / (pX1 * pY0));
        if (pXY11 != 0)
            m11 = pXY11 * Math.log(pXY11 / (pX1 * pY1));
        return m00 + m10 + m01 + m11;
    }
}