hu.ppke.itk.nlpg.purepos.decoder.BeamedViterbi.java Source code

Java tutorial

Introduction

Here is the source code for hu.ppke.itk.nlpg.purepos.decoder.BeamedViterbi.java

Source

/*******************************************************************************
 * Copyright (c) 2012 Gyrgy Orosz, Attila Novk.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the GNU Lesser Public License v3
 * which accompanies this distribution, and is available at
 * http://www.gnu.org/licenses/
 * 
 * This file is part of PurePos.
 * 
 * PurePos is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * PurePos is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser Public License for more details.
 * 
 * Contributors:
 *     Gyrgy Orosz - initial API and implementation
 ******************************************************************************/
package hu.ppke.itk.nlpg.purepos.decoder;

import hu.ppke.itk.nlpg.purepos.model.internal.CompiledModel;
import hu.ppke.itk.nlpg.purepos.model.internal.NGram;
import hu.ppke.itk.nlpg.purepos.morphology.IMorphologicalAnalyzer;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import org.apache.commons.lang3.tuple.Pair;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.collect.Table.Cell;

/**
 * Decoder that implements the Viterbi search method speed up with using beams.
 * 
 * @author Gyrgy Orosz
 * 
 */
public class BeamedViterbi extends AbstractDecoder {
    public BeamedViterbi(CompiledModel<String, Integer> model, IMorphologicalAnalyzer morphologicalAnalyzer,
            double logTheta, double sufTheta, int maxGuessedTags) {
        super(model, morphologicalAnalyzer, logTheta, sufTheta, maxGuessedTags);

    }

    // protected Logger logger = Logger.getLogger(this.getClass());

    @Override
    public List<Pair<List<Integer>, Double>> decode(List<String> observations, int maxResultsNumber) {
        List<String> obs = prepareObservations(observations);
        NGram<Integer> startNGram = createInitialElement();

        List<Pair<List<Integer>, Double>> tagSeqList = beamedSearch(startNGram, obs, maxResultsNumber);
        List<Pair<List<Integer>, Double>> ret = cleanResults(tagSeqList);
        return ret;
    }

    //
    // public List<Integer> beamedSearch(final NGram<Integer> start,
    // final List<String> obs) {
    // return beamedSearch(start, obs, 1).get(0);
    //
    // }

    public List<Pair<List<Integer>, Double>> beamedSearch(final NGram<Integer> start,
            final List<String> observations, int resultsNumber) {
        HashMap<NGram<Integer>, Node> beam = new HashMap<NGram<Integer>, Node>();

        beam.put(start, startNode(start));
        boolean isFirst = true;
        int pos = 0;
        for (String obs : observations) {
            // System.err.println(obs);

            // logger.trace("Current observation " + obs);
            // logger.trace("\tCurrent states:");
            // for (Entry<NGram<Integer>, Node> entry : beam.entrySet()) {
            // logger.trace("\t\t" + entry.getKey() + " - " + entry.getValue());
            // }

            HashMap<NGram<Integer>, Node> newBeam = new HashMap<NGram<Integer>, Node>();

            Table<NGram<Integer>, Integer, Double> nextProbs = HashBasedTable.create();
            Map<NGram<Integer>, Double> obsProbs = new HashMap<NGram<Integer>, Double>();
            Set<NGram<Integer>> contexts = beam.keySet();

            Map<NGram<Integer>, Map<Integer, Pair<Double, Double>>> nexts = getNextProbs(contexts, obs, pos,
                    isFirst);

            for (Map.Entry<NGram<Integer>, Map<Integer, Pair<Double, Double>>> nextsEntry : nexts.entrySet()) {
                NGram<Integer> context = nextsEntry.getKey();
                Map<Integer, Pair<Double, Double>> nextContextProbs = nextsEntry.getValue();
                for (Map.Entry<Integer, Pair<Double, Double>> entry : nextContextProbs.entrySet()) {
                    Integer tag = entry.getKey();
                    nextProbs.put(context, tag, entry.getValue().getLeft());
                    obsProbs.put(context.add(tag), entry.getValue().getRight());
                }
            }
            // for (Integer t : nextProbs.keySet()) {
            // logger.trace("\t\tNext node:" + context + t);
            // logger.trace("\t\tnode currentprob:"
            // + (beam.get(context) + nextProbs.get(t).getLeft()));
            // logger.trace("\t\tnode emissionprob:"
            // + nextProbs.get(t).getRight());
            // logger.trace("\n");
            // // logger.trace("\t\tNext node:" + context + t);
            // }
            for (Cell<NGram<Integer>, Integer, Double> cell : nextProbs.cellSet()) {
                Integer nextTag = cell.getColumnKey();
                NGram<Integer> context = cell.getRowKey();
                Double transVal = cell.getValue();
                NGram<Integer> newState = context.add(nextTag);
                Node from = beam.get(context);
                double newVal = transVal + beam.get(context).getWeight();
                update(newBeam, newState, newVal, from);
            }
            // adding observation probabilities
            // logger.trace("beam" + newBeam);
            if (nextProbs.size() > 1)
                for (NGram<Integer> tagSeq : newBeam.keySet()) {
                    // Integer tag = tagSeq.getLast();
                    Node node = newBeam.get(tagSeq);
                    // Double prevVal = node.getWeight();

                    Double obsProb = obsProbs.get(tagSeq);
                    // logger.trace("put to beam: " + context + "(from) "
                    // + tagSeq + " " + prevVal + "+" + obsProb);
                    node.setWeight(obsProb + node.getWeight());
                }

            beam = prune(newBeam);
            isFirst = false;
            // for (Entry<NGram<Integer>, Node> e : beam.entrySet()) {
            // logger.trace("\t\tNode state: " + e.getKey() + " "
            // + e.getValue());
            // }
            ++pos;
        }
        return findMax(beam, resultsNumber);
    }

    private List<Pair<List<Integer>, Double>> findMax(final HashMap<NGram<Integer>, Node> beam, int resultsNumber) {

        // Node max = Collections.max(beam.values());
        // Node act = max;
        // return decompose(max);

        SortedSet<Node> sortedKeys = new TreeSet<Node>(beam.values());

        List<Pair<List<Integer>, Double>> ret = new ArrayList<Pair<List<Integer>, Double>>();
        Node max;
        for (int i = 0; i < resultsNumber && !sortedKeys.isEmpty(); ++i) {
            max = sortedKeys.last();
            sortedKeys.remove(max);
            List<Integer> maxTagSeq = decompose(max);
            ret.add(Pair.of(maxTagSeq, max.weight));
        }
        return ret;

    }

    private HashMap<NGram<Integer>, Node> prune(final HashMap<NGram<Integer>, Node> beam) {

        HashMap<NGram<Integer>, Node> ret = new HashMap<NGram<Integer>, Node>();
        // System.err.println(beam);
        // try {
        Node maxNode = Collections.max(beam.values());
        Double max = maxNode.getWeight();
        for (NGram<Integer> key : beam.keySet()) {
            Node actNode = beam.get(key);
            Double actVal = actNode.getWeight();
            if (!(actVal < max - logTheta)) {
                ret.put(key, actNode);
            }
        }
        // } catch (Exception e) {
        // e.printStackTrace();
        // System.err.println(beam);
        // }
        return ret;

    }

    private void update(HashMap<NGram<Integer>, Node> beam, NGram<Integer> newState, Double newWeight,
            Node fromNode) {

        if (!beam.containsKey(newState)) {
            // logger.trace("\t\t\tAS: " + newNGram + " from " + context
            // + " with " + newValue);

            beam.put(newState, new Node(newState, newWeight, fromNode));

        } else if (beam.get(newState).getWeight() < newWeight) {
            // logger.trace("\t\t\tUS: " + old + " to " + newNGram + " from "
            // + context + " with " + newValue);
            beam.get(newState).setPrevious(fromNode);
            beam.get(newState).setWeight(newWeight);
        } else {
            // logger.trace("\t\t\tNU: " + old + " to " + newNGram + " from "
            // + context + " with " + newValue);
        }
    }

}