Java tutorial
/* * Copyright (C) 2007 by * * Xuan-Hieu Phan * hieuxuan@ecei.tohoku.ac.jp or pxhieu@gmail.com * Graduate School of Information Sciences * Tohoku University * * Cam-Tu Nguyen * ncamtu@gmail.com * College of Technology * Vietnam National University, Hanoi * * JGibbsLDA is a free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published * by the Free Software Foundation; either version 2 of the License, * or (at your option) any later version. * * JGibbsLDA is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with JGibbsLDA; if not, write to the Free Software Foundation, * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. */ package com.ss.language.model.gibblda; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.InputStreamReader; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; public class Estimator { // output model protected Model trnModel; LDACmdOption option; public Estimator(LDACmdOption option) { LDACmdOption.curOption.set(option); init(option); } protected boolean init(LDACmdOption option) { this.option = option; trnModel = new Model(); if (option.est) { if (!trnModel.initNewModel(option)) return false; trnModel.data.localDict.writeWordMap(option.dir + File.separator + option.wordMapFileName); writeEachwordsEachWord(trnModel.data.docs); } else if (option.estc) { if (!trnModel.initEstimatedModel(option)) return false; } return true; } /** * ??? * * @param docs */ private void writeEachwordsEachWord(Document[] docs) { if (docs != null && docs.length > 0) { BufferedReader br = null; try { br = new BufferedReader(new InputStreamReader( new FileInputStream(trnModel.data.localDict.getWordIdsFile()), "UTF-8")); for (String wordId = br.readLine(); wordId != null; wordId = br.readLine()) { wordId = wordId == null ? "" : wordId.trim(); if (wordId.isEmpty()) { continue; } StringBuffer sb = new StringBuffer(); for (Document doc : docs) { String[] words = doc.getAllWords(); if (words != null && words.length > 0) { int times = 0; for (String w : words) { if (wordId.equals(w)) { times += 1; } } if (times > 0) { sb.append("("); sb.append(doc.getDocId()); sb.append(":"); sb.append(times); sb.append("),"); } } } // ??? if (sb.length() > 0) { File file = new File( option.dir + File.separator + option.wordMapFileName + "-statistic.txt"); sb.insert(0, "["); sb.insert(sb.length() - 1, "]"); FileUtils.write(file, sb.subSequence(0, sb.length() - 1) + IOUtils.LINE_SEPARATOR, "UTF-8", true); } } } catch (Exception e) { e.printStackTrace(); } finally { if (br != null) { try { br.close(); } catch (Exception e) { } } } } } public void estimate() { System.out.println("Sampling " + trnModel.niters + " iteration!"); int lastIter = trnModel.liter; for (trnModel.liter = lastIter + 1; trnModel.liter < trnModel.niters + lastIter; trnModel.liter++) { System.out.println("Iteration " + trnModel.liter + " ..."); // for all z_i for (int m = 0; m < trnModel.M; m++) { for (int n = 0; n < trnModel.data.docs[m].getLength(); n++) { // z_i = z[m][n] // sample from p(z_i|z_-i, w) int topic = sampling(m, n); trnModel.z[m].set(n, topic); } // end for each word } // end for each document if (option.savestep > 0) { if (trnModel.liter % option.savestep == 0) { System.out.println("Saving the model at iteration " + trnModel.liter + " ..."); computeTheta(); computePhi(); trnModel.saveModel("model-" + Conversion.ZeroPad(trnModel.liter, 5)); } } } // end iterations System.out.println("Gibbs sampling completed!\n"); System.out.println("Saving the final model!\n"); computeTheta(); computePhi(); trnModel.liter--; trnModel.saveModel("model-final"); } /** * Do sampling * * @param m * document number * @param n * word number * @return topic id */ public int sampling(int m, int n) { // remove z_i from the count variable int topic = trnModel.z[m].get(n); int w = trnModel.data.docs[m].getWord(n); if (w < trnModel.V) { trnModel.nw[w][topic] -= 1; } trnModel.nd[m][topic] -= 1; trnModel.nwsum[topic] -= 1; trnModel.ndsum[m] -= 1; double Vbeta = trnModel.V * trnModel.beta; double Kalpha = trnModel.K * trnModel.alpha; // do multinominal sampling via cumulative method for (int k = 0; k < trnModel.K; k++) { if (w < trnModel.V) { trnModel.p[k] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + Vbeta) * (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + Kalpha); } } // cumulate multinomial parameters for (int k = 1; k < trnModel.K; k++) { trnModel.p[k] += trnModel.p[k - 1]; } // scaled sample because of unnormalized p[] double u = Math.random() * trnModel.p[trnModel.K - 1]; for (topic = 0; topic < trnModel.K; topic++) { if (trnModel.p[topic] > u) // sample topic w.r.t distribution p break; } // add newly estimated z_i to count variables if (w < trnModel.V) { trnModel.nw[w][topic] += 1; } trnModel.nd[m][topic] += 1; trnModel.nwsum[topic] += 1; trnModel.ndsum[m] += 1; return topic; } public void computeTheta() { for (int m = 0; m < trnModel.M; m++) { for (int k = 0; k < trnModel.K; k++) { trnModel.theta.save(m, k, (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha)); } } } public void computePhi() { for (int k = 0; k < trnModel.K; k++) { for (int w = 0; w < trnModel.V; w++) { if (w < trnModel.V) { trnModel.phi.save(k, w, (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta)); } } } } }