org.apache.lucene.search.join.ToParentBlockJoinQuery.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.lucene.search.join.ToParentBlockJoinQuery.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.join;

import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.Locale;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilterWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Matches;
import org.apache.lucene.search.MatchesUtils;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BitSet;

import static org.apache.lucene.search.ScoreMode.COMPLETE;

/**
 * This query requires that you index
 * children and parent docs as a single block, using the
 * {@link IndexWriter#addDocuments IndexWriter.addDocuments()} or {@link
 * IndexWriter#updateDocuments IndexWriter.updateDocuments()} API.  In each block, the
 * child documents must appear first, ending with the parent
 * document.  At search time you provide a Filter
 * identifying the parents, however this Filter must provide
 * an {@link BitSet} per sub-reader.
 *
 * <p>Once the block index is built, use this query to wrap
 * any sub-query matching only child docs and join matches in that
 * child document space up to the parent document space.
 * You can then use this Query as a clause with
 * other queries in the parent document space.</p>
 *
 * <p>See {@link ToChildBlockJoinQuery} if you need to join
 * in the reverse order.
 *
 * <p>The child documents must be orthogonal to the parent
 * documents: the wrapped child query must never
 * return a parent document.</p>
 *
 * <p>See {@link org.apache.lucene.search.join} for an
 * overview. </p>
 *
 * @lucene.experimental
 */
public class ToParentBlockJoinQuery extends Query {

    private final BitSetProducer parentsFilter;
    private final Query childQuery;
    private final ScoreMode scoreMode;

    /** Create a ToParentBlockJoinQuery.
     *
     * @param childQuery Query matching child documents.
     * @param parentsFilter Filter identifying the parent documents.
     * @param scoreMode How to aggregate multiple child scores
     * into a single parent score.
     **/
    public ToParentBlockJoinQuery(Query childQuery, BitSetProducer parentsFilter, ScoreMode scoreMode) {
        super();
        this.childQuery = childQuery;
        this.parentsFilter = parentsFilter;
        this.scoreMode = scoreMode;
    }

    @Override
    public void visit(QueryVisitor visitor) {
        visitor.visitLeaf(this);
    }

    @Override
    public Weight createWeight(IndexSearcher searcher, org.apache.lucene.search.ScoreMode weightScoreMode,
            float boost) throws IOException {
        ScoreMode childScoreMode = weightScoreMode.needsScores() ? scoreMode : ScoreMode.None;
        final Weight childWeight;
        if (childScoreMode == ScoreMode.None) {
            // we don't need to compute a score for the child query so we wrap
            // it under a constant score query that can early terminate if the
            // minimum score is greater than 0 and the total hits that match the
            // query is not requested.
            childWeight = searcher.rewrite(new ConstantScoreQuery(childQuery)).createWeight(searcher,
                    weightScoreMode, 0f);
        } else {
            // if the score is needed we force the collection mode to COMPLETE because the child query cannot skip
            // non-competitive documents.
            childWeight = childQuery.createWeight(searcher,
                    weightScoreMode.needsScores() ? COMPLETE : weightScoreMode, boost);
        }
        return new BlockJoinWeight(this, childWeight, parentsFilter, childScoreMode);
    }

    /** Return our child query. */
    public Query getChildQuery() {
        return childQuery;
    }

    private static class BlockJoinWeight extends FilterWeight {
        private final BitSetProducer parentsFilter;
        private final ScoreMode scoreMode;

        public BlockJoinWeight(Query joinQuery, Weight childWeight, BitSetProducer parentsFilter,
                ScoreMode scoreMode) {
            super(joinQuery, childWeight);
            this.parentsFilter = parentsFilter;
            this.scoreMode = scoreMode;
        }

        @Override
        public Scorer scorer(LeafReaderContext context) throws IOException {
            final ScorerSupplier scorerSupplier = scorerSupplier(context);
            if (scorerSupplier == null) {
                return null;
            }
            return scorerSupplier.get(Long.MAX_VALUE);
        }

        // NOTE: acceptDocs applies (and is checked) only in the
        // parent document space
        @Override
        public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
            final ScorerSupplier childScorerSupplier = in.scorerSupplier(context);
            if (childScorerSupplier == null) {
                return null;
            }

            // NOTE: this does not take accept docs into account, the responsibility
            // to not match deleted docs is on the scorer
            final BitSet parents = parentsFilter.getBitSet(context);
            if (parents == null) {
                // No matches
                return null;
            }

            return new ScorerSupplier() {

                @Override
                public Scorer get(long leadCost) throws IOException {
                    return new BlockJoinScorer(BlockJoinWeight.this, childScorerSupplier.get(leadCost), parents,
                            scoreMode);
                }

                @Override
                public long cost() {
                    return childScorerSupplier.cost();
                }
            };
        }

        @Override
        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            BlockJoinScorer scorer = (BlockJoinScorer) scorer(context);
            if (scorer != null && scorer.iterator().advance(doc) == doc) {
                return scorer.explain(context, in);
            }
            return Explanation.noMatch("Not a match");
        }

        @Override
        public Matches matches(LeafReaderContext context, int doc) throws IOException {
            // The default implementation would delegate to the joinQuery's Weight, which
            // matches on children.  We need to match on the parent instead
            Scorer scorer = scorer(context);
            if (scorer == null) {
                return null;
            }
            final TwoPhaseIterator twoPhase = scorer.twoPhaseIterator();
            if (twoPhase == null) {
                if (scorer.iterator().advance(doc) != doc) {
                    return null;
                }
            } else {
                if (twoPhase.approximation().advance(doc) != doc || twoPhase.matches() == false) {
                    return null;
                }
            }
            return MatchesUtils.MATCH_WITH_NO_TERMS;
        }
    }

    private static class ParentApproximation extends DocIdSetIterator {

        private final DocIdSetIterator childApproximation;
        private final BitSet parentBits;
        private int doc = -1;

        ParentApproximation(DocIdSetIterator childApproximation, BitSet parentBits) {
            this.childApproximation = childApproximation;
            this.parentBits = parentBits;
        }

        @Override
        public int docID() {
            return doc;
        }

        @Override
        public int nextDoc() throws IOException {
            return advance(doc + 1);
        }

        @Override
        public int advance(int target) throws IOException {
            if (target >= parentBits.length()) {
                return doc = NO_MORE_DOCS;
            }
            final int firstChildTarget = target == 0 ? 0 : parentBits.prevSetBit(target - 1) + 1;
            int childDoc = childApproximation.docID();
            if (childDoc < firstChildTarget) {
                childDoc = childApproximation.advance(firstChildTarget);
            }
            if (childDoc >= parentBits.length() - 1) {
                return doc = NO_MORE_DOCS;
            }
            return doc = parentBits.nextSetBit(childDoc + 1);
        }

        @Override
        public long cost() {
            return childApproximation.cost();
        }
    }

    private static class ParentTwoPhase extends TwoPhaseIterator {

        private final ParentApproximation parentApproximation;
        private final DocIdSetIterator childApproximation;
        private final TwoPhaseIterator childTwoPhase;

        ParentTwoPhase(ParentApproximation parentApproximation, TwoPhaseIterator childTwoPhase) {
            super(parentApproximation);
            this.parentApproximation = parentApproximation;
            this.childApproximation = childTwoPhase.approximation();
            this.childTwoPhase = childTwoPhase;
        }

        @Override
        public boolean matches() throws IOException {
            assert childApproximation.docID() < parentApproximation.docID();
            do {
                if (childTwoPhase.matches()) {
                    return true;
                }
            } while (childApproximation.nextDoc() < parentApproximation.docID());
            return false;
        }

        @Override
        public float matchCost() {
            // TODO: how could we compute a match cost?
            return childTwoPhase.matchCost() + 10;
        }
    }

    static class BlockJoinScorer extends Scorer {
        private final Scorer childScorer;
        private final BitSet parentBits;
        private final ScoreMode scoreMode;
        private final DocIdSetIterator childApproximation;
        private final TwoPhaseIterator childTwoPhase;
        private final ParentApproximation parentApproximation;
        private final ParentTwoPhase parentTwoPhase;
        private float score;

        public BlockJoinScorer(Weight weight, Scorer childScorer, BitSet parentBits, ScoreMode scoreMode) {
            super(weight);
            //System.out.println("Q.init firstChildDoc=" + firstChildDoc);
            this.parentBits = parentBits;
            this.childScorer = childScorer;
            this.scoreMode = scoreMode;
            childTwoPhase = childScorer.twoPhaseIterator();
            if (childTwoPhase == null) {
                childApproximation = childScorer.iterator();
                parentApproximation = new ParentApproximation(childApproximation, parentBits);
                parentTwoPhase = null;
            } else {
                childApproximation = childTwoPhase.approximation();
                parentApproximation = new ParentApproximation(childTwoPhase.approximation(), parentBits);
                parentTwoPhase = new ParentTwoPhase(parentApproximation, childTwoPhase);
            }
        }

        @Override
        public Collection<ChildScorable> getChildren() {
            return Collections.singleton(new ChildScorable(childScorer, "BLOCK_JOIN"));
        }

        @Override
        public DocIdSetIterator iterator() {
            if (parentTwoPhase == null) {
                // the approximation is exact
                return parentApproximation;
            } else {
                return TwoPhaseIterator.asDocIdSetIterator(parentTwoPhase);
            }
        }

        @Override
        public TwoPhaseIterator twoPhaseIterator() {
            return parentTwoPhase;
        }

        @Override
        public int docID() {
            return parentApproximation.docID();
        }

        @Override
        public float score() throws IOException {
            setScoreAndFreq();
            return score;
        }

        @Override
        public float getMaxScore(int upTo) throws IOException {
            if (scoreMode == ScoreMode.None) {
                return childScorer.getMaxScore(upTo);
            }
            return Float.POSITIVE_INFINITY;
        }

        @Override
        public void setMinCompetitiveScore(float minScore) throws IOException {
            if (scoreMode == ScoreMode.None) {
                childScorer.setMinCompetitiveScore(minScore);
            }
        }

        private void setScoreAndFreq() throws IOException {
            if (childApproximation.docID() >= parentApproximation.docID()) {
                return;
            }
            double score = scoreMode == ScoreMode.None ? 0 : childScorer.score();
            int freq = 1;
            while (childApproximation.nextDoc() < parentApproximation.docID()) {
                if (childTwoPhase == null || childTwoPhase.matches()) {
                    final float childScore = scoreMode == ScoreMode.None ? 0 : childScorer.score();
                    freq += 1;
                    switch (scoreMode) {
                    case Total:
                    case Avg:
                        score += childScore;
                        break;
                    case Min:
                        score = Math.min(score, childScore);
                        break;
                    case Max:
                        score = Math.max(score, childScore);
                        break;
                    case None:
                        break;
                    default:
                        throw new AssertionError();
                    }
                }
            }
            if (childApproximation.docID() == parentApproximation.docID()
                    && (childTwoPhase == null || childTwoPhase.matches())) {
                throw new IllegalStateException("Child query must not match same docs with parent filter. "
                        + "Combine them as must clauses (+) to find a problem doc. " + "docId="
                        + parentApproximation.docID() + ", " + childScorer.getClass());
            }
            if (scoreMode == ScoreMode.Avg) {
                score /= freq;
            }
            this.score = (float) score;
        }

        public Explanation explain(LeafReaderContext context, Weight childWeight) throws IOException {
            int prevParentDoc = parentBits.prevSetBit(parentApproximation.docID() - 1);
            int start = context.docBase + prevParentDoc + 1; // +1 b/c prevParentDoc is previous parent doc
            int end = context.docBase + parentApproximation.docID() - 1; // -1 b/c parentDoc is parent doc

            Explanation bestChild = null;
            int matches = 0;
            for (int childDoc = start; childDoc <= end; childDoc++) {
                Explanation child = childWeight.explain(context, childDoc - context.docBase);
                if (child.isMatch()) {
                    matches++;
                    if (bestChild == null || child.getValue().floatValue() > bestChild.getValue().floatValue()) {
                        bestChild = child;
                    }
                }
            }

            return Explanation.match(score(), String.format(Locale.ROOT,
                    "Score based on %d child docs in range from %d to %d, best match:", matches, start, end),
                    bestChild);
        }
    }

    @Override
    public Query rewrite(IndexReader reader) throws IOException {
        final Query childRewrite = childQuery.rewrite(reader);
        if (childRewrite != childQuery) {
            return new ToParentBlockJoinQuery(childRewrite, parentsFilter, scoreMode);
        } else {
            return super.rewrite(reader);
        }
    }

    @Override
    public String toString(String field) {
        return "ToParentBlockJoinQuery (" + childQuery.toString() + ")";
    }

    @Override
    public boolean equals(Object other) {
        return sameClassAs(other) && equalsTo(getClass().cast(other));
    }

    private boolean equalsTo(ToParentBlockJoinQuery other) {
        return childQuery.equals(other.childQuery) && parentsFilter.equals(other.parentsFilter)
                && scoreMode == other.scoreMode;
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int hash = classHash();
        hash = prime * hash + childQuery.hashCode();
        hash = prime * hash + scoreMode.hashCode();
        hash = prime * hash + parentsFilter.hashCode();
        return hash;
    }
}