Java tutorial
/* * Copyright 2014 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.lt.lm.lucenebased; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Properties; import org.apache.commons.collections.IteratorUtils; import org.apache.commons.collections.Transformer; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.StringUtils; import org.apache.lucene.document.Document; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.MultiFields; import org.apache.lucene.index.Term; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.Bits; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import de.tudarmstadt.lt.lm.AbstractLanguageModel; import de.tudarmstadt.lt.lm.LanguageModel; import de.tudarmstadt.lt.lm.PseudoSymbol; import de.tudarmstadt.lt.lm.service.LMProviderUtils; /** * * @author Steffen Remus */ public class CountingStringLM extends AbstractLanguageModel<String> implements LanguageModel<String> { private static final Logger LOG = LoggerFactory.getLogger(CountingStringLM.class); protected static Document _UNKOWN_NGRAM_LUCENE_DOCUMENT; protected static Document _UNKOWN_WORD_LUCENE_DOCUMENT; static { _UNKOWN_WORD_LUCENE_DOCUMENT = new Document(); _UNKOWN_WORD_LUCENE_DOCUMENT.add(new StoredField("word", PseudoSymbol.UNKOWN_WORD.asString())); _UNKOWN_NGRAM_LUCENE_DOCUMENT = new Document(); _UNKOWN_NGRAM_LUCENE_DOCUMENT.add(new StoredField("ngram", PseudoSymbol.UNKOWN_WORD.asString())); _UNKOWN_NGRAM_LUCENE_DOCUMENT.add(new StoredField("num", 0)); _UNKOWN_NGRAM_LUCENE_DOCUMENT.add(new StoredField("cardinality", 0)); } protected int _order; protected IndexReader _reader_ngram; protected IndexSearcher _searcher_ngram; protected IndexReader _reader_vocab; protected IndexSearcher _searcher_vocab; protected double[][] _num_ngrams; protected double[] _sum_ngrams; protected double[][] _N; protected boolean _fixed = false; public CountingStringLM(int order, File index_dir) { _order = order; try { LOG.info("Loading index from or creating index in '{}'.", index_dir.getAbsolutePath()); File index_dir_vocab = new File(index_dir, "vocab"); File index_dir_ngram = new File(index_dir, "ngram"); _fixed = true; Directory directory = MMapDirectory.open(index_dir_ngram); // directory = new RAMDirectory(directory, IOContext.DEFAULT); _reader_ngram = DirectoryReader.open(directory); _searcher_ngram = new IndexSearcher(_reader_ngram); directory = MMapDirectory.open(index_dir_vocab); // directory = new RAMDirectory(directory, IOContext.DEFAULT); _reader_vocab = DirectoryReader.open(directory); _searcher_vocab = new IndexSearcher(_reader_vocab); LOG.info("Computing number of ngram occurrences."); File sumfile = new File(index_dir, "__sum_ngrams__"); try { InputStream in = new FileInputStream(sumfile); Properties p = new Properties(); p.load(in); in.close(); int max_n = Math.max(_order, Integer.parseInt(p.getProperty("max_n"))); if (max_n < order) LOG.error("max_n={} in {} is smaller than the order of the language model ({}).", max_n, sumfile, order); int max_c = Integer.parseInt(p.getProperty("max_c")); _N = new double[max_n + 1][max_c]; _sum_ngrams = new double[max_n + 1]; for (String name : p.stringPropertyNames()) { if (name.startsWith("n")) { int n = Integer.parseInt(name.substring(1, name.length())); String[] v = p.getProperty(name).split(","); for (int i = 0; i < v.length; i++) { _N[n][i] = Double.parseDouble(v[i]); } } else if (name.startsWith("s")) { int n = Integer.parseInt(name.substring(1, name.length())); _sum_ngrams[n] = Double.parseDouble(p.getProperty(name)); } } } catch (Exception e) { LOG.error("Could not read ngram sum file '{}'.", sumfile, e); _N = new double[order + 1][6]; _sum_ngrams = new double[order + 1]; } _num_ngrams = new double[_N.length][4]; long sum = 0; for (int n = 0; n < _N.length; n++) { for (int i = 0; i < 3; i++) _num_ngrams[n][i] = _N[n][i]; for (int i = 3; i < _N[n].length; i++) _num_ngrams[n][3] += _N[n][i]; sum += _num_ngrams[n][0]; } LOG.info("Number of Ngrams {}.", _searcher_ngram.collectionStatistics("ngram").docCount()); LOG.info("Number of Ngrams {}.", sum); LOG.info("Vocabulary Size {}.", _searcher_vocab.collectionStatistics("word").docCount()); } catch (IOException e) { LOG.error("Could not open lucene index: Dir={}; Dir exists={}; ", index_dir, index_dir.exists() && index_dir.isDirectory(), e); } } @Override public int getOrder() { return _order; } @Override public String predictNextWord(List<String> history_words) { // // // check length // assert history_words.size() >= getOrder() - 1 : "Length of history must be at least of ngram order 1."; // List<String> pruned_history_words = history_words.subList(Math.max(1, (history_words.size() - getOrder()) + 1), history_words.size()); // // LOG.debug("History: {}; pruned: {}.", history_words.toString(), pruned_history_words.toString()); // // // List<Integer> wordIds = toIntegerList(pruned_history_words); // // wordIds.add(-1); // List<String> ngram = new ArrayList<String>(pruned_history_words); // ngram.add("<unk>"); // int lastIndex = ngram.size() - 1; // double max_value = -Double.MAX_VALUE; // String max_word = ngram.get(lastIndex); // // for(Iterator<String> iter = getVocabularyIterator(); iter.hasNext();){ // ngram.set(lastIndex, iter.next()); // double logprob = getNgramLogProbability(ngram); // LOG.trace("Word '{}' log10Prob: ", ngram.get(lastIndex) , logprob); // if (logprob > max_value) { // max_value = logprob; // max_word = ngram.get(lastIndex); // } // } // return max_word; // FIXME: return ""; } @Override public String getWord(int wordId) { try { if (wordId < 0 || wordId > _searcher_vocab.getIndexReader().maxDoc()) return PseudoSymbol.UNKOWN_WORD.asString(); return _searcher_vocab.doc(wordId).get("word"); } catch (IOException e) { LOG.error("Could not get word for id {}. Querying luceneindex failed. Max doc: {}.", wordId, _searcher_vocab.getIndexReader().maxDoc(), e); return null; } } @Override public int getWordIndex(String word) { if (word == null) return -1; Query query = new TermQuery(new Term("word", word)); try { ScoreDoc[] hits = _searcher_vocab.search(query, null, 2).scoreDocs; if (hits.length < 1) return -1; if (hits.length > 1) LOG.warn("Found more than one entry for '{}', expected only one.", word); return hits[0].doc; } catch (IOException e) { LOG.error("Could not get id for word {}. Querying luceneindex failed: {}.", word, query); return -1; } } @Override public double getNgramLogProbability(int[] wordIds) { return getNgramLogProbabilityFromIds(Arrays.asList(ArrayUtils.toObject(wordIds))); } public double getNgramLogProbabilityFromIds(List<Integer> ngram) { return getNgramLogProbability(toWordList(ngram)); } @Override public double getNgramLogProbability(List<String> ngram) { // check length assert ngram .size() <= _order : "Length of Ngram must be lower or equal to the order of the language model."; if (ngram.size() < 1) return Double.NEGATIVE_INFINITY; // c(w_1 ... w_n) Long nominator = getQuantity(ngram); if (nominator == 0) return Double.NEGATIVE_INFINITY; // c(w_1) / N if (ngram.size() == 1) return Math.log10((double) nominator) - Math.log10(_num_ngrams[1][0]); // c(w_1 ... w_n-1) Long denominator = getQuantity(ngram.subList(0, ngram.size() - 1)); if (denominator == 0) return Double.NEGATIVE_INFINITY; double logprob = Math.log10((double) nominator) - Math.log10((double) denominator); return logprob; } public long getQuantity(List<String> ngram) { if (ngram == null) throw new IllegalAccessError("Ngram is null."); if (ngram.isEmpty()) return 0L; String ngram_str = StringUtils.join(ngram, ' '); Term query_term = new Term("ngram", ngram_str); Query query = new TermQuery(query_term); try { Document doc = null; ScoreDoc[] hits = _searcher_ngram.search(query, 2).scoreDocs; if (hits.length >= 1) { if (hits.length > 1) LOG.warn("Found more than one entry for '{}', expected only one.", ngram_str); doc = _searcher_ngram.doc(hits[0].doc); return getQuantity(doc); } } catch (IOException e) { LOG.error("Could not get ngram {}. Luceneindex failed.", ngram_str, e); } return 0L; } long getQuantity(Document ngram) { return getNgramCountFromDoc(ngram); } long getNgramCountFromDoc(Document ngram) { IndexableField field = ngram.getField("num"); return field.numericValue().longValue(); } int getCardinality(Document ngram) { IndexableField field = ngram.getField("cardinality"); return field.numericValue().intValue(); } String getNgramString(Document ngram) { IndexableField field = ngram.getField("ngram"); if (field == null) return "<unk>"; return field.stringValue(); } String getWordString(Document word) { IndexableField field = word.getField("word"); if (field == null) return "<unk>"; return field.stringValue(); } public Document getNgramLuceneDoc(List<String> ngram) { if (ngram == null) throw new IllegalAccessError("Ngram is null."); if (ngram.isEmpty()) return _UNKOWN_NGRAM_LUCENE_DOCUMENT; String ngram_str = StringUtils.join(ngram, ' '); return getNgramLuceneDoc(ngram_str); } public Document getNgramLuceneDoc(String ngram_str) { if (ngram_str == null) throw new IllegalAccessError("Ngram is null."); if (ngram_str.isEmpty()) return _UNKOWN_NGRAM_LUCENE_DOCUMENT; Query query = new TermQuery(new Term("ngram", ngram_str)); try { Document doc = null; ScoreDoc[] hits = _searcher_ngram.search(query, 2).scoreDocs; if (hits.length >= 1) { if (hits.length > 1) LOG.warn("Found more than one entry for '{}', expected only one.", ngram_str); doc = _searcher_ngram.doc(hits[0].doc); return doc; } } catch (IOException e) { LOG.error("Could not get ngram {}. Luceneindex failed.", ngram_str, e); } return _UNKOWN_NGRAM_LUCENE_DOCUMENT; } public Document getWordLuceneDoc(String word) { if (word == null) throw new IllegalAccessError("Word is null."); if (de.tudarmstadt.lt.utilities.StringUtils.trim(word).isEmpty()) return _UNKOWN_WORD_LUCENE_DOCUMENT; Query query = new TermQuery(new Term("word", word)); try { Document doc = null; ScoreDoc[] hits = _searcher_vocab.search(query, 2).scoreDocs; if (hits.length >= 1) { if (hits.length > 1) LOG.warn("Found more than one entry for '{}', expected only one.", word); doc = _searcher_vocab.doc(hits[0].doc); return doc; } } catch (IOException e) { LOG.error("Could not get word {}. Luceneindex failed.", word, e); } return _UNKOWN_WORD_LUCENE_DOCUMENT; } static List<String> getNgramHistory(List<String> ngram) { assert ngram.size() > 0 : "Ngram must be longer than 0!"; return ngram.subList(0, ngram.size() - 1); } static List<String> getLowerOrderNgram(List<String> ngram) { assert ngram.size() > 0 : "Ngram must be longer than 0!"; return ngram.subList(1, ngram.size()); } static boolean isDefined(double d) { return !(Double.isInfinite(d) || Double.isNaN(d)) && d > 0; } @Override public Iterator<List<String>> getNgramIterator() { return new NgramIterator(); } public Iterator<String> getVocabularyIterator() { return new VocabularyIterator(); } @SuppressWarnings("unchecked") @Override public Iterator<List<Integer>> getNgramIdIterator() { return IteratorUtils.transformedIterator(getNgramIterator(), new Transformer() { @Override public Object transform(Object ngram) { return getNgramAsIds((List<String>) ngram); } }); } // public int addSequence(List<String> sequence) throws IllegalAccessException { // if(sequence.size() < 1) // return 0; // int c = 0; // for(int n = 1; n <= getOrder(); n++){ // List<String>[] ngram_sequence = LMProviderUtils.getNgramSequence(sequence, n); // for (List<String> ngram : ngram_sequence) // addNgram(ngram); // c += ngram_sequence.length; // } // return c; // } public int addNgramSequence(List<String>[] ngram_sequence) throws IllegalAccessException { for (List<String> ngram : ngram_sequence) addNgram(ngram); return ngram_sequence.length; } public int addNgram(List<String> ngram) throws IllegalAccessException { throw new UnsupportedOperationException(); } public int addNgramAsIds(List<Integer> ngram) throws IllegalAccessException { List<String> ngram_s = Arrays.asList(new String[ngram.size()]); for (int i = 0; i < ngram.size(); i++) { String w = getWord(ngram.get(i)); ngram_s.set(i, w == null ? String.format("_%d_", ngram.get(i)) : w); } return addNgram(ngram_s); } public Integer getOrAddWord(String word) throws IllegalAccessException { throw new UnsupportedOperationException(); } private List<String> toWordList(List<Integer> ngramIds) { List<String> ngram = new ArrayList<String>(ngramIds.size()); for (Integer ngram_i : ngramIds) ngram.add(getWord(ngram_i)); return ngram; } public boolean isFix() { return _fixed; } public void fixItNow() { _fixed = true; } public class NgramIterator implements Iterator<List<String>> { Bits _liveDocs = MultiFields.getLiveDocs(CountingStringLM.this._reader_ngram); int current_docid = -1; List<String> current_ngram = null; /* (non-Javadoc) * @see java.util.Iterator#hasNext() */ @Override public boolean hasNext() { if (++current_docid >= CountingStringLM.this._reader_ngram.maxDoc()) return false; current_ngram = null; for (; current_docid < CountingStringLM.this._reader_ngram.maxDoc(); current_docid++) { if (_liveDocs != null && !_liveDocs.get(current_docid)) continue; try { Document doc = CountingStringLM.this._reader_ngram.document(current_docid); if (getCardinality(doc) != _order) continue; current_ngram = Arrays.asList(getNgramString(doc).split(" ")); break; } catch (IOException e) { LOG.error("Could not get ngram lucene doc with id {}.", current_docid, e); return false; } } return current_ngram != null; } /* (non-Javadoc) * @see java.util.Iterator#next() */ @Override public List<String> next() { return current_ngram; } } public class VocabularyIterator implements Iterator<String> { Bits _liveDocs = MultiFields.getLiveDocs(CountingStringLM.this._reader_vocab); int current_docid = -1; String current_word = null; /* (non-Javadoc) * @see java.util.Iterator#hasNext() */ @Override public boolean hasNext() { if (++current_docid >= CountingStringLM.this._reader_vocab.maxDoc()) return false; current_word = null; for (; current_docid < CountingStringLM.this._reader_vocab.maxDoc(); current_docid++) { if (_liveDocs != null && !_liveDocs.get(current_docid)) continue; try { Document doc = CountingStringLM.this._reader_vocab.document(current_docid); current_word = getWordString(doc); break; } catch (IOException e) { LOG.error("Could not get ngram lucene doc with id {}.", current_docid, e); return false; } } return current_word != null; } /* (non-Javadoc) * @see java.util.Iterator#next() */ @Override public String next() { return current_word; } } }