IrqaQuery.java Source code

Java tutorial

Introduction

Here is the source code for IrqaQuery.java

Source

/**
 * Copyright 2016, Emory University
 * <p/>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p/>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p/>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.core.StopFilter;
import org.apache.lucene.analysis.en.EnglishAnalyzer;

import org.apache.lucene.document.*;
import org.apache.lucene.index.*;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.TFIDFSimilarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;

import java.io.*;
import java.nio.file.*;
import java.util.*;

import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;

/**
 * @author Bonggun Shin ({@code bonggun.shin@emory.edu}).
 */
public class IrqaQuery {
    static public IndexWriter writer;

    public IrqaQuery() {
    }

    public static void makeIndexWriter(String indexPath, String stopPath, String sim) throws IOException {
        System.out.println("[makeIndexWriter] started");
        System.out.println("[makeIndexWriter]" + stopPath);
        Directory dir = FSDirectory.open(Paths.get(indexPath));
        Analyzer analyzer = new EnglishAnalyzer(StopFilter.makeStopSet(mygetStopwords(stopPath)));
        IndexWriterConfig iwc = new IndexWriterConfig(analyzer);

        if (sim.equals("TFIDF"))
            iwc.setSimilarity(new ClassicSimilarity());
        else if (sim.equals("BM25"))
            iwc.setSimilarity(new BM25Similarity());
        else
            iwc.setSimilarity(new BM25Similarity());

        writer = new IndexWriter(dir, iwc);
    }

    public static void indexDoc(String docid, String... args) throws IOException {
        //        docid, title, contents,...
        Document doc = new Document();

        Field pathField = new StringField("docid", docid, Field.Store.YES);
        doc.add(pathField);

        for (int i = 0; i < args.length; i += 2) {
            String field = args[i];
            String field_text = args[i + 1];
            doc.add(new TextField(field, field_text, Field.Store.NO));
            //            System.out.println("[doc.add]" + path + ":" + field + ":" + field_text);
        }

        System.out.println("adding " + docid);
        writer.addDocument(doc);
    }

    public static List<String> mygetStopwords(String stopFile) {
        List<String> stopwords = new ArrayList<>();
        String line;

        try (FileReader fr = new FileReader(stopFile); BufferedReader br = new BufferedReader(fr)) {
            while ((line = br.readLine()) != null) {
                stopwords.add(line.trim());
            }
            br.close();
        } catch (Exception ex) {
            ex.printStackTrace();
        }
        return stopwords;
    }

    public static List<Document> query(String index, String stoppath, String question, int numResult, String sim)
            throws Exception {
        IndexReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index)));
        IndexSearcher searcher = new IndexSearcher(reader);

        Analyzer analyzer = new EnglishAnalyzer(StopFilter.makeStopSet(mygetStopwords(stoppath)));

        if (sim.equals("TFIDF"))
            searcher.setSimilarity(new ClassicSimilarity());
        else if (sim.equals("BM25"))
            searcher.setSimilarity(new BM25Similarity());
        else
            searcher.setSimilarity(new BM25Similarity());

        String field = "contents";
        QueryParser parser = new QueryParser(field, analyzer);
        Query query = parser.parse(parser.escape(question));

        TopDocs results = searcher.search(query, numResult);
        ScoreDoc[] hits = results.scoreDocs;
        List<Document> docs = new ArrayList<Document>();

        int numTotalHits = results.totalHits;
        //        System.out.println(numTotalHits + " total matching documents");

        int end = Math.min(numTotalHits, numResult);

        String searchResult = "";
        //        System.out.println("Only results 1 - " + hits.length);

        for (int i = 0; i < end; i++) {
            Document doc = searcher.doc(hits[i].doc);
            docs.add(doc);
        }

        return docs;
    }

    public static void batch_query(String basedir, String indexpath) throws Exception {
        indexpath = basedir + "/index_all" + indexpath + "/";
        String stopwords = basedir + "/stopwords.txt";
        IrqaQuery lp = new IrqaQuery();

        JSONParser parser = new JSONParser();
        JSONArray questions = (JSONArray) parser.parse(new FileReader(basedir + "/data/questions.json"));

        long startTime = System.currentTimeMillis();
        int answercount = 0;
        int questioncount = 0;
        for (Object o : questions) {
            JSONObject q = (JSONObject) o;

            String query = (String) q.get("question");
            String gold_id = (String) q.get("paragraph_id");

            List<Document> docs = lp.query(indexpath, stopwords, query, 5, "BM25");

            questioncount++;
            for (Document d : docs) {
                String docid = d.get("docid");

                if (docid.equals(gold_id)) {
                    //                    System.out.println(docid);
                    answercount = answercount + 1;
                    break;
                }
            }
            if (questioncount % 1000 == 0) {
                long midtime = System.currentTimeMillis() - startTime;

                System.out.format("[%d] midtime=%f\n", questioncount, midtime / 1000.0);
            }
        }
        System.out.format("acc=%f\t%d\t%d\n", answercount * 1.0 / questioncount * 100, answercount, questioncount);
        long estimatedTime = System.currentTimeMillis() - startTime;
        System.out.println(estimatedTime / 1000.0);
    }

    public static void get_sentence_from_json(JSONArray raw_list, String question, String docid, BufferedWriter out)
            throws Exception {
        for (Object o : raw_list) {
            JSONObject rl = (JSONObject) o;

            String query = (String) rl.get("question");
            String pid = (String) rl.get("paragraph_id");

            if (query.compareTo(question) == 0) {

            }

            if (pid.compareTo(docid) == 0 && query.compareTo(question) == 0) {
                // if docid is matched
                // get candidate index
                List<Integer> candidate_list = new ArrayList<Integer>();
                //                if (rl.get("candidates").toString().length()>1){
                //                System.out.println(rl.get("candidates").toString());
                String[] candidates = rl.get("candidates").toString().split(",");
                for (String cand : candidates) {
                    //                    System.out.println(cand);
                    //                    System.out.println(rl.get("candidates").toString());
                    candidate_list.add(Integer.parseInt(cand.replace(" ", "")) - 1);
                }
                //                }

                // print with candidate 0/1
                int index_of_sen = 0;
                for (Object sen : (JSONArray) rl.get("sentences")) {
                    int zero_one = 0;
                    for (int cand : candidate_list) {
                        if (index_of_sen == cand)
                            zero_one = 1;
                    }
                    String out_format = String.format("%s\t%s\t%d\n", question, sen, zero_one);
                    out.write(out_format);
                    index_of_sen++;
                }
            } else {
                // print with candidate 0
                for (Object sen : (JSONArray) rl.get("sentences")) {
                    //                    System.out.format("%s\t%s\t%d\n", question, sen, 0);
                    String out_format = String.format("%s\t%s\t%d\n", question, sen, 0);
                    out.write(out_format);
                }
            }

        }
    }

    public static void pipeline(String basedir, String indexpath, String set, JSONObject lookup_sent)
            throws Exception {
        System.out.println(set + " started...");
        String index = basedir + "/index_all" + indexpath + "/";

        String stopwords = basedir + "/stopwords.txt";
        IrqaQuery lp = new IrqaQuery();

        String answer_filename = String.format(basedir + "/stats/data_for_analysis/newTACL/%s_raw_list.json", set);
        String file = String.format(basedir + "/stats/data_for_analysis/newTACL/WikiQASent-%s.txt", set);

        //        String lookup_8kfn = basedir+"/data/wikilookup_8k.json";
        String documents2_fn = basedir + "/data/documents2.json";

        JSONParser parser = new JSONParser();
        JSONArray answer_list = (JSONArray) parser.parse(new FileReader(answer_filename));

        //        Object obj2 = parser.parse(new FileReader(lookup_8kfn));
        //        JSONObject lookup_8k = (JSONObject) obj2;

        Object obj3 = parser.parse(new FileReader(documents2_fn));
        JSONArray documents2 = (JSONArray) obj3;

        List<String> questions = new ArrayList<>();

        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file)));

        String outfilename = String.format(basedir + "/stats/data_for_analysis/newTACL/newsplit%s_%s.txt",
                indexpath, set);
        BufferedWriter outfile = new BufferedWriter(new FileWriter(outfilename));

        int numline = 0;

        ArrayList<ArrayList<String>> sentlistAll = new ArrayList<ArrayList<String>>();
        ArrayList<ArrayList<String>> alistAll = new ArrayList<ArrayList<String>>();

        try {
            String r;
            String cquestion = "";

            ArrayList<String> sentlist = new ArrayList<>();
            ArrayList<String> alist = new ArrayList<>();

            while ((r = br.readLine()) != null) {
                numline++;
                String[] line = r.split("\t");

                if (cquestion.compareTo(line[0]) != 0) {
                    if (cquestion.compareTo("") != 0) {
                        sentlistAll.add(sentlist);
                        alistAll.add(alist);
                        questions.add(cquestion);
                    }

                    sentlist = new ArrayList<>();
                    alist = new ArrayList<>();
                    sentlist.add(line[1]);
                    alist.add(line[2]);

                    cquestion = line[0];
                } else {
                    sentlist.add(line[1]);
                    alist.add(line[2]);
                }
            }

            sentlistAll.add(sentlist);
            alistAll.add(alist);
            questions.add(cquestion);

        } finally {
            br.close();
        }

        System.out.println(questions.size());

        for (int i = 0; i < questions.size(); i++) {
            String query = questions.get(i);
            List<Document> docs = lp.query(index, stopwords, query, 5, "BM25");
            //            Object o = (Object) answer_list.get(0);
            JSONObject rl = (JSONObject) answer_list.get(i);
            String gold_pid = (String) rl.get("paragraph_id");
            //            String gold_q =(String) rl.get("question");

            for (Document d : docs) {
                String docid = d.get("docid");

                if (gold_pid.compareTo(docid) == 0) {
                    //                    get sentences from gold (alistAll, sentlistAll)
                    for (int j = 0; j < sentlistAll.get(i).size(); j++) {
                        if (sentlistAll.get(i).get(j).length() < 1 || sentlistAll.get(i).get(j).compareTo(" ") == 0
                                || sentlistAll.get(i).get(j).compareTo("  ") == 0
                                || sentlistAll.get(i).get(j).compareTo("''") == 0
                                || sentlistAll.get(i).get(j).compareTo("   ") == 0)
                            continue;
                        String outstring = String.format("%s\t%s\t%s\n", query, sentlistAll.get(i).get(j),
                                alistAll.get(i).get(j));
                        outfile.write(outstring);
                    }
                } else {
                    //                    get_sentence_from_lookup();
                    //                    lookup_sent.get(docid)
                    //                    JSONArray sents = (JSONArray) lookup_sent.get("Timeline_of_classical_mechanics-Abstract");
                    JSONArray sents = (JSONArray) lookup_sent.get(docid);

                    if (sents == null) {
                        System.out.println("noway, " + docid + "\n");
                    } else {
                        for (int kk = 0; kk < sents.size(); kk++) {
                            if (sents.get(kk).toString().length() < 1
                                    || sents.get(kk).toString().compareTo(" ") == 0
                                    || sents.get(kk).toString().compareTo("  ") == 0
                                    || sents.get(kk).toString().compareTo("''") == 0
                                    || sents.get(kk).toString().compareTo("   ") == 0)
                                continue;
                            String outstring = String.format("%s\t%s\t%s\n", query, sents.get(kk).toString(), "0");
                            outfile.write(outstring);

                            //                            System.out.printf("%s\t%s\t%s\n", query, sents.get(kk).toString(), "0");
                            //                            System.out.println(sents.get(kk));
                        }
                    }
                }
            }
        }

        outfile.close();

        //        System.out.println(raw_list.size());
        System.out.println(numline);
    }

    /** Simple command-line based search demo. */
    public static void main(String[] args) throws Exception {
        String basedir = "/Users/bong/works/research/irqa";
        //        String basedir = "/home/bgshin/works/irqa";

        List<String> exps = new ArrayList<>();

        exps.add("_c_2048");
        //        exps.add("_c_1024");
        //        exps.add("_c_512");
        //        exps.add("_c_256");
        //        exps.add("_c_128");
        //        exps.add("_c_64");
        //        exps.add("_c_32");
        //        exps.add("_c_16");
        //        exps.add("_c_8");
        //        exps.add("_c_4");
        //        exps.add("_c_2");
        //        exps.add("_c_0");

        for (int i = 0; i < exps.size(); i++) {
            String indexpath = exps.get(i);
            batch_query(basedir, indexpath);
        }

        // pipeline //////////////////////////////////////////////////////////////

        //        JSONParser parser = new JSONParser();
        //        String lookup_sentfn = basedir+"/data/wikilookup_clean_sentence.json";
        //        Object obj1 = parser.parse(new FileReader(lookup_sentfn));
        //        JSONObject lookup_sent = (JSONObject) obj1;
        //
        //        for (int i=0; i<exps.size(); i++) {
        //            String indexpath = exps.get(i);
        //            pipeline(basedir, indexpath, "dev", lookup_sent);
        //            pipeline(basedir, indexpath, "test", lookup_sent);
        //            pipeline(basedir, indexpath, "train", lookup_sent);
        //        }
        // pipeline //////////////////////////////////////////////////////////////

    }
}