Java tutorial
/* * Ivory: A Hadoop toolkit for web-scale information retrieval * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package ivory.cascade.retrieval; import ivory.core.RetrievalEnvironment; import ivory.core.exception.ConfigurationException; import ivory.smrf.model.MarkovRandomField; import ivory.smrf.model.builder.MRFBuilder; import ivory.smrf.model.expander.MRFExpander; import ivory.smrf.retrieval.Accumulator; import ivory.smrf.retrieval.MRFDocumentRanker; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.log4j.Logger; import com.google.common.base.Preconditions; import com.google.common.collect.Maps; /** * @author Lidan Wang */ public class CascadeThreadedQueryRunner implements CascadeQueryRunner { private static final Logger sLogger = Logger.getLogger(CascadeThreadedQueryRunner.class); private MRFBuilder mBuilder; private MRFExpander mExpander; private ExecutorService mThreadPool; private Map<String, Future<Accumulator[]>> mQueryResults; private int mNumHits; private Map<Integer, Float[][]> savedResults_prevStage = Maps.newHashMap(); //for all queries private int mK; //K value used in cascade model //assume no more than 1000 queries private float[] cascadeCostAllQueries = new float[1000]; private float[] cascadeCostAllQueries_lastStage = new float[1000]; public CascadeThreadedQueryRunner(MRFBuilder builder, MRFExpander expander, int numThreads, int numHits, Map<Integer, Float[][]> savedResults, int K) { Preconditions.checkNotNull(builder); assert (numThreads > 0); assert (numHits > 0); mBuilder = builder; mExpander = expander; mThreadPool = Executors.newFixedThreadPool(numThreads); mQueryResults = Maps.newLinkedHashMap(); mNumHits = numHits; savedResults_prevStage = savedResults; mK = K; } /** * Runs a query asynchronously. Results can be fetched using * {@link getResults}. */ public void runQuery(String qid, String[] query) { Preconditions.checkNotNull(qid); Preconditions.checkNotNull(query); Future<Accumulator[]> future = mThreadPool .submit(new ThreadTask(query, mBuilder, mExpander, qid, mNumHits)); mQueryResults.put(qid, future); } /** * Runs a query synchronously, waiting until completion. */ public Accumulator[] runQuery(String[] query) { Preconditions.checkNotNull(query); Future<Accumulator[]> future = mThreadPool .submit(new ThreadTask(query, mBuilder, mExpander, "query", mNumHits)); Accumulator[] results = null; try { results = future.get(); } catch (Exception e) { e.printStackTrace(); } return results; } /** * Fetches the results of a query. If necessary, waits until completion of * the query. * * @param qid * query id */ public Accumulator[] getResults(String qid) { try { return mQueryResults.get(qid).get(); } catch (InterruptedException e) { e.printStackTrace(); return null; } catch (ExecutionException e) { e.printStackTrace(); return null; } } /** * Clears all stored results. */ public void clearResults() { mQueryResults.clear(); } /** * Returns results of all queries executed. */ public Map<String, Accumulator[]> getResults() { Map<String, Accumulator[]> results = Maps.newLinkedHashMap(); for (Map.Entry<String, Future<Accumulator[]>> e : mQueryResults.entrySet()) { try { Accumulator[] a = e.getValue().get(); if (a != null) { results.put(e.getKey(), e.getValue().get()); } } catch (Exception ex) { ex.printStackTrace(); } } return results; } //return cascade cost for all queries public float[] getCascadeCostAllQueries() { return cascadeCostAllQueries; } public float[] getCascadeCostAllQueries_lastStage() { return cascadeCostAllQueries_lastStage; } // Thread for running a query. No need to expose implementation. private class ThreadTask implements Callable<Accumulator[]> { private String[] mQuery; private MRFBuilder mBuilder; private MRFExpander mExpander; private String mQid; private int mNumHits; public ThreadTask(String[] query, MRFBuilder builder, MRFExpander expander, String qid, int numHits) { mQuery = query; mBuilder = builder; mExpander = expander; mQid = qid; mNumHits = numHits; } public Accumulator[] call() { try { long startTime; long endTime; startTime = System.currentTimeMillis(); // Build the MRF for this query. Object r = savedResults_prevStage.get(mQid); float[][] savedResults = null; //store docno and score if (r != null) { savedResults = (float[][]) r; } MarkovRandomField mrf = mBuilder.buildMRF(mQuery); // Run initial query, if necessary. Accumulator[] results = null; float cascadeCost = -1; float cascadeCost_lastStage = -1; if (mrf.getCliques().size() == 0) { } else { if (RetrievalEnvironment.mIsNewModel) { CascadeEval ranker = new CascadeEval(mrf, mNumHits, mQid, savedResults, mK); // Rank the documents using the cascade model. results = ranker.rank(); cascadeCost = ranker.getCascadeCost(); } else { // Retrieve documents using this MRF. MRFDocumentRanker ranker = new MRFDocumentRanker(mrf, mNumHits); if (mExpander != null) { results = ranker.rank(); } // Perform pseudo-relevance feedback, if requested. if (mExpander != null) { // Get expanded MRF. MarkovRandomField expandedMRF = mExpander.getExpandedMRF(mrf, results); // Re-rank documents according to expanded MRF. ranker = new MRFDocumentRanker(expandedMRF, mNumHits); } // Rank the documents. results = ranker.rank(); //cascadeCost = ranker.getCost(); cascadeCost = -1; //todo: later } } endTime = System.currentTimeMillis(); sLogger.info("Processed query " + mQid + " in " + (endTime - startTime) + " ms."); //This stores the cascade cost for this query using the model represented by modelID if (cascadeCost != -1) { //String key = BatchQueryRunner.model_ID + " "+mQid; //BatchQueryRunner.cascadeCosts.put(key, cascadeCost+""); cascadeCostAllQueries[Integer.parseInt(mQid)] = cascadeCost; } if (cascadeCost_lastStage != -1) { cascadeCostAllQueries_lastStage[Integer.parseInt(mQid)] = cascadeCost_lastStage; } return results; } catch (ConfigurationException e) { e.printStackTrace(); sLogger.error(e.getMessage()); return null; } } } }