edu.mit.ll.vizlinc.highlight.WeightedSpanTermExtractor.java Source code

Java tutorial

Introduction

Here is the source code for edu.mit.ll.vizlinc.highlight.WeightedSpanTermExtractor.java

Source

package edu.mit.ll.vizlinc.highlight;

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.analysis.CachingTokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.index.FilterIndexReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.index.memory.MemoryIndex;
import org.apache.lucene.search.*;
import org.apache.lucene.search.spans.FieldMaskingSpanQuery;
import org.apache.lucene.search.spans.SpanFirstQuery;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanNotQuery;
import org.apache.lucene.search.spans.SpanOrQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.StringHelper;

/**
 * Class used to extract {@link WeightedSpanTerm}s from a {@link Query} based on whether 
 * {@link Term}s from the {@link Query} are contained in a supplied {@link TokenStream}.
 */
public class WeightedSpanTermExtractor {

    private String fieldName;
    private TokenStream tokenStream;
    private Map<String, IndexReader> readers = new HashMap<String, IndexReader>(10);
    private String defaultField;
    private boolean expandMultiTermQuery;
    private boolean cachedTokenStream;
    private boolean wrapToCaching = true;
    private int maxDocCharsToAnalyze;

    public WeightedSpanTermExtractor() {
    }

    public WeightedSpanTermExtractor(String defaultField) {
        if (defaultField != null) {
            this.defaultField = StringHelper.intern(defaultField);
        }
    }

    private void closeReaders() {
        Collection<IndexReader> readerSet = readers.values();

        for (final IndexReader reader : readerSet) {
            try {
                reader.close();
            } catch (IOException e) {
                // alert?
            }
        }
    }

    /**
     * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
     * 
     * @param query
     *          Query to extract Terms from
     * @param terms
     *          Map to place created WeightedSpanTerms in
     * @throws IOException
     */
    private void extract(Query query, Map<String, WeightedSpanTerm> terms) throws IOException {
        if (query instanceof BooleanQuery) {
            BooleanClause[] queryClauses = ((BooleanQuery) query).getClauses();

            for (int i = 0; i < queryClauses.length; i++) {
                if (!queryClauses[i].isProhibited()) {
                    extract(queryClauses[i].getQuery(), terms);
                }
            }
        } else if (query instanceof PhraseQuery) {
            PhraseQuery phraseQuery = ((PhraseQuery) query);
            Term[] phraseQueryTerms = phraseQuery.getTerms();
            SpanQuery[] clauses = new SpanQuery[phraseQueryTerms.length];
            for (int i = 0; i < phraseQueryTerms.length; i++) {
                clauses[i] = new SpanTermQuery(phraseQueryTerms[i]);
            }
            int slop = phraseQuery.getSlop();
            int[] positions = phraseQuery.getPositions();
            // add largest position increment to slop
            if (positions.length > 0) {
                int lastPos = positions[0];
                int largestInc = 0;
                int sz = positions.length;
                for (int i = 1; i < sz; i++) {
                    int pos = positions[i];
                    int inc = pos - lastPos;
                    if (inc > largestInc) {
                        largestInc = inc;
                    }
                    lastPos = pos;
                }
                if (largestInc > 1) {
                    slop += largestInc;
                }
            }

            boolean inorder = false;

            if (slop == 0) {
                inorder = true;
            }

            SpanNearQuery sp = new SpanNearQuery(clauses, slop, inorder);
            sp.setBoost(query.getBoost());
            extractWeightedSpanTerms(terms, sp);
        } else if (query instanceof TermQuery) {
            extractWeightedTerms(terms, query);
        } else if (query instanceof SpanQuery) {
            extractWeightedSpanTerms(terms, (SpanQuery) query);
        } else if (query instanceof FilteredQuery) {
            extract(((FilteredQuery) query).getQuery(), terms);
        } else if (query instanceof DisjunctionMaxQuery) {
            for (Iterator<Query> iterator = ((DisjunctionMaxQuery) query).iterator(); iterator.hasNext();) {
                extract(iterator.next(), terms);
            }
        } else if (query instanceof MultiTermQuery && expandMultiTermQuery) {
            MultiTermQuery mtq = ((MultiTermQuery) query);
            if (mtq.getRewriteMethod() != MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE) {
                mtq = (MultiTermQuery) mtq.clone();
                mtq.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE);
                query = mtq;
            }
            FakeReader fReader = new FakeReader();
            MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE.rewrite(fReader, mtq);
            if (fReader.field != null) {
                IndexReader ir = getReaderForField(fReader.field);
                extract(query.rewrite(ir), terms);
            }
        } else if (query instanceof MultiPhraseQuery) {
            final MultiPhraseQuery mpq = (MultiPhraseQuery) query;
            final List<Term[]> termArrays = mpq.getTermArrays();
            final int[] positions = mpq.getPositions();
            if (positions.length > 0) {

                int maxPosition = positions[positions.length - 1];
                for (int i = 0; i < positions.length - 1; ++i) {
                    if (positions[i] > maxPosition) {
                        maxPosition = positions[i];
                    }
                }

                @SuppressWarnings("unchecked")
                final List<SpanQuery>[] disjunctLists = new List[maxPosition + 1];
                int distinctPositions = 0;

                for (int i = 0; i < termArrays.size(); ++i) {
                    final Term[] termArray = termArrays.get(i);
                    List<SpanQuery> disjuncts = disjunctLists[positions[i]];
                    if (disjuncts == null) {
                        disjuncts = (disjunctLists[positions[i]] = new ArrayList<SpanQuery>(termArray.length));
                        ++distinctPositions;
                    }
                    for (int j = 0; j < termArray.length; ++j) {
                        disjuncts.add(new SpanTermQuery(termArray[j]));
                    }
                }

                int positionGaps = 0;
                int position = 0;
                final SpanQuery[] clauses = new SpanQuery[distinctPositions];
                for (int i = 0; i < disjunctLists.length; ++i) {
                    List<SpanQuery> disjuncts = disjunctLists[i];
                    if (disjuncts != null) {
                        clauses[position++] = new SpanOrQuery(disjuncts.toArray(new SpanQuery[disjuncts.size()]));
                    } else {
                        ++positionGaps;
                    }
                }

                final int slop = mpq.getSlop();
                final boolean inorder = (slop == 0);

                SpanNearQuery sp = new SpanNearQuery(clauses, slop + positionGaps, inorder);
                sp.setBoost(query.getBoost());
                extractWeightedSpanTerms(terms, sp);
            }
        }
    }

    /**
     * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>SpanQuery</code>.
     * 
     * @param terms
     *          Map to place created WeightedSpanTerms in
     * @param spanQuery
     *          SpanQuery to extract Terms from
     * @throws IOException
     */
    private void extractWeightedSpanTerms(Map<String, WeightedSpanTerm> terms, SpanQuery spanQuery)
            throws IOException {
        Set<String> fieldNames;

        if (fieldName == null) {
            fieldNames = new HashSet<String>();
            collectSpanQueryFields(spanQuery, fieldNames);
        } else {
            fieldNames = new HashSet<String>(1);
            fieldNames.add(fieldName);
        }
        // To support the use of the default field name
        if (defaultField != null) {
            fieldNames.add(defaultField);
        }

        Map<String, SpanQuery> queries = new HashMap<String, SpanQuery>();

        Set<Term> nonWeightedTerms = new HashSet<Term>();
        final boolean mustRewriteQuery = mustRewriteQuery(spanQuery);
        if (mustRewriteQuery) {
            for (final String field : fieldNames) {
                final SpanQuery rewrittenQuery = (SpanQuery) spanQuery.rewrite(getReaderForField(field));
                queries.put(field, rewrittenQuery);
                rewrittenQuery.extractTerms(nonWeightedTerms);
            }
        } else {
            spanQuery.extractTerms(nonWeightedTerms);
        }

        List<PositionSpan> spanPositions = new ArrayList<PositionSpan>();

        for (final String field : fieldNames) {

            IndexReader reader = getReaderForField(field);
            final Spans spans;
            if (mustRewriteQuery) {
                spans = queries.get(field).getSpans(reader);
            } else {
                spans = spanQuery.getSpans(reader);
            }

            // collect span positions
            while (spans.next()) {
                spanPositions.add(new PositionSpan(spans.start(), spans.end() - 1));
            }

        }

        if (spanPositions.size() == 0) {
            // no spans found
            return;
        }

        for (final Term queryTerm : nonWeightedTerms) {

            if (fieldNameComparator(queryTerm.field())) {
                WeightedSpanTerm weightedSpanTerm = terms.get(queryTerm.text());

                if (weightedSpanTerm == null) {
                    weightedSpanTerm = new WeightedSpanTerm(spanQuery.getBoost(), queryTerm.text());
                    weightedSpanTerm.addPositionSpans(spanPositions);
                    weightedSpanTerm.positionSensitive = true;
                    terms.put(queryTerm.text(), weightedSpanTerm);
                } else {
                    if (spanPositions.size() > 0) {
                        weightedSpanTerm.addPositionSpans(spanPositions);
                    }
                }
            }
        }
    }

    /**
     * Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
     * 
     * @param terms
     *          Map to place created WeightedSpanTerms in
     * @param query
     *          Query to extract Terms from
     * @throws IOException
     */
    private void extractWeightedTerms(Map<String, WeightedSpanTerm> terms, Query query) throws IOException {
        Set<Term> nonWeightedTerms = new HashSet<Term>();
        query.extractTerms(nonWeightedTerms);

        for (final Term queryTerm : nonWeightedTerms) {

            if (fieldNameComparator(queryTerm.field())) {
                WeightedSpanTerm weightedSpanTerm = new WeightedSpanTerm(query.getBoost(), queryTerm.text());
                terms.put(queryTerm.text(), weightedSpanTerm);
            }
        }
    }

    /**
     * Necessary to implement matches for queries against <code>defaultField</code>
     */
    private boolean fieldNameComparator(String fieldNameToCheck) {
        boolean rv = fieldName == null || fieldNameToCheck == fieldName || fieldNameToCheck == defaultField;
        return rv;
    }

    private IndexReader getReaderForField(String field) throws IOException {
        if (wrapToCaching && !cachedTokenStream && !(tokenStream instanceof CachingTokenFilter)) {
            tokenStream = new CachingTokenFilter(new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
            cachedTokenStream = true;
        }
        IndexReader reader = readers.get(field);
        if (reader == null) {
            MemoryIndex indexer = new MemoryIndex();
            indexer.addField(field, new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
            tokenStream.reset();
            IndexSearcher searcher = indexer.createSearcher();
            reader = searcher.getIndexReader();
            readers.put(field, reader);
        }

        return reader;
    }

    /**
     * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
     * 
     * <p>
     * 
     * @param query
     *          that caused hit
     * @param tokenStream
     *          of text to be highlighted
     * @return Map containing WeightedSpanTerms
     * @throws IOException
     */
    public Map<String, WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream)
            throws IOException {
        return getWeightedSpanTerms(query, tokenStream, null);
    }

    /**
     * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
     * 
     * <p>
     * 
     * @param query
     *          that caused hit
     * @param tokenStream
     *          of text to be highlighted
     * @param fieldName
     *          restricts Term's used based on field name
     * @return Map containing WeightedSpanTerms
     * @throws IOException
     */
    public Map<String, WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream,
            String fieldName) throws IOException {
        if (fieldName != null) {
            this.fieldName = StringHelper.intern(fieldName);
        } else {
            this.fieldName = null;
        }

        Map<String, WeightedSpanTerm> terms = new PositionCheckingMap<String>();
        this.tokenStream = tokenStream;
        try {
            extract(query, terms);
        } finally {
            closeReaders();
        }

        return terms;
    }

    /**
     * Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>. Uses a supplied
     * <code>IndexReader</code> to properly weight terms (for gradient highlighting).
     * 
     * <p>
     * 
     * @param query
     *          that caused hit
     * @param tokenStream
     *          of text to be highlighted
     * @param fieldName
     *          restricts Term's used based on field name
     * @param reader
     *          to use for scoring
     * @return Map of WeightedSpanTerms with quasi tf/idf scores
     * @throws IOException
     */
    public Map<String, WeightedSpanTerm> getWeightedSpanTermsWithScores(Query query, TokenStream tokenStream,
            String fieldName, IndexReader reader) throws IOException {
        if (fieldName != null) {
            this.fieldName = StringHelper.intern(fieldName);
        } else {
            this.fieldName = null;
        }
        this.tokenStream = tokenStream;

        Map<String, WeightedSpanTerm> terms = new PositionCheckingMap<String>();
        extract(query, terms);

        int totalNumDocs = reader.numDocs();
        Set<String> weightedTerms = terms.keySet();
        Iterator<String> it = weightedTerms.iterator();

        try {
            while (it.hasNext()) {
                WeightedSpanTerm weightedSpanTerm = terms.get(it.next());
                int docFreq = reader.docFreq(new Term(fieldName, weightedSpanTerm.term));
                // docFreq counts deletes
                if (totalNumDocs < docFreq) {
                    docFreq = totalNumDocs;
                }
                // IDF algorithm taken from DefaultSimilarity class
                float idf = (float) (Math.log((float) totalNumDocs / (double) (docFreq + 1)) + 1.0);
                weightedSpanTerm.weight *= idf;
            }
        } finally {

            closeReaders();
        }

        return terms;
    }

    private void collectSpanQueryFields(SpanQuery spanQuery, Set<String> fieldNames) {
        if (spanQuery instanceof FieldMaskingSpanQuery) {
            collectSpanQueryFields(((FieldMaskingSpanQuery) spanQuery).getMaskedQuery(), fieldNames);
        } else if (spanQuery instanceof SpanFirstQuery) {
            collectSpanQueryFields(((SpanFirstQuery) spanQuery).getMatch(), fieldNames);
        } else if (spanQuery instanceof SpanNearQuery) {
            for (final SpanQuery clause : ((SpanNearQuery) spanQuery).getClauses()) {
                collectSpanQueryFields(clause, fieldNames);
            }
        } else if (spanQuery instanceof SpanNotQuery) {
            collectSpanQueryFields(((SpanNotQuery) spanQuery).getInclude(), fieldNames);
        } else if (spanQuery instanceof SpanOrQuery) {
            for (final SpanQuery clause : ((SpanOrQuery) spanQuery).getClauses()) {
                collectSpanQueryFields(clause, fieldNames);
            }
        } else {
            fieldNames.add(spanQuery.getField());
        }
    }

    private boolean mustRewriteQuery(SpanQuery spanQuery) {
        if (!expandMultiTermQuery) {
            return false; // Will throw UnsupportedOperationException in case of a SpanRegexQuery.
        } else if (spanQuery instanceof FieldMaskingSpanQuery) {
            return mustRewriteQuery(((FieldMaskingSpanQuery) spanQuery).getMaskedQuery());
        } else if (spanQuery instanceof SpanFirstQuery) {
            return mustRewriteQuery(((SpanFirstQuery) spanQuery).getMatch());
        } else if (spanQuery instanceof SpanNearQuery) {
            for (final SpanQuery clause : ((SpanNearQuery) spanQuery).getClauses()) {
                if (mustRewriteQuery(clause)) {
                    return true;
                }
            }
            return false;
        } else if (spanQuery instanceof SpanNotQuery) {
            SpanNotQuery spanNotQuery = (SpanNotQuery) spanQuery;
            return mustRewriteQuery(spanNotQuery.getInclude()) || mustRewriteQuery(spanNotQuery.getExclude());
        } else if (spanQuery instanceof SpanOrQuery) {
            for (final SpanQuery clause : ((SpanOrQuery) spanQuery).getClauses()) {
                if (mustRewriteQuery(clause)) {
                    return true;
                }
            }
            return false;
        } else if (spanQuery instanceof SpanTermQuery) {
            return false;
        } else {
            return true;
        }
    }

    /**
     * This class makes sure that if both position sensitive and insensitive
     * versions of the same term are added, the position insensitive one wins.
     */
    static private class PositionCheckingMap<K> extends HashMap<K, WeightedSpanTerm> {

        @Override
        public void putAll(Map<? extends K, ? extends WeightedSpanTerm> m) {
            for (Map.Entry<? extends K, ? extends WeightedSpanTerm> entry : m.entrySet())
                this.put(entry.getKey(), entry.getValue());
        }

        @Override
        public WeightedSpanTerm put(K key, WeightedSpanTerm value) {
            WeightedSpanTerm prev = super.put(key, value);
            if (prev == null)
                return prev;
            WeightedSpanTerm prevTerm = prev;
            WeightedSpanTerm newTerm = value;
            if (!prevTerm.positionSensitive) {
                newTerm.positionSensitive = false;
            }
            return prev;
        }

    }

    public boolean getExpandMultiTermQuery() {
        return expandMultiTermQuery;
    }

    public void setExpandMultiTermQuery(boolean expandMultiTermQuery) {
        this.expandMultiTermQuery = expandMultiTermQuery;
    }

    public boolean isCachedTokenStream() {
        return cachedTokenStream;
    }

    public TokenStream getTokenStream() {
        return tokenStream;
    }

    /**
     * By default, {@link TokenStream}s that are not of the type
     * {@link CachingTokenFilter} are wrapped in a {@link CachingTokenFilter} to
     * ensure an efficient reset - if you are already using a different caching
     * {@link TokenStream} impl and you don't want it to be wrapped, set this to
     * false.
     * 
     * @param wrap
     */
    public void setWrapIfNotCachingTokenFilter(boolean wrap) {
        this.wrapToCaching = wrap;
    }

    /**
     * 
     * A fake IndexReader class to extract the field from a MultiTermQuery
     * 
     */
    static final class FakeReader extends FilterIndexReader {

        private static final IndexReader EMPTY_MEMORY_INDEX_READER = new MemoryIndex().createSearcher()
                .getIndexReader();

        String field;

        FakeReader() {
            super(EMPTY_MEMORY_INDEX_READER);
        }

        @Override
        public TermEnum terms(final Term t) throws IOException {
            // only set first fieldname, maybe use a Set?
            if (t != null && field == null)
                field = t.field();
            return super.terms(t);
        }

    }

    protected final void setMaxDocCharsToAnalyze(int maxDocCharsToAnalyze) {
        this.maxDocCharsToAnalyze = maxDocCharsToAnalyze;
    }

}