hivemall.topicmodel.OnlineLDAModel.java Source code

Java tutorial

Introduction

Here is the source code for hivemall.topicmodel.OnlineLDAModel.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.topicmodel;

import hivemall.annotations.VisibleForTesting;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.math.MathUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.special.Gamma;

public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {

    private static final double SHAPE = 100.d;
    private static final double SCALE = 1.d / SHAPE;

    // ---------------------------------
    // HyperParameters

    // prior on weight vectors "theta ~ Dir(alpha_)"
    private final float _alpha;

    // prior on topics "beta"
    private final float _eta;

    // positive value which downweights early iterations
    @Nonnegative
    private final double _tau0;

    // exponential decay rate (i.e., learning rate) which must be in (0.5, 1] to guarantee convergence
    @Nonnegative
    private final double _kappa;

    // check convergence in the expectation (E) step
    private final double _delta;

    // ---------------------------------

    // how many times EM steps are launched; later EM steps do not drastically forget old lambda
    private long _updateCount = 0L;

    // defined by (tau0 + updateCount)^(-kappa_)
    // controls how much old lambda is forgotten
    private double _rhot;

    // if `num_docs` option is not given, this flag will be true
    // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model
    private final boolean _isAutoD;

    // parameters
    private List<Map<String, float[]>> _phi;
    private float[][] _gamma;
    @Nonnull
    private final Map<String, float[]> _lambda;

    // random number generator
    @Nonnull
    private final GammaDistribution _gd;

    // for computing perplexity
    private float _docRatio = 1.f;
    private double _valueSum = 0.d;

    public OnlineLDAModel(int K, float alpha, double delta) { // for E step only instantiation
        this(K, alpha, 1 / 20.f, -1L, 1020, 0.7, delta);
    }

    public OnlineLDAModel(int K, float alpha, float eta, long D, double tau0, double kappa, double delta) {
        super(K);

        if (tau0 < 0.d) {
            throw new IllegalArgumentException("tau0 MUST be positive: " + tau0);
        }
        if (kappa <= 0.5 || 1.d < kappa) {
            throw new IllegalArgumentException("kappa MUST be in (0.5, 1.0]: " + kappa);
        }

        this._alpha = alpha;
        this._eta = eta;
        this._D = D;
        this._tau0 = tau0;
        this._kappa = kappa;
        this._delta = delta;

        this._isAutoD = (_D < 0L);

        // initialize a random number generator
        this._gd = new GammaDistribution(SHAPE, SCALE);
        _gd.reseedRandomGenerator(1001);

        // initialize the parameters
        this._lambda = new HashMap<String, float[]>(100);
    }

    @Override
    protected void accumulateDocCount() {
        /*
         * In a truly online setting, total number of documents equals to the number of documents that have ever seen.
         * In that case, users need to manually set the current max number of documents via this method.
         * Note that, since the same set of documents could be repeatedly passed to `train()`,
         * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
         */
        if (_isAutoD) {
            this._D += 1;
        }
    }

    protected void train(@Nonnull final String[][] miniBatch) {
        preprocessMiniBatch(miniBatch);

        initParams(true);

        // Expectation
        eStep();

        this._rhot = Math.pow(_tau0 + _updateCount, -_kappa);

        // Maximization
        mStep();

        _updateCount++;
    }

    private void preprocessMiniBatch(@Nonnull final String[][] miniBatch) {
        initMiniBatch(miniBatch, _miniBatchDocs);

        this._miniBatchSize = _miniBatchDocs.size();

        // accumulate the number of words for each documents
        double valueSum = 0.d;
        for (int d = 0; d < _miniBatchSize; d++) {
            for (Float n : _miniBatchDocs.get(d).values()) {
                valueSum += n.floatValue();
            }
        }
        this._valueSum = valueSum;

        this._docRatio = (float) ((double) _D / _miniBatchSize);
    }

    private void initParams(final boolean gammaWithRandom) {
        final List<Map<String, float[]>> phi = new ArrayList<Map<String, float[]>>();
        final float[][] gamma = new float[_miniBatchSize][];

        for (int d = 0; d < _miniBatchSize; d++) {
            if (gammaWithRandom) {
                gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd);
            } else {
                gamma[d] = ArrayUtils.newFloatArray(_K, 1.f);
            }

            final Map<String, float[]> phi_d = new HashMap<String, float[]>();
            phi.add(phi_d);
            for (final String label : _miniBatchDocs.get(d).keySet()) {
                phi_d.put(label, new float[_K]);
                if (!_lambda.containsKey(label)) { // lambda for newly observed word
                    _lambda.put(label, ArrayUtils.newRandomFloatArray(_K, _gd));
                }
            }
        }

        this._phi = phi;
        this._gamma = gamma;
    }

    private void eStep() {
        // since lambda is invariant in the expectation step,
        // `digamma`s of lambda values for Elogbeta are pre-computed
        final double[] lambdaSum = new double[_K];
        final Map<String, float[]> digamma_lambda = new HashMap<String, float[]>();
        for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
            String label = e.getKey();
            float[] lambda_label = e.getValue();

            // for digamma(lambdaSum)
            MathUtils.add(lambda_label, lambdaSum, _K);

            digamma_lambda.put(label, MathUtils.digamma(lambda_label));
        }

        final double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
        // for each of mini-batch documents, update gamma until convergence
        float[] gamma_d, gammaPrev_d;
        Map<String, float[]> eLogBeta_d;
        for (int d = 0; d < _miniBatchSize; d++) {
            gamma_d = _gamma[d];
            eLogBeta_d = computeElogBetaPerDoc(d, digamma_lambda, digamma_lambdaSum);

            do {
                gammaPrev_d = gamma_d.clone(); // deep copy the last gamma values

                updatePhiPerDoc(d, eLogBeta_d);
                updateGammaPerDoc(d);
            } while (!checkGammaDiff(gammaPrev_d, gamma_d));
        }
    }

    @Nonnull
    private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int d,
            @Nonnull final Map<String, float[]> digamma_lambda, @Nonnull final double[] digamma_lambdaSum) {
        final Map<String, Float> doc = _miniBatchDocs.get(d);

        // Dirichlet expectation (2d) for lambda
        final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>(doc.size());
        for (final String label : doc.keySet()) {
            float[] eLogBeta_label = eLogBeta_d.get(label);
            if (eLogBeta_label == null) {
                eLogBeta_label = new float[_K];
                eLogBeta_d.put(label, eLogBeta_label);
            }
            final float[] digamma_lambda_label = digamma_lambda.get(label);
            for (int k = 0; k < _K; k++) {
                eLogBeta_label[k] = (float) (digamma_lambda_label[k] - digamma_lambdaSum[k]);
            }
        }

        return eLogBeta_d;
    }

    private void updatePhiPerDoc(@Nonnegative final int d, @Nonnull final Map<String, float[]> eLogBeta_d) {
        // Dirichlet expectation (2d) for gamma
        final float[] gamma_d = _gamma[d];
        final double digamma_gammaSum_d = Gamma.digamma(MathUtils.sum(gamma_d));
        final double[] eLogTheta_d = new double[_K];
        for (int k = 0; k < _K; k++) {
            eLogTheta_d[k] = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
        }

        // updating phi w/ normalization
        final Map<String, float[]> phi_d = _phi.get(d);
        final Map<String, Float> doc = _miniBatchDocs.get(d);
        for (String label : doc.keySet()) {
            final float[] phi_label = phi_d.get(label);
            final float[] eLogBeta_label = eLogBeta_d.get(label);

            double normalizer = 0.d;
            for (int k = 0; k < _K; k++) {
                float phiVal = (float) Math.exp(eLogBeta_label[k] + eLogTheta_d[k]) + 1E-20f;
                phi_label[k] = phiVal;
                normalizer += phiVal;
            }

            for (int k = 0; k < _K; k++) {
                phi_label[k] /= normalizer;
            }
        }
    }

    private void updateGammaPerDoc(@Nonnegative final int d) {
        final Map<String, Float> doc = _miniBatchDocs.get(d);
        final Map<String, float[]> phi_d = _phi.get(d);

        final float[] gamma_d = _gamma[d];
        for (int k = 0; k < _K; k++) {
            gamma_d[k] = _alpha;
        }
        for (Map.Entry<String, Float> e : doc.entrySet()) {
            final float[] phi_label = phi_d.get(e.getKey());
            final float val = e.getValue().floatValue();
            for (int k = 0; k < _K; k++) {
                gamma_d[k] += phi_label[k] * val;
            }
        }
    }

    private boolean checkGammaDiff(@Nonnull final float[] gammaPrev, @Nonnull final float[] gammaNext) {
        double diff = 0.d;
        for (int k = 0; k < _K; k++) {
            diff += Math.abs(gammaPrev[k] - gammaNext[k]);
        }
        return (diff / _K) < _delta;
    }

    private void mStep() {
        // calculate lambdaTilde for vocabularies in the current mini-batch
        final Map<String, float[]> lambdaTilde = new HashMap<String, float[]>();
        for (int d = 0; d < _miniBatchSize; d++) {
            final Map<String, float[]> phi_d = _phi.get(d);
            for (String label : _miniBatchDocs.get(d).keySet()) {
                float[] lambdaTilde_label = lambdaTilde.get(label);
                if (lambdaTilde_label == null) {
                    lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta);
                    lambdaTilde.put(label, lambdaTilde_label);
                }

                final float[] phi_label = phi_d.get(label);
                for (int k = 0; k < _K; k++) {
                    lambdaTilde_label[k] += _docRatio * phi_label[k];
                }
            }
        }

        // update lambda for all vocabularies
        for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
            String label = e.getKey();
            final float[] lambda_label = e.getValue();

            float[] lambdaTilde_label = lambdaTilde.get(label);
            if (lambdaTilde_label == null) {
                lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta);
            }

            for (int k = 0; k < _K; k++) {
                lambda_label[k] = (float) ((1.d - _rhot) * lambda_label[k] + _rhot * lambdaTilde_label[k]);
            }
        }
    }

    /**
     * Calculate approximate perplexity for the current mini-batch.
     */
    protected float computePerplexity() {
        double bound = computeApproxBound();
        double perWordBound = bound / (_docRatio * _valueSum);
        return (float) Math.exp(-1.d * perWordBound);
    }

    /**
     * Estimates the variational bound over all documents using only the documents passed as
     * mini-batch.
     */
    private double computeApproxBound() {
        // prepare
        final double[] gammaSum = new double[_miniBatchSize];
        for (int d = 0; d < _miniBatchSize; d++) {
            gammaSum[d] = MathUtils.sum(_gamma[d]);
        }
        final double[] digamma_gammaSum = MathUtils.digamma(gammaSum);

        final double[] lambdaSum = new double[_K];
        for (float[] lambda_label : _lambda.values()) {
            MathUtils.add(lambda_label, lambdaSum, _K);
        }
        final double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);

        final double logGamma_alpha = Gamma.logGamma(_alpha);
        final double logGamma_alphaSum = Gamma.logGamma(_K * _alpha);

        double score = 0.d;
        for (int d = 0; d < _miniBatchSize; d++) {
            final double digamma_gammaSum_d = digamma_gammaSum[d];
            final float[] gamma_d = _gamma[d];

            // E[log p(doc | theta, beta)]
            for (Map.Entry<String, Float> e : _miniBatchDocs.get(d).entrySet()) {
                final float[] lambda_label = _lambda.get(e.getKey());

                // logsumexp( Elogthetad + Elogbetad )
                final double[] temp = new double[_K];
                double max = Double.MIN_VALUE;
                for (int k = 0; k < _K; k++) {
                    double eLogTheta_dk = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
                    double eLogBeta_kw = Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k];
                    final double tempK = eLogTheta_dk + eLogBeta_kw;
                    if (tempK > max) {
                        max = tempK;
                    }
                    temp[k] = tempK;
                }
                double logsumexp = MathUtils.logsumexp(temp, max);

                // sum( word count * logsumexp(...) )
                score += e.getValue().floatValue() * logsumexp;
            }

            // E[log p(theta | alpha) - log q(theta | gamma)]
            for (int k = 0; k < _K; k++) {
                float gamma_dk = gamma_d[k];

                // sum( (alpha - gammad) * Elogthetad )
                score += (_alpha - gamma_dk) * (Gamma.digamma(gamma_dk) - digamma_gammaSum_d);

                // sum( gammaln(gammad) - gammaln(alpha) )
                score += Gamma.logGamma(gamma_dk) - logGamma_alpha;
            }
            score += logGamma_alphaSum; // gammaln(sum(alpha))
            score -= Gamma.logGamma(gammaSum[d]); // gammaln(sum(gammad))
        }

        // assuming likelihood for when corpus in the documents is only a subset of the whole corpus
        // (i.e., online setting); likelihood should be always roughly on the same scale
        score *= _docRatio;

        final double logGamma_eta = Gamma.logGamma(_eta);
        final double logGamma_etaSum = Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta

        // E[log p(beta | eta) - log q (beta | lambda)]
        for (final float[] lambda_label : _lambda.values()) {
            for (int k = 0; k < _K; k++) {
                float lambda_label_k = lambda_label[k];

                // sum( (eta - lambda) * Elogbeta )
                score += (_eta - lambda_label_k) * (Gamma.digamma(lambda_label_k) - digamma_lambdaSum[k]);

                // sum( gammaln(lambda) - gammaln(eta) )
                score += Gamma.logGamma(lambda_label_k) - logGamma_eta;
            }
        }
        for (int k = 0; k < _K; k++) {
            // sum( gammaln(etaSum) - gammaln( lambdaSum_k )
            score += logGamma_etaSum - Gamma.logGamma(lambdaSum[k]);
        }

        return score;
    }

    @VisibleForTesting
    float getWordScore(@Nonnull final String label, @Nonnegative final int k) {
        final float[] lambda_label = _lambda.get(label);
        if (lambda_label == null) {
            throw new IllegalArgumentException("Word `" + label + "` is not in the corpus.");
        }
        if (k >= lambda_label.length) {
            throw new IllegalArgumentException("Topic index must be in [0, " + _lambda.get(label).length + "]");
        }
        return lambda_label[k];
    }

    protected void setWordScore(@Nonnull final String label, @Nonnegative final int k, final float lambda_k) {
        float[] lambda_label = _lambda.get(label);
        if (lambda_label == null) {
            lambda_label = ArrayUtils.newRandomFloatArray(_K, _gd);
            _lambda.put(label, lambda_label);
        }
        lambda_label[k] = lambda_k;
    }

    @Nonnull
    protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
        return getTopicWords(k, _lambda.keySet().size());
    }

    @Nonnull
    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k, @Nonnegative int topN) {
        double lambdaSum = 0.d;
        final SortedMap<Float, List<String>> sortedLambda = new TreeMap<Float, List<String>>(
                Collections.reverseOrder());

        for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
            final float lambda_k = e.getValue()[k];
            lambdaSum += lambda_k;

            List<String> labels = sortedLambda.get(lambda_k);
            if (labels == null) {
                labels = new ArrayList<String>();
                sortedLambda.put(lambda_k, labels);
            }
            labels.add(e.getKey());
        }

        final SortedMap<Float, List<String>> ret = new TreeMap<Float, List<String>>(Collections.reverseOrder());

        topN = Math.min(topN, _lambda.keySet().size());
        int tt = 0;
        for (Map.Entry<Float, List<String>> e : sortedLambda.entrySet()) {
            float key = (float) (e.getKey().floatValue() / lambdaSum);
            ret.put(Float.valueOf(key), e.getValue());

            if (++tt == topN) {
                break;
            }
        }

        return ret;
    }

    @Nonnull
    protected float[] getTopicDistribution(@Nonnull final String[] doc) {
        preprocessMiniBatch(new String[][] { doc });

        initParams(false);

        eStep();

        // normalize topic distribution
        final float[] topicDistr = new float[_K];
        final float[] gamma0 = _gamma[0];
        final double gammaSum = MathUtils.sum(gamma0);
        for (int k = 0; k < _K; k++) {
            topicDistr[k] = (float) (gamma0[k] / gammaSum);
        }
        return topicDistr;
    }

}