opennlp.tools.ml.BeamSearch.java Source code

Java tutorial

Introduction

Here is the source code for opennlp.tools.ml.BeamSearch.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package opennlp.tools.ml;

import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;

import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.util.BeamSearchContextGenerator;
import opennlp.tools.util.Cache;
import opennlp.tools.util.Sequence;
import opennlp.tools.util.SequenceValidator;

/**
 * Performs k-best search over sequence.  This is based on the description in
 * Ratnaparkhi (1998), PhD diss, Univ. of Pennsylvania.
 *
 * @see Sequence
 * @see SequenceValidator
 * @see BeamSearchContextGenerator
 */
public class BeamSearch<T> implements SequenceClassificationModel<T> {

    public static final String BEAM_SIZE_PARAMETER = "BeamSize";

    private static final Object[] EMPTY_ADDITIONAL_CONTEXT = new Object[0];

    protected int size;
    protected MaxentModel model;

    private double[] probs;
    private Cache<String[], double[]> contextsCache;
    private static final int zeroLog = -100000;

    /**
     * Creates new search object.
     *
     * @param size The size of the beam (k).
     * @param model the model for assigning probabilities to the sequence outcomes.
     */
    public BeamSearch(int size, MaxentModel model) {
        this(size, model, 0);
    }

    public BeamSearch(int size, MaxentModel model, int cacheSize) {

        this.size = size;
        this.model = model;

        if (cacheSize > 0) {
            contextsCache = new Cache<>(cacheSize);
        }

        this.probs = new double[model.getNumOutcomes()];
    }

    /**
     * Returns the best sequence of outcomes based on model for this object.
     *
     * @param sequence The input sequence.
     * @param additionalContext An Object[] of additional context.
     *     This is passed to the context generator blindly with the
     *     assumption that the context are appropiate.
     *
     * @return The top ranked sequence of outcomes or null if no sequence could be found
     */
    public Sequence[] bestSequences(int numSequences, T[] sequence, Object[] additionalContext,
            double minSequenceScore, BeamSearchContextGenerator<T> cg, SequenceValidator<T> validator) {

        Queue<Sequence> prev = new PriorityQueue<>(size);
        Queue<Sequence> next = new PriorityQueue<>(size);
        Queue<Sequence> tmp;
        prev.add(new Sequence());

        if (additionalContext == null) {
            additionalContext = EMPTY_ADDITIONAL_CONTEXT;
        }

        for (int i = 0; i < sequence.length; i++) {
            int sz = Math.min(size, prev.size());

            for (int sc = 0; prev.size() > 0 && sc < sz; sc++) {
                Sequence top = prev.remove();
                List<String> tmpOutcomes = top.getOutcomes();
                String[] outcomes = tmpOutcomes.toArray(new String[tmpOutcomes.size()]);
                String[] contexts = cg.getContext(i, sequence, outcomes, additionalContext);
                double[] scores;
                if (contextsCache != null) {
                    scores = contextsCache.computeIfAbsent(contexts, c -> model.eval(c, probs));
                } else {
                    scores = model.eval(contexts, probs);
                }

                double[] temp_scores = new double[scores.length];
                System.arraycopy(scores, 0, temp_scores, 0, scores.length);

                Arrays.sort(temp_scores);

                double min = temp_scores[Math.max(0, scores.length - size)];

                for (int p = 0; p < scores.length; p++) {
                    if (scores[p] >= min) {
                        String out = model.getOutcome(p);
                        if (validator.validSequence(i, sequence, outcomes, out)) {
                            Sequence ns = new Sequence(top, out, scores[p]);
                            if (ns.getScore() > minSequenceScore) {
                                next.add(ns);
                            }
                        }
                    }
                }

                if (next.size() == 0) { //if no advanced sequences, advance all valid
                    for (int p = 0; p < scores.length; p++) {
                        String out = model.getOutcome(p);
                        if (validator.validSequence(i, sequence, outcomes, out)) {
                            Sequence ns = new Sequence(top, out, scores[p]);
                            if (ns.getScore() > minSequenceScore) {
                                next.add(ns);
                            }
                        }
                    }
                }
            }

            //    make prev = next; and re-init next (we reuse existing prev set once we clear it)
            prev.clear();
            tmp = prev;
            prev = next;
            next = tmp;
        }

        int numSeq = Math.min(numSequences, prev.size());
        Sequence[] topSequences = new Sequence[numSeq];

        for (int seqIndex = 0; seqIndex < numSeq; seqIndex++) {
            topSequences[seqIndex] = prev.remove();
        }

        return topSequences;
    }

    public Sequence[] bestSequences(int numSequences, T[] sequence, Object[] additionalContext,
            BeamSearchContextGenerator<T> cg, SequenceValidator<T> validator) {
        return bestSequences(numSequences, sequence, additionalContext, zeroLog, cg, validator);
    }

    public Sequence bestSequence(T[] sequence, Object[] additionalContext, BeamSearchContextGenerator<T> cg,
            SequenceValidator<T> validator) {
        Sequence[] sequences = bestSequences(1, sequence, additionalContext, cg, validator);

        if (sequences.length > 0)
            return sequences[0];
        else
            return null;
    }

    @Override
    public String[] getOutcomes() {
        String[] outcomes = new String[model.getNumOutcomes()];

        for (int i = 0; i < model.getNumOutcomes(); i++) {
            outcomes[i] = model.getOutcome(i);
        }

        return outcomes;
    }
}