com.ss.language.model.gibblda.Estimator.java Source code

Java tutorial

Introduction

Here is the source code for com.ss.language.model.gibblda.Estimator.java

Source

/*
 * Copyright (C) 2007 by
 * 
 *    Xuan-Hieu Phan
 *   hieuxuan@ecei.tohoku.ac.jp or pxhieu@gmail.com
 *    Graduate School of Information Sciences
 *    Tohoku University
 * 
 *  Cam-Tu Nguyen
 *  ncamtu@gmail.com
 *  College of Technology
 *  Vietnam National University, Hanoi
 *
 * JGibbsLDA is a free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published
 * by the Free Software Foundation; either version 2 of the License,
 * or (at your option) any later version.
 *
 * JGibbsLDA is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with JGibbsLDA; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
 */

package com.ss.language.model.gibblda;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;

public class Estimator {

    // output model
    protected Model trnModel;
    LDACmdOption option;

    public Estimator(LDACmdOption option) {
        LDACmdOption.curOption.set(option);
        init(option);
    }

    protected boolean init(LDACmdOption option) {
        this.option = option;
        trnModel = new Model();
        if (option.est) {
            if (!trnModel.initNewModel(option))
                return false;
            trnModel.data.localDict.writeWordMap(option.dir + File.separator + option.wordMapFileName);
            writeEachwordsEachWord(trnModel.data.docs);
        } else if (option.estc) {
            if (!trnModel.initEstimatedModel(option))
                return false;
        }

        return true;
    }

    /**
     * ???
     * 
     * @param docs
     */
    private void writeEachwordsEachWord(Document[] docs) {
        if (docs != null && docs.length > 0) {
            BufferedReader br = null;
            try {
                br = new BufferedReader(new InputStreamReader(
                        new FileInputStream(trnModel.data.localDict.getWordIdsFile()), "UTF-8"));
                for (String wordId = br.readLine(); wordId != null; wordId = br.readLine()) {
                    wordId = wordId == null ? "" : wordId.trim();
                    if (wordId.isEmpty()) {
                        continue;
                    }
                    StringBuffer sb = new StringBuffer();
                    for (Document doc : docs) {
                        String[] words = doc.getAllWords();
                        if (words != null && words.length > 0) {
                            int times = 0;
                            for (String w : words) {
                                if (wordId.equals(w)) {
                                    times += 1;
                                }
                            }
                            if (times > 0) {
                                sb.append("(");
                                sb.append(doc.getDocId());
                                sb.append(":");
                                sb.append(times);
                                sb.append("),");
                            }
                        }
                    }
                    // ???
                    if (sb.length() > 0) {
                        File file = new File(
                                option.dir + File.separator + option.wordMapFileName + "-statistic.txt");
                        sb.insert(0, "[");
                        sb.insert(sb.length() - 1, "]");
                        FileUtils.write(file, sb.subSequence(0, sb.length() - 1) + IOUtils.LINE_SEPARATOR, "UTF-8",
                                true);
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                if (br != null) {
                    try {
                        br.close();
                    } catch (Exception e) {
                    }
                }
            }
        }
    }

    public void estimate() {
        System.out.println("Sampling " + trnModel.niters + " iteration!");

        int lastIter = trnModel.liter;
        for (trnModel.liter = lastIter + 1; trnModel.liter < trnModel.niters + lastIter; trnModel.liter++) {
            System.out.println("Iteration " + trnModel.liter + " ...");

            // for all z_i
            for (int m = 0; m < trnModel.M; m++) {
                for (int n = 0; n < trnModel.data.docs[m].getLength(); n++) {
                    // z_i = z[m][n]
                    // sample from p(z_i|z_-i, w)
                    int topic = sampling(m, n);
                    trnModel.z[m].set(n, topic);
                } // end for each word
            } // end for each document

            if (option.savestep > 0) {
                if (trnModel.liter % option.savestep == 0) {
                    System.out.println("Saving the model at iteration " + trnModel.liter + " ...");
                    computeTheta();
                    computePhi();
                    trnModel.saveModel("model-" + Conversion.ZeroPad(trnModel.liter, 5));
                }
            }
        } // end iterations

        System.out.println("Gibbs sampling completed!\n");
        System.out.println("Saving the final model!\n");
        computeTheta();
        computePhi();
        trnModel.liter--;
        trnModel.saveModel("model-final");
    }

    /**
     * Do sampling
     * 
     * @param m
     *            document number
     * @param n
     *            word number
     * @return topic id
     */
    public int sampling(int m, int n) {
        // remove z_i from the count variable
        int topic = trnModel.z[m].get(n);
        int w = trnModel.data.docs[m].getWord(n);
        if (w < trnModel.V) {
            trnModel.nw[w][topic] -= 1;
        }
        trnModel.nd[m][topic] -= 1;
        trnModel.nwsum[topic] -= 1;
        trnModel.ndsum[m] -= 1;

        double Vbeta = trnModel.V * trnModel.beta;
        double Kalpha = trnModel.K * trnModel.alpha;

        // do multinominal sampling via cumulative method
        for (int k = 0; k < trnModel.K; k++) {
            if (w < trnModel.V) {
                trnModel.p[k] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + Vbeta)
                        * (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + Kalpha);
            }
        }

        // cumulate multinomial parameters
        for (int k = 1; k < trnModel.K; k++) {
            trnModel.p[k] += trnModel.p[k - 1];
        }

        // scaled sample because of unnormalized p[]
        double u = Math.random() * trnModel.p[trnModel.K - 1];

        for (topic = 0; topic < trnModel.K; topic++) {
            if (trnModel.p[topic] > u) // sample topic w.r.t distribution p
                break;
        }

        // add newly estimated z_i to count variables
        if (w < trnModel.V) {
            trnModel.nw[w][topic] += 1;
        }
        trnModel.nd[m][topic] += 1;
        trnModel.nwsum[topic] += 1;
        trnModel.ndsum[m] += 1;

        return topic;
    }

    public void computeTheta() {
        for (int m = 0; m < trnModel.M; m++) {
            for (int k = 0; k < trnModel.K; k++) {
                trnModel.theta.save(m, k,
                        (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha));
            }
        }
    }

    public void computePhi() {
        for (int k = 0; k < trnModel.K; k++) {
            for (int w = 0; w < trnModel.V; w++) {
                if (w < trnModel.V) {
                    trnModel.phi.save(k, w,
                            (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta));
                }
            }
        }
    }
}