org.apache.lucene.queries.CommonTermsQuery.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.lucene.queries.CommonTermsQuery.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.queries;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.TermQuery;

/**
 * A query that executes high-frequency terms in a optional sub-query to prevent
 * slow queries due to "common" terms like stopwords. This query
 * builds 2 queries off the {@link #add(Term) added} terms: low-frequency
 * terms are added to a required boolean clause and high-frequency terms are
 * added to an optional boolean clause. The optional clause is only executed if
 * the required "low-frequency" clause matches. In most cases, high-frequency terms are
 * unlikely to significantly contribute to the document score unless at least
 * one of the low-frequency terms are matched.  This query can improve
 * query execution times significantly if applicable.
 * <p>
 * {@link CommonTermsQuery} has several advantages over stopword filtering at
 * index or query time since a term can be "classified" based on the actual
 * document frequency in the index and can prevent slow queries even across
 * domains without specialized stopword files.
 * </p>
 * <p>
 * <b>Note:</b> if the query only contains high-frequency terms the query is
 * rewritten into a plain conjunction query ie. all high-frequency terms need to
 * match in order to match a document.
 * </p>
 */
public class CommonTermsQuery extends Query {
    /*
     * TODO maybe it would make sense to abstract this even further and allow to
     * rewrite to dismax rather than boolean. Yet, this can already be subclassed
     * to do so.
     */
    protected final List<Term> terms = new ArrayList<>();
    protected final float maxTermFrequency;
    protected final Occur lowFreqOccur;
    protected final Occur highFreqOccur;
    protected float lowFreqBoost = 1.0f;
    protected float highFreqBoost = 1.0f;
    protected float lowFreqMinNrShouldMatch = 0;
    protected float highFreqMinNrShouldMatch = 0;

    /**
     * Creates a new {@link CommonTermsQuery}
     * 
     * @param highFreqOccur
     *          {@link Occur} used for high frequency terms
     * @param lowFreqOccur
     *          {@link Occur} used for low frequency terms
     * @param maxTermFrequency
     *          a value in [0..1) (or absolute number &gt;=1) representing the
     *          maximum threshold of a terms document frequency to be considered a
     *          low frequency term.
     * @throws IllegalArgumentException
     *           if {@link Occur#MUST_NOT} is pass as lowFreqOccur or
     *           highFreqOccur
     */
    public CommonTermsQuery(Occur highFreqOccur, Occur lowFreqOccur, float maxTermFrequency) {
        if (highFreqOccur == Occur.MUST_NOT) {
            throw new IllegalArgumentException("highFreqOccur should be MUST or SHOULD but was MUST_NOT");
        }
        if (lowFreqOccur == Occur.MUST_NOT) {
            throw new IllegalArgumentException("lowFreqOccur should be MUST or SHOULD but was MUST_NOT");
        }
        this.highFreqOccur = highFreqOccur;
        this.lowFreqOccur = lowFreqOccur;
        this.maxTermFrequency = maxTermFrequency;
    }

    /**
     * Adds a term to the {@link CommonTermsQuery}
     * 
     * @param term
     *          the term to add
     */
    public void add(Term term) {
        if (term == null) {
            throw new IllegalArgumentException("Term must not be null");
        }
        this.terms.add(term);
    }

    @Override
    public Query rewrite(IndexReader reader) throws IOException {
        if (this.terms.isEmpty()) {
            return new MatchNoDocsQuery("CommonTermsQuery with no terms");
        } else if (this.terms.size() == 1) {
            return newTermQuery(this.terms.get(0), null);
        }
        final List<LeafReaderContext> leaves = reader.leaves();
        final int maxDoc = reader.maxDoc();
        final TermStates[] contextArray = new TermStates[terms.size()];
        final Term[] queryTerms = this.terms.toArray(new Term[0]);
        collectTermStates(reader, leaves, contextArray, queryTerms);
        return buildQuery(maxDoc, contextArray, queryTerms);
    }

    @Override
    public void visit(QueryVisitor visitor) {
        Term[] selectedTerms = terms.stream().filter(t -> visitor.acceptField(t.field())).toArray(Term[]::new);
        if (selectedTerms.length > 0) {
            QueryVisitor v = visitor.getSubVisitor(Occur.SHOULD, this);
            v.consumeTerms(this, selectedTerms);
        }
    }

    protected int calcLowFreqMinimumNumberShouldMatch(int numOptional) {
        return minNrShouldMatch(lowFreqMinNrShouldMatch, numOptional);
    }

    protected int calcHighFreqMinimumNumberShouldMatch(int numOptional) {
        return minNrShouldMatch(highFreqMinNrShouldMatch, numOptional);
    }

    private final int minNrShouldMatch(float minNrShouldMatch, int numOptional) {
        if (minNrShouldMatch >= 1.0f || minNrShouldMatch == 0.0f) {
            return (int) minNrShouldMatch;
        }
        return Math.round(minNrShouldMatch * numOptional);
    }

    protected Query buildQuery(final int maxDoc, final TermStates[] contextArray, final Term[] queryTerms) {
        List<Query> lowFreqQueries = new ArrayList<>();
        List<Query> highFreqQueries = new ArrayList<>();
        for (int i = 0; i < queryTerms.length; i++) {
            TermStates termStates = contextArray[i];
            if (termStates == null) {
                lowFreqQueries.add(newTermQuery(queryTerms[i], null));
            } else {
                if ((maxTermFrequency >= 1f && termStates.docFreq() > maxTermFrequency)
                        || (termStates.docFreq() > (int) Math.ceil(maxTermFrequency * (float) maxDoc))) {
                    highFreqQueries.add(newTermQuery(queryTerms[i], termStates));
                } else {
                    lowFreqQueries.add(newTermQuery(queryTerms[i], termStates));
                }
            }
        }
        final int numLowFreqClauses = lowFreqQueries.size();
        final int numHighFreqClauses = highFreqQueries.size();
        Occur lowFreqOccur = this.lowFreqOccur;
        Occur highFreqOccur = this.highFreqOccur;
        int lowFreqMinShouldMatch = 0;
        int highFreqMinShouldMatch = 0;
        if (lowFreqOccur == Occur.SHOULD && numLowFreqClauses > 0) {
            lowFreqMinShouldMatch = calcLowFreqMinimumNumberShouldMatch(numLowFreqClauses);
        }
        if (highFreqOccur == Occur.SHOULD && numHighFreqClauses > 0) {
            highFreqMinShouldMatch = calcHighFreqMinimumNumberShouldMatch(numHighFreqClauses);
        }
        if (lowFreqQueries.isEmpty()) {
            /*
             * if lowFreq is empty we rewrite the high freq terms in a conjunction to
             * prevent slow queries.
             */
            if (highFreqMinShouldMatch == 0 && highFreqOccur != Occur.MUST) {
                highFreqOccur = Occur.MUST;
            }
        }
        BooleanQuery.Builder builder = new BooleanQuery.Builder();

        if (lowFreqQueries.isEmpty() == false) {
            BooleanQuery.Builder lowFreq = new BooleanQuery.Builder();
            for (Query query : lowFreqQueries) {
                lowFreq.add(query, lowFreqOccur);
            }
            lowFreq.setMinimumNumberShouldMatch(lowFreqMinShouldMatch);
            Query lowFreqQuery = lowFreq.build();
            builder.add(new BoostQuery(lowFreqQuery, lowFreqBoost), Occur.MUST);
        }
        if (highFreqQueries.isEmpty() == false) {
            BooleanQuery.Builder highFreq = new BooleanQuery.Builder();
            for (Query query : highFreqQueries) {
                highFreq.add(query, highFreqOccur);
            }
            highFreq.setMinimumNumberShouldMatch(highFreqMinShouldMatch);
            Query highFreqQuery = highFreq.build();
            builder.add(new BoostQuery(highFreqQuery, highFreqBoost), Occur.SHOULD);
        }
        return builder.build();
    }

    public void collectTermStates(IndexReader reader, List<LeafReaderContext> leaves, TermStates[] contextArray,
            Term[] queryTerms) throws IOException {
        TermsEnum termsEnum = null;
        for (LeafReaderContext context : leaves) {
            for (int i = 0; i < queryTerms.length; i++) {
                Term term = queryTerms[i];
                TermStates termStates = contextArray[i];
                final Terms terms = context.reader().terms(term.field());
                if (terms == null) {
                    // field does not exist
                    continue;
                }
                termsEnum = terms.iterator();
                assert termsEnum != null;

                if (termsEnum == TermsEnum.EMPTY)
                    continue;
                if (termsEnum.seekExact(term.bytes())) {
                    if (termStates == null) {
                        contextArray[i] = new TermStates(reader.getContext(), termsEnum.termState(), context.ord,
                                termsEnum.docFreq(), termsEnum.totalTermFreq());
                    } else {
                        termStates.register(termsEnum.termState(), context.ord, termsEnum.docFreq(),
                                termsEnum.totalTermFreq());
                    }

                }

            }
        }
    }

    /**
     * Specifies a minimum number of the low frequent optional BooleanClauses which must be
     * satisfied in order to produce a match on the low frequency terms query
     * part. This method accepts a float value in the range [0..1) as a fraction
     * of the actual query terms in the low frequent clause or a number
     * <tt>&gt;=1</tt> as an absolut number of clauses that need to match.
     * 
     * <p>
     * By default no optional clauses are necessary for a match (unless there are
     * no required clauses). If this method is used, then the specified number of
     * clauses is required.
     * </p>
     * 
     * @param min
     *          the number of optional clauses that must match
     */
    public void setLowFreqMinimumNumberShouldMatch(float min) {
        this.lowFreqMinNrShouldMatch = min;
    }

    /**
     * Gets the minimum number of the optional low frequent BooleanClauses which must be
     * satisfied.
     */
    public float getLowFreqMinimumNumberShouldMatch() {
        return lowFreqMinNrShouldMatch;
    }

    /**
     * Specifies a minimum number of the high frequent optional BooleanClauses which must be
     * satisfied in order to produce a match on the low frequency terms query
     * part. This method accepts a float value in the range [0..1) as a fraction
     * of the actual query terms in the low frequent clause or a number
     * <tt>&gt;=1</tt> as an absolut number of clauses that need to match.
     * 
     * <p>
     * By default no optional clauses are necessary for a match (unless there are
     * no required clauses). If this method is used, then the specified number of
     * clauses is required.
     * </p>
     * 
     * @param min
     *          the number of optional clauses that must match
     */
    public void setHighFreqMinimumNumberShouldMatch(float min) {
        this.highFreqMinNrShouldMatch = min;
    }

    /**
     * Gets the minimum number of the optional high frequent BooleanClauses which must be
     * satisfied.
     */
    public float getHighFreqMinimumNumberShouldMatch() {
        return highFreqMinNrShouldMatch;
    }

    /**
     * Gets the list of terms.
     */
    public List<Term> getTerms() {
        return Collections.unmodifiableList(terms);
    }

    /**
     * Gets the maximum threshold of a terms document frequency to be considered a
     * low frequency term.
     */
    public float getMaxTermFrequency() {
        return maxTermFrequency;
    }

    /**
     * Gets the {@link Occur} used for low frequency terms.
     */
    public Occur getLowFreqOccur() {
        return lowFreqOccur;
    }

    /**
     * Gets the {@link Occur} used for high frequency terms.
     */
    public Occur getHighFreqOccur() {
        return highFreqOccur;
    }

    /**
     * Gets the boost used for low frequency terms.
     */
    public float getLowFreqBoost() {
        return lowFreqBoost;
    }

    /**
     * Gets the boost used for high frequency terms.
     */
    public float getHighFreqBoost() {
        return highFreqBoost;
    }

    @Override
    public String toString(String field) {
        StringBuilder buffer = new StringBuilder();
        boolean needParens = (getLowFreqMinimumNumberShouldMatch() > 0);
        if (needParens) {
            buffer.append("(");
        }
        for (int i = 0; i < terms.size(); i++) {
            Term t = terms.get(i);
            buffer.append(newTermQuery(t, null).toString());

            if (i != terms.size() - 1)
                buffer.append(", ");
        }
        if (needParens) {
            buffer.append(")");
        }
        if (getLowFreqMinimumNumberShouldMatch() > 0 || getHighFreqMinimumNumberShouldMatch() > 0) {
            buffer.append('~');
            buffer.append("(");
            buffer.append(getLowFreqMinimumNumberShouldMatch());
            buffer.append(getHighFreqMinimumNumberShouldMatch());
            buffer.append(")");
        }
        return buffer.toString();
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = classHash();
        result = prime * result + Float.floatToIntBits(highFreqBoost);
        result = prime * result + Objects.hashCode(highFreqOccur);
        result = prime * result + Objects.hashCode(lowFreqOccur);
        result = prime * result + Float.floatToIntBits(lowFreqBoost);
        result = prime * result + Float.floatToIntBits(maxTermFrequency);
        result = prime * result + Float.floatToIntBits(lowFreqMinNrShouldMatch);
        result = prime * result + Float.floatToIntBits(highFreqMinNrShouldMatch);
        result = prime * result + Objects.hashCode(terms);
        return result;
    }

    @Override
    public boolean equals(Object other) {
        return sameClassAs(other) && equalsTo(getClass().cast(other));
    }

    private boolean equalsTo(CommonTermsQuery other) {
        return Float.floatToIntBits(highFreqBoost) == Float.floatToIntBits(other.highFreqBoost)
                && highFreqOccur == other.highFreqOccur && lowFreqOccur == other.lowFreqOccur
                && Float.floatToIntBits(lowFreqBoost) == Float.floatToIntBits(other.lowFreqBoost)
                && Float.floatToIntBits(maxTermFrequency) == Float.floatToIntBits(other.maxTermFrequency)
                && lowFreqMinNrShouldMatch == other.lowFreqMinNrShouldMatch
                && highFreqMinNrShouldMatch == other.highFreqMinNrShouldMatch && terms.equals(other.terms);
    }

    /**
     * Builds a new TermQuery instance.
     * <p>This is intended for subclasses that wish to customize the generated queries.</p>
     * @param term term
     * @param termStates the TermStates to be used to create the low level term query. Can be <code>null</code>.
     * @return new TermQuery instance
     */
    protected Query newTermQuery(Term term, TermStates termStates) {
        return termStates == null ? new TermQuery(term) : new TermQuery(term, termStates);
    }
}