Java tutorial
/******************************************************************************* * Copyright (c) 2012 Gyrgy Orosz, Attila Novk. * All rights reserved. This program and the accompanying materials * are made available under the terms of the GNU Lesser Public License v3 * which accompanies this distribution, and is available at * http://www.gnu.org/licenses/ * * This file is part of PurePos. * * PurePos is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * PurePos is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser Public License for more details. * * Contributors: * Gyrgy Orosz - initial API and implementation ******************************************************************************/ package hu.ppke.itk.nlpg.purepos.model.internal; import hu.ppke.itk.nlpg.purepos.model.INGramModel; import hu.ppke.itk.nlpg.purepos.model.IProbabilityModel; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.ListIterator; import java.util.Map; import java.util.Map.Entry; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; /** * N-gram model implementation which uses tries to store these elements. * (Similar to SRILM.) * * Tries are stored in HashTables, for the sake of efficiency. * * @author Gyrgy Orosz * * @param <C> * context type * @param <W> * word type */ public class NGramModel<W> extends INGramModel<Integer, W> implements Serializable { private static final long serialVersionUID = 5159356902216485765L; // Logger logger = Logger.getLogger(this.getClass()); protected IntTrieNode<W> root; /* * lambda1 is at position 1 and so on; lamda0 is seen to be used in Hunpos * when calculating probs: * * P(C| A B) = l3 * ML (C| A B) + l2 * ML (C | B) + l1 * ML (C) + l0 */ protected ArrayList<Double> lambdas; public NGramModel(int n) { super(n); root = new IntTrieNode<W>(IntVocabulary.getExtremalElement()); lambdas = new ArrayList<Double>(); } @Override public void addWord(List<Integer> context, W word) { ListIterator<Integer> iterator = context.listIterator(context.size()); IntTrieNode<W> act = root; int i = 0; int size = n - 1; act.addWord(word); while (iterator.hasPrevious() && i < size) { act = (IntTrieNode<W>) act.addChild(iterator.previous()); act.addWord(word); i++; } } @Override public List<Double> getWordFrequency(List<Integer> context, W word) { ArrayList<Double> ret = new ArrayList<Double>(); ret.add(root.getAprioriProb(word)); if (!(context == null || context.size() == 0)) { ListIterator<Integer> it = context.listIterator(context.size()); Integer previous; IntTrieNode<W> actNode = root; while (it.hasPrevious() && actNode != null) { previous = it.previous(); if (actNode.hasChild(previous)) { actNode = (IntTrieNode<W>) actNode.getChild(previous); ret.add(actNode.getAprioriProb(word)); } else { ret.add(0.0); while (it.hasPrevious()) { ret.add(0.0); } actNode = null; } } } return ret; } @Override public int getTotalFrequency() { return root.getNum(); } protected double calculateModifiedFreqVal(List<TrieNode<Integer, Integer, W>> nodeList, int position, W word) { double contextFreq = nodeList.get(position).getNum(); double wordFreq = nodeList.get(position).getWord(word); if (contextFreq == 1 || wordFreq == 1) return -1; else // TODO: RESEARCH: what if we would substract any value instead of // 1? return (wordFreq - 1) / (contextFreq - 1); } /** * Finds the maximal frequency element in a nodelist. * * @param list * @param word * @return */ protected Pair<Integer, Double> findMax(ArrayList<TrieNode<Integer, Integer, W>> list, W word) { Integer maxPos; Double maxVal; if (!(list == null || list.size() == 0)) { maxPos = -1; maxVal = 0.0; for (int i = list.size() - 1; i >= 0; --i) { double val = calculateModifiedFreqVal(list, i, word); if (val > maxVal) { maxPos = i; maxVal = val; } } } else { maxPos = null; maxVal = null; } ImmutablePair<Integer, Double> ret = new ImmutablePair<Integer, Double>(maxPos, maxVal); return ret; } @Override protected void calculateNGramLamdas() { adjustLamdas(); // logger.trace("pure lambdas: " + lambdas); // normalization double sum = 0.0; lambdas.set(0, 0.0); for (Double e : lambdas) { sum += e; } // logger.debug(lambdas.toString()); if (sum > 0) { for (int i = 0; i < lambdas.size(); ++i) { lambdas.set(i, lambdas.get(i) / sum); } } // logger.debug(lambdas); // logger.debug(lambdas.toString()); } /** * Calculate the lambdas, without smoothing */ protected void adjustLamdas() { lambdas = new ArrayList<Double>(); for (int i = 0; i < n + 1; ++i) { lambdas.add(0.0); } ArrayList<TrieNode<Integer, Integer, W>> acc = new ArrayList<TrieNode<Integer, Integer, W>>(); iterate(root, acc); } protected void iterate(TrieNode<Integer, Integer, W> node, ArrayList<TrieNode<Integer, Integer, W>> acc) { acc.add(node); if (node.getChildNodes() == null || node.getChildNodes().size() == 0) { for (W word : node.getWords().keySet()) { Pair<Integer, Double> max = findMax(acc, word); int index = max.getKey() + 1; if (max.getValue() != -1) { lambdas.set(index, lambdas.get(index) + node.getWord(word)); } // logger.debug("Max:" + max + " add:" + node.getWord(word) // + " to:" + index + " lambdas:" + lambdas); } } else { for (TrieNode<Integer, Integer, W> child : node.getChildNodes().values()) { iterate(child, acc); } } acc.remove(acc.size() - 1); } @Override public IProbabilityModel<Integer, W> createProbabilityModel() { // logger.trace("NGramModel: " + getReprString()); calculateNGramLamdas(); return new ProbModel<W>(root, lambdas); } @Override public Map<W, Integer> getWords() { return root.getWords(); } @Override public Map<W, Double> getWordAprioriProbs() { Map<W, Double> ret = new HashMap<W, Double>(); double sumFreg = root.getNum(); for (Entry<W, Integer> e : root.getWords().entrySet()) { double val = e.getValue(); ret.put(e.getKey(), val / sumFreg); } return ret; } public String getReprString() { calculateNGramLamdas(); return "tree:\n" + root.getReprString() + "lambdas: " + lambdas; } }