ivory.cascade.retrieval.CascadeThreadedQueryRunner.java Source code

Java tutorial

Introduction

Here is the source code for ivory.cascade.retrieval.CascadeThreadedQueryRunner.java

Source

/*
 * Ivory: A Hadoop toolkit for web-scale information retrieval
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you
 * may not use this file except in compliance with the License. You may
 * obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0 
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package ivory.cascade.retrieval;

import ivory.core.RetrievalEnvironment;
import ivory.core.exception.ConfigurationException;
import ivory.smrf.model.MarkovRandomField;
import ivory.smrf.model.builder.MRFBuilder;
import ivory.smrf.model.expander.MRFExpander;
import ivory.smrf.retrieval.Accumulator;
import ivory.smrf.retrieval.MRFDocumentRanker;

import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.log4j.Logger;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;

/**
 * @author Lidan Wang
 */
public class CascadeThreadedQueryRunner implements CascadeQueryRunner {
    private static final Logger sLogger = Logger.getLogger(CascadeThreadedQueryRunner.class);

    private MRFBuilder mBuilder;
    private MRFExpander mExpander;
    private ExecutorService mThreadPool;
    private Map<String, Future<Accumulator[]>> mQueryResults;
    private int mNumHits;
    private Map<Integer, Float[][]> savedResults_prevStage = Maps.newHashMap(); //for all queries
    private int mK; //K value used in cascade model

    //assume no more than 1000 queries
    private float[] cascadeCostAllQueries = new float[1000];
    private float[] cascadeCostAllQueries_lastStage = new float[1000];

    public CascadeThreadedQueryRunner(MRFBuilder builder, MRFExpander expander, int numThreads, int numHits,
            Map<Integer, Float[][]> savedResults, int K) {
        Preconditions.checkNotNull(builder);

        assert (numThreads > 0);
        assert (numHits > 0);

        mBuilder = builder;
        mExpander = expander;
        mThreadPool = Executors.newFixedThreadPool(numThreads);
        mQueryResults = Maps.newLinkedHashMap();
        mNumHits = numHits;
        savedResults_prevStage = savedResults;
        mK = K;
    }

    /**
     * Runs a query asynchronously. Results can be fetched using
     * {@link getResults}.
     */
    public void runQuery(String qid, String[] query) {
        Preconditions.checkNotNull(qid);
        Preconditions.checkNotNull(query);

        Future<Accumulator[]> future = mThreadPool
                .submit(new ThreadTask(query, mBuilder, mExpander, qid, mNumHits));
        mQueryResults.put(qid, future);
    }

    /**
     * Runs a query synchronously, waiting until completion.
     */
    public Accumulator[] runQuery(String[] query) {
        Preconditions.checkNotNull(query);

        Future<Accumulator[]> future = mThreadPool
                .submit(new ThreadTask(query, mBuilder, mExpander, "query", mNumHits));
        Accumulator[] results = null;
        try {
            results = future.get();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return results;
    }

    /**
     * Fetches the results of a query. If necessary, waits until completion of
     * the query.
     * 
     * @param qid
     *            query id
     */
    public Accumulator[] getResults(String qid) {
        try {
            return mQueryResults.get(qid).get();
        } catch (InterruptedException e) {
            e.printStackTrace();
            return null;
        } catch (ExecutionException e) {
            e.printStackTrace();
            return null;
        }
    }

    /**
     * Clears all stored results.
     */
    public void clearResults() {
        mQueryResults.clear();
    }

    /**
     * Returns results of all queries executed.
     */
    public Map<String, Accumulator[]> getResults() {
        Map<String, Accumulator[]> results = Maps.newLinkedHashMap();
        for (Map.Entry<String, Future<Accumulator[]>> e : mQueryResults.entrySet()) {
            try {
                Accumulator[] a = e.getValue().get();

                if (a != null) {
                    results.put(e.getKey(), e.getValue().get());
                }

            } catch (Exception ex) {
                ex.printStackTrace();
            }
        }
        return results;
    }

    //return cascade cost for all queries
    public float[] getCascadeCostAllQueries() {
        return cascadeCostAllQueries;
    }

    public float[] getCascadeCostAllQueries_lastStage() {
        return cascadeCostAllQueries_lastStage;
    }

    // Thread for running a query.  No need to expose implementation.
    private class ThreadTask implements Callable<Accumulator[]> {
        private String[] mQuery;
        private MRFBuilder mBuilder;
        private MRFExpander mExpander;
        private String mQid;
        private int mNumHits;

        public ThreadTask(String[] query, MRFBuilder builder, MRFExpander expander, String qid, int numHits) {
            mQuery = query;
            mBuilder = builder;
            mExpander = expander;
            mQid = qid;
            mNumHits = numHits;
        }

        public Accumulator[] call() {
            try {
                long startTime;
                long endTime;

                startTime = System.currentTimeMillis();

                // Build the MRF for this query.
                Object r = savedResults_prevStage.get(mQid);
                float[][] savedResults = null; //store docno and score
                if (r != null) {
                    savedResults = (float[][]) r;
                }
                MarkovRandomField mrf = mBuilder.buildMRF(mQuery);

                // Run initial query, if necessary.
                Accumulator[] results = null;

                float cascadeCost = -1;
                float cascadeCost_lastStage = -1;

                if (mrf.getCliques().size() == 0) {
                }

                else {
                    if (RetrievalEnvironment.mIsNewModel) {

                        CascadeEval ranker = new CascadeEval(mrf, mNumHits, mQid, savedResults, mK);

                        // Rank the documents using the cascade model. 
                        results = ranker.rank();
                        cascadeCost = ranker.getCascadeCost();
                    }

                    else {
                        // Retrieve documents using this MRF.
                        MRFDocumentRanker ranker = new MRFDocumentRanker(mrf, mNumHits);

                        if (mExpander != null) {
                            results = ranker.rank();
                        }

                        // Perform pseudo-relevance feedback, if requested.
                        if (mExpander != null) {
                            // Get expanded MRF.
                            MarkovRandomField expandedMRF = mExpander.getExpandedMRF(mrf, results);

                            // Re-rank documents according to expanded MRF.
                            ranker = new MRFDocumentRanker(expandedMRF, mNumHits);
                        }

                        // Rank the documents.
                        results = ranker.rank();

                        //cascadeCost = ranker.getCost();
                        cascadeCost = -1; //todo: later

                    }
                }
                endTime = System.currentTimeMillis();
                sLogger.info("Processed query " + mQid + " in " + (endTime - startTime) + " ms.");

                //This stores the cascade cost for this query using the model represented by modelID
                if (cascadeCost != -1) {
                    //String key = BatchQueryRunner.model_ID + " "+mQid;
                    //BatchQueryRunner.cascadeCosts.put(key, cascadeCost+"");

                    cascadeCostAllQueries[Integer.parseInt(mQid)] = cascadeCost;
                }

                if (cascadeCost_lastStage != -1) {
                    cascadeCostAllQueries_lastStage[Integer.parseInt(mQid)] = cascadeCost_lastStage;
                }

                return results;
            } catch (ConfigurationException e) {
                e.printStackTrace();
                sLogger.error(e.getMessage());

                return null;
            }
        }
    }
}