opennlp.tools.postag.POSTaggerME.java Source code

Java tutorial

Introduction

Here is the source code for opennlp.tools.postag.POSTaggerME.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package opennlp.tools.postag;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicInteger;

import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.ml.BeamSearch;
import opennlp.tools.ml.EventModelSequenceTrainer;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.SequenceTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.TrainerFactory.TrainerType;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.ngram.NGramModel;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.Sequence;
import opennlp.tools.util.SequenceValidator;
import opennlp.tools.util.StringList;
import opennlp.tools.util.StringUtil;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.featuregen.StringPattern;

/**
 * A part-of-speech tagger that uses maximum entropy.  Tries to predict whether
 * words are nouns, verbs, or any of 70 other POS tags depending on their
 * surrounding context.
 *
 */
public class POSTaggerME implements POSTagger {

    public static final int DEFAULT_BEAM_SIZE = 3;

    private POSModel modelPackage;

    /**
     * The feature context generator.
     */
    protected POSContextGenerator contextGen;

    /**
     * Tag dictionary used for restricting words to a fixed set of tags.
     */
    protected TagDictionary tagDictionary;

    protected Dictionary ngramDictionary;

    /**
     * Says whether a filter should be used to check whether a tag assignment
     * is to a word outside of a closed class.
     */
    protected boolean useClosedClassTagsFilter = false;

    /**
     * The size of the beam to be used in determining the best sequence of pos tags.
     */
    protected int size;

    private Sequence bestSequence;

    private SequenceClassificationModel<String> model;

    private SequenceValidator<String> sequenceValidator;

    /**
     * Initializes the current instance with the provided model.
     *
     * @param model
     */
    public POSTaggerME(POSModel model) {
        POSTaggerFactory factory = model.getFactory();

        int beamSize = POSTaggerME.DEFAULT_BEAM_SIZE;

        String beamSizeString = model.getManifestProperty(BeamSearch.BEAM_SIZE_PARAMETER);

        if (beamSizeString != null) {
            beamSize = Integer.parseInt(beamSizeString);
        }

        modelPackage = model;

        contextGen = factory.getPOSContextGenerator(beamSize);
        tagDictionary = factory.getTagDictionary();
        size = beamSize;

        sequenceValidator = factory.getSequenceValidator();

        if (model.getPosSequenceModel() != null) {
            this.model = model.getPosSequenceModel();
        } else {
            this.model = new opennlp.tools.ml.BeamSearch<>(beamSize, model.getPosModel(), 0);
        }

    }

    /**
     * Retrieves an array of all possible part-of-speech tags from the
     * tagger.
     *
     * @return String[]
     */
    public String[] getAllPosTags() {
        return model.getOutcomes();
    }

    public String[] tag(String[] sentence) {
        return this.tag(sentence, null);
    }

    public String[] tag(String[] sentence, Object[] additionaContext) {
        bestSequence = model.bestSequence(sentence, additionaContext, contextGen, sequenceValidator);
        List<String> t = bestSequence.getOutcomes();
        return t.toArray(new String[t.size()]);
    }

    /**
     * Returns at most the specified number of taggings for the specified sentence.
     *
     * @param numTaggings The number of tagging to be returned.
     * @param sentence An array of tokens which make up a sentence.
     *
     * @return At most the specified number of taggings for the specified sentence.
     */
    public String[][] tag(int numTaggings, String[] sentence) {
        Sequence[] bestSequences = model.bestSequences(numTaggings, sentence, null, contextGen, sequenceValidator);
        String[][] tags = new String[bestSequences.length][];
        for (int si = 0; si < tags.length; si++) {
            List<String> t = bestSequences[si].getOutcomes();
            tags[si] = t.toArray(new String[t.size()]);
        }
        return tags;
    }

    public Sequence[] topKSequences(String[] sentence) {
        return this.topKSequences(sentence, null);
    }

    public Sequence[] topKSequences(String[] sentence, Object[] additionaContext) {
        return model.bestSequences(size, sentence, additionaContext, contextGen, sequenceValidator);
    }

    /**
     * Populates the specified array with the probabilities for each tag of the last tagged sentence.
     *
     * @param probs An array to put the probabilities into.
     */
    public void probs(double[] probs) {
        bestSequence.getProbs(probs);
    }

    /**
     * Returns an array with the probabilities for each tag of the last tagged sentence.
     *
     * @return an array with the probabilities for each tag of the last tagged sentence.
     */
    public double[] probs() {
        return bestSequence.getProbs();
    }

    public String[] getOrderedTags(List<String> words, List<String> tags, int index) {
        return getOrderedTags(words, tags, index, null);
    }

    public String[] getOrderedTags(List<String> words, List<String> tags, int index, double[] tprobs) {

        if (modelPackage.getPosModel() != null) {

            MaxentModel posModel = modelPackage.getPosModel();

            double[] probs = posModel.eval(contextGen.getContext(index, words.toArray(new String[words.size()]),
                    tags.toArray(new String[tags.size()]), null));

            String[] orderedTags = new String[probs.length];
            for (int i = 0; i < probs.length; i++) {
                int max = 0;
                for (int ti = 1; ti < probs.length; ti++) {
                    if (probs[ti] > probs[max]) {
                        max = ti;
                    }
                }
                orderedTags[i] = posModel.getOutcome(max);
                if (tprobs != null) {
                    tprobs[i] = probs[max];
                }
                probs[max] = 0;
            }
            return orderedTags;
        } else {
            throw new UnsupportedOperationException(
                    "This method can only be called if the " + "classifcation model is an event model!");
        }
    }

    public static POSModel train(String languageCode, ObjectStream<POSSample> samples,
            TrainingParameters trainParams, POSTaggerFactory posFactory) throws IOException {

        int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE);

        POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator();

        Map<String, String> manifestInfoEntries = new HashMap<>();

        TrainerType trainerType = TrainerFactory.getTrainerType(trainParams);

        MaxentModel posModel = null;
        SequenceClassificationModel<String> seqPosModel = null;
        if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
            ObjectStream<Event> es = new POSSampleEventStream(samples, contextGenerator);

            EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, manifestInfoEntries);
            posModel = trainer.train(es);
        } else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
            POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
            EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(trainParams,
                    manifestInfoEntries);
            posModel = trainer.train(ss);
        } else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
            SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(trainParams, manifestInfoEntries);

            // TODO: This will probably cause issue, since the feature generator uses the outcomes array

            POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
            seqPosModel = trainer.train(ss);
        } else {
            throw new IllegalArgumentException("Trainer type is not supported: " + trainerType);
        }

        if (posModel != null) {
            return new POSModel(languageCode, posModel, beamSize, manifestInfoEntries, posFactory);
        } else {
            return new POSModel(languageCode, seqPosModel, manifestInfoEntries, posFactory);
        }
    }

    public static Dictionary buildNGramDictionary(ObjectStream<POSSample> samples, int cutoff) throws IOException {

        NGramModel ngramModel = new NGramModel();

        POSSample sample;
        while ((sample = samples.read()) != null) {
            String[] words = sample.getSentence();

            if (words.length > 0)
                ngramModel.add(new StringList(words), 1, 1);
        }

        ngramModel.cutoff(cutoff, Integer.MAX_VALUE);

        return ngramModel.toDictionary(true);
    }

    public static void populatePOSDictionary(ObjectStream<POSSample> samples, MutableTagDictionary dict, int cutoff)
            throws IOException {
        System.out.println("Expanding POS Dictionary ...");
        long start = System.nanoTime();

        // the data structure will store the word, the tag, and the number of
        // occurrences
        Map<String, Map<String, AtomicInteger>> newEntries = new HashMap<>();
        POSSample sample;
        while ((sample = samples.read()) != null) {
            String[] words = sample.getSentence();
            String[] tags = sample.getTags();

            for (int i = 0; i < words.length; i++) {
                // only store words
                if (!StringPattern.recognize(words[i]).containsDigit()) {
                    String word;
                    if (dict.isCaseSensitive()) {
                        word = words[i];
                    } else {
                        word = StringUtil.toLowerCase(words[i]);
                    }

                    if (!newEntries.containsKey(word)) {
                        newEntries.put(word, new HashMap<>());
                    }

                    String[] dictTags = dict.getTags(word);
                    if (dictTags != null) {
                        for (String tag : dictTags) {
                            // for this tags we start with the cutoff
                            Map<String, AtomicInteger> value = newEntries.get(word);
                            if (!value.containsKey(tag)) {
                                value.put(tag, new AtomicInteger(cutoff));
                            }
                        }
                    }

                    if (!newEntries.get(word).containsKey(tags[i])) {
                        newEntries.get(word).put(tags[i], new AtomicInteger(1));
                    } else {
                        newEntries.get(word).get(tags[i]).incrementAndGet();
                    }
                }
            }
        }

        // now we check if the word + tag pairs have enough occurrences, if yes we
        // add it to the dictionary
        for (Entry<String, Map<String, AtomicInteger>> wordEntry : newEntries.entrySet()) {
            List<String> tagsForWord = new ArrayList<>();
            for (Entry<String, AtomicInteger> entry : wordEntry.getValue().entrySet()) {
                if (entry.getValue().get() >= cutoff) {
                    tagsForWord.add(entry.getKey());
                }
            }
            if (tagsForWord.size() > 0) {
                dict.put(wordEntry.getKey(), tagsForWord.toArray(new String[tagsForWord.size()]));
            }
        }

        System.out.println(
                "... finished expanding POS Dictionary. [" + (System.nanoTime() - start) / 1000000 + "ms]");
    }
}