org.index.TermScore.java Source code

Java tutorial

Introduction

Here is the source code for org.index.TermScore.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */

package org.index;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.TreeMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.analysis.payloads.PayloadHelper;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.payloads.AveragePayloadFunction;
import org.apache.lucene.search.payloads.PayloadFunction;
import org.apache.lucene.search.payloads.PayloadTermQuery;
import org.apache.lucene.search.similarities.BasicStats;
import org.apache.lucene.search.similarities.LMJelinekMercerSimilarity;
import org.apache.lucene.search.similarities.LMSimilarity;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
import org.sostruct.Question;

/**
 *
 * @author dganguly
 */

class TermScore implements Comparable<TermScore> {
    String term;
    float score;
    int freq;

    public TermScore(String term, float score) {
        this.term = term;
        this.score = score;
    }

    @Override
    public int compareTo(TermScore that) {
        return score < that.score ? 1 : score == that.score ? 0 : -1;
    }
}

class SOQuery {
    int qid;
    Query luceneQueryObj;

    public SOQuery(int qid, Query luceneQueryObj) {
        this.qid = qid;
        this.luceneQueryObj = luceneQueryObj;
    }

    public String toString() {
        return luceneQueryObj.toString();
    }
}

// A simple query constructor which takes all terms to formulate a query.
class QueryConstructor {
    IndexReader reader;
    float lambda;
    float tmLambda;
    Properties prop;
    TreeMap<String, TermScore> tfMap;

    public QueryConstructor(Properties prop, IndexReader reader) {
        this.prop = prop;
        this.reader = reader;
        lambda = Float.parseFloat(prop.getProperty("retrieve.lambda", "0.3"));
        tmLambda = Float.parseFloat(prop.getProperty("retrieve.tm_lambda", "0.3"));
        tfMap = new TreeMap<>();
    }

    private List<String> getBagOfWords(String text) throws Exception {

        List<String> terms = new ArrayList<>();
        text = Question.removeTags(text);

        boolean toStem = Boolean.parseBoolean(prop.getProperty("stem", "true"));
        String stopFile = prop.getProperty("stopfile");
        Analyzer analyzer = new WhitespaceAnalyzer(Version.LUCENE_4_9);
        /*SOAnalyzer(toStem, stopFile)*/;
        TokenStream stream = analyzer.tokenStream("bow", new StringReader(text));
        CharTermAttribute termAtt = stream.addAttribute(CharTermAttribute.class);
        stream.reset();

        while (stream.incrementToken()) {
            String term = termAtt.toString();
            terms.add(term);
        }

        stream.end();
        stream.close();

        return terms;
    }

    // Take as input a StackOverflow document and return the list
    // of terms
    BooleanQuery construct(List<String> terms, String[] tagTerms) {
        BooleanQuery q = new BooleanQuery();
        PayloadFunction pf = new AveragePayloadFunction();

        Boolean useTags = Boolean.parseBoolean(prop.getProperty("retrieve.use_tags", "false"));
        Term thisTerm;

        //System.out.println("About to add " + terms.size() + " query terms..");
        for (String term : terms) {
            thisTerm = new Term(QuestionIndexer.FIELD_BODY_WITH_TOPICS, term);
            q.add(new LMLinearCombinationTermQuery(thisTerm, pf, lambda, tmLambda), BooleanClause.Occur.SHOULD);
        }

        if (useTags) {
            for (String term : tagTerms) {
                thisTerm = new Term(QuestionIndexer.FIELD_TAGS, term);
                q.add(new TermQuery(thisTerm), BooleanClause.Occur.SHOULD);
            }
        }
        return q;
    }

    List<String> getSelectedTerms(String text) throws Exception {
        int nTerms = Integer.parseInt(prop.getProperty("query.nterms", "10"));
        List<String> terms = new ArrayList<String>();
        TreeMap<String, TermScore> tfMap = new TreeMap<>();
        String[] wtWords = text.split("\\s+");
        int doclen = wtWords.length;
        float termScore;
        float q_lambda = Float.parseFloat(prop.getProperty("query.lambda", "0.4"));
        float q_tmLambda = Float.parseFloat(prop.getProperty("query.tm_lambda", "0.4"));

        float collectionProb, collWt = (1 - q_lambda - q_tmLambda);
        long numDocs = reader.numDocs();
        TermScore ts;
        List<TermScore> termScores = new ArrayList<>();

        // Collect the term so as to compute the frequencies
        for (String wtWord : wtWords) {
            String[] tokens = wtWord.split("\\" + String.valueOf(PayloadAnalyzer.delim));
            String word = tokens[0];
            float tmProb = Float.parseFloat(tokens[1]);
            ts = tfMap.get(word);
            if (ts == null)
                ts = new TermScore(word, tmProb);
            ts.freq++;
            tfMap.put(word, ts);
        }

        // Revisit the map in second pass and compute scores
        for (Map.Entry<String, TermScore> entry : tfMap.entrySet()) {
            ts = entry.getValue();
            Term thisTerm = new Term(QuestionIndexer.FIELD_BODY_WITH_TOPICS, ts.term);
            collectionProb = (float) (numDocs / (double) reader.totalTermFreq(thisTerm)); // reciprocal
            termScore = (float) (Math.log(1 + collectionProb
                    * (q_lambda / collWt * ts.freq / (float) doclen + q_tmLambda / collWt * ts.score)));
            ts.score = termScore;
            termScores.add(ts);
        }

        // Sort the terms by the score values and return the top ones
        Collections.sort(termScores);
        int termCount = 0;
        for (TermScore topTermScore : termScores) {
            if (termCount >= nTerms)
                break;
            terms.add(topTermScore.term);
            termCount++;
            System.out.print(topTermScore.term + " ");
        }
        System.out.println();

        return terms;
    }

    public List<SOQuery> constructQueries() throws Exception {
        List<SOQuery> queryList = new LinkedList<SOQuery>();
        boolean allTerms = true;
        BooleanQuery.setMaxClauseCount(20000);

        String queryFile = prop.getProperty("query.tsv.file");
        if (queryFile == null) {
            // the weighted tsv files
            queryFile = prop.getProperty("query.wtsv.file");
            allTerms = false;
        }

        FileReader fr = new FileReader(queryFile);
        BufferedReader br = new BufferedReader(fr);
        String line;
        List<String> terms;

        while ((line = br.readLine()) != null) {
            String[] tokens = line.split("\t");
            String tags = tokens[1];
            String[] tagTerms = tokens[1].split("\\,");
            // the first token is the question id, the second is a list
            // of labels...
            if (allTerms)
                terms = getBagOfWords(tokens[2]);
            else
                terms = getSelectedTerms(tokens[2]);

            Query q = construct(terms, tagTerms);
            queryList.add(new SOQuery(Integer.parseInt(tokens[0]), q));
        }
        return queryList;
    }
}

class ScoreDocComparator implements Comparator<ScoreDoc> {

    @Override
    public int compare(ScoreDoc o1, ScoreDoc o2) {
        return o1.score < o2.score ? 1 : o1.score == o2.score ? 0 : -1;
    }
}

public class SOQuestionSearcher {
    IndexReader reader;
    IndexSearcher searcher;
    Properties prop;
    int numWanted;
    QueryConstructor qc;
    HashMap<Integer, Float> docScorePredictionMap;
    boolean isSupervised;

    public SOQuestionSearcher(String propFile) throws Exception {
        prop = new Properties();
        prop.load(new FileReader(propFile));

        String className = prop.getProperty("retrieve.topicmodel", "plda");
        String index_dir = null;
        index_dir = prop.getProperty(className + ".index");

        System.out.println("Running queries against index: " + index_dir);
        boolean useTopicWeights = Boolean.parseBoolean(prop.getProperty("retrieve.use_topics", "true"));
        try {
            File indexDir = new File(index_dir);
            reader = DirectoryReader.open(FSDirectory.open(indexDir));
            searcher = new IndexSearcher(reader);

            searcher.setSimilarity(new GeneralizedLMSimilarity(useTopicWeights));

            numWanted = Integer.parseInt(prop.getProperty("retrieve.num_wanted", "100"));
            qc = new QueryConstructor(prop, reader);
        } catch (Exception ex) {
            ex.printStackTrace();
        }

        /*
        if (className.equals("splda") || className.equals("slda")) {
        isSupervised = true;
        String predictionsFile = prop.getProperty("slda.infer.predictions");
        docScorePredictionMap = new HashMap<>();
        FileReader fr = new FileReader(predictionsFile);
        BufferedReader br = new BufferedReader(fr);
        String line;
        while ((line = br.readLine()) != null) {
            String[] tokens = line.split("\\s+");
            int docId = Integer.parseInt(tokens[0]);
            float score = Float.parseFloat(tokens[1]);
            score = (score + 1)/2;  // [-1,1] => [0,1]
            docScorePredictionMap.put(docId, score);
        }
        br.close();
        fr.close();
        }
        */
    }

    public void retrieveAll() throws Exception {
        ScoreDoc[] hits;
        TopDocs topDocs;
        String runName = prop.getProperty("retrieve.runname", "baseline");

        String resultsFile = prop.getProperty("retrieve.results_file");
        FileWriter fw = new FileWriter(resultsFile);

        List<SOQuery> queries = qc.constructQueries();

        for (SOQuery thisQuery : queries) {

            //System.out.println(thisQuery);
            TopScoreDocCollector collector = TopScoreDocCollector.create(numWanted, true);
            Query query = thisQuery.luceneQueryObj;
            searcher.search(query, collector);
            topDocs = collector.topDocs();
            hits = topDocs.scoreDocs;

            if (isSupervised) {
                float predFactor;
                int docId;
                // Rerank the results based on the predicted scores
                for (int i = 0; i < hits.length; ++i) {
                    Document d = searcher.doc(hits[i].doc);
                    docId = Integer.parseInt(d.get(QuestionIndexer.FIELD_ID));
                    predFactor = docScorePredictionMap.get(docId);
                    hits[i].score = 0.5f * predFactor + 0.5f * hits[i].score;
                }
                Arrays.sort(hits, new ScoreDocComparator());
            }

            StringBuffer buff = new StringBuffer();
            for (int i = 0; i < hits.length; ++i) {
                int docId = hits[i].doc;
                Document d = searcher.doc(docId);

                buff.append(thisQuery.qid).append("\tQ0\t").append(d.get(QuestionIndexer.FIELD_ID)).append("\t")
                        .append((i + 1)).append("\t").append(hits[i].score).append("\t").append(runName)
                        .append("\n");
            }
            fw.write(buff.toString());
        }
        fw.close();
        reader.close();
    }

    public static void main(String[] args) {
        if (args.length < 1) {
            args = new String[1];
            args[0] = "retrieve.properties";
        }
        try {
            SOQuestionSearcher searcher = new SOQuestionSearcher(args[0]);
            searcher.retrieveAll();
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }
}