hivemall.topicmodel.ProbabilisticTopicModelBaseUDTF.java Source code

Java tutorial

Introduction

Here is the source code for hivemall.topicmodel.ProbabilisticTopicModelBaseUDTF.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.topicmodel;

import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NIOUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);

    public static final int DEFAULT_TOPICS = 10;

    // Options
    protected int topics;
    protected int iterations;
    protected double eps;
    protected int miniBatchSize;

    protected String[][] miniBatch;
    protected int miniBatchCount;

    protected transient AbstractProbabilisticTopicModel model;

    protected ListObjectInspector wordCountsOI;

    // for iterations
    protected NioStatefulSegment fileIO;
    protected ByteBuffer inputBuf;

    private float cumPerplexity;

    public ProbabilisticTopicModelBaseUDTF() {
        this.topics = DEFAULT_TOPICS;
        this.iterations = 10;
        this.eps = 1E-1d;
        this.miniBatchSize = 128; // if 1, truly online setting
    }

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
        opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
        opts.addOption("eps", "epsilon", true,
                "Check convergence based on the difference of perplexity [default: 1E-1]");
        opts.addOption("s", "mini_batch_size", true, "Repeat model updating per mini-batch [default: 128]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;

        if (argOIs.length >= 2) {
            String rawArgs = HiveUtils.getConstString(argOIs[1]);
            cl = parseOptions(rawArgs);
            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
            this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
            if (iterations < 1) {
                throw new UDFArgumentException("'-iterations' must be greater than or equals to 1: " + iterations);
            }
            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
            this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
        }

        return cl;
    }

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 1) {
            throw new UDFArgumentException(
                    "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
        }

        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
        HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());

        processOptions(argOIs);

        this.model = null;
        this.miniBatch = new String[miniBatchSize][];
        this.miniBatchCount = 0;
        this.cumPerplexity = 0.f;

        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
        fieldNames.add("topic");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("word");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        fieldNames.add("score");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);

        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Nonnull
    protected abstract AbstractProbabilisticTopicModel createModel();

    @Override
    public void process(Object[] args) throws HiveException {
        if (model == null) {
            this.model = createModel();
        }

        final int length = wordCountsOI.getListLength(args[0]);
        final String[] wordCounts = new String[length];
        int j = 0;
        for (int i = 0; i < length; i++) {
            Object o = wordCountsOI.getListElement(args[0], i);
            if (o == null) {
                throw new HiveException("Given feature vector contains invalid null elements");
            }
            String s = o.toString();
            wordCounts[j] = s;
            j++;
        }
        if (j == 0) {// avoid empty documents
            return;
        }

        model.accumulateDocCount();

        update(wordCounts);

        recordTrainSampleToTempFile(wordCounts);
    }

    protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts) throws HiveException {
        if (iterations == 1) {
            return;
        }

        ByteBuffer buf = inputBuf;
        NioStatefulSegment dst = fileIO;

        if (buf == null) {
            final File file;
            try {
                file = File.createTempFile("hivemall_topicmodel", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath());
                }
                logger.info("Record training samples to a file: " + file.getAbsolutePath());
            } catch (IOException ioe) {
                throw new UDFArgumentException(ioe);
            } catch (Throwable e) {
                throw new UDFArgumentException(e);
            }
            this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
            this.fileIO = dst = new NioStatefulSegment(file, false);
        }

        // wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
        int wcLengthTotal = 0;
        for (String wc : wordCounts) {
            if (wc == null) {
                continue;
            }
            wcLengthTotal += wc.length();
        }
        int recordBytes = SizeOf.INT + SizeOf.INT * wordCounts.length + wcLengthTotal * SizeOf.CHAR;
        int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself

        int remain = buf.remaining();
        if (remain < requiredBytes) {
            writeBuffer(buf, dst);
        }

        buf.putInt(recordBytes);
        buf.putInt(wordCounts.length);
        for (String wc : wordCounts) {
            NIOUtils.putString(wc, buf);
        }
    }

    private void update(@Nonnull final String[] wordCounts) {
        miniBatch[miniBatchCount] = wordCounts;
        miniBatchCount++;

        if (miniBatchCount == miniBatchSize) {
            train();
        }
    }

    protected void train() {
        if (miniBatchCount == 0) {
            return;
        }

        model.train(miniBatch);

        this.cumPerplexity += model.computePerplexity();

        Arrays.fill(miniBatch, null); // clear
        miniBatchCount = 0;
    }

    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefulSegment dst)
            throws HiveException {
        srcBuf.flip();
        try {
            dst.write(srcBuf);
        } catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", e);
        }
        srcBuf.clear();
    }

    @Override
    public void close() throws HiveException {
        finalizeTraining();
        forwardModel();
        this.model = null;
    }

    @VisibleForTesting
    void finalizeTraining() throws HiveException {
        if (model.getDocCount() == 0L) {
            this.model = null;
            return;
        }
        if (miniBatchCount > 0) { // update for remaining samples
            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
        }
        if (iterations > 1) {
            runIterativeTraining(iterations);
        }
    }

    protected final void runIterativeTraining(@Nonnegative final int iterations) throws HiveException {
        final ByteBuffer buf = this.inputBuf;
        final NioStatefulSegment dst = this.fileIO;
        assert (buf != null);
        assert (dst != null);
        final long numTrainingExamples = model.getDocCount();

        long numTrain = numTrainingExamples / miniBatchSize;
        if (numTrainingExamples % miniBatchSize != 0L) {
            numTrain++;
        }

        final Reporter reporter = getReporter();
        final Counters.Counter iterCounter = (reporter == null) ? null
                : reporter.getCounter("hivemall.topicmodel.ProbabilisticTopicModel$Counter", "iteration");

        try {
            if (dst.getPosition() == 0L) {// run iterations w/o temporary file
                if (buf.position() == 0) {
                    return; // no training example
                }
                buf.flip();

                int iter = 2;
                float perplexity = cumPerplexity / numTrain;
                float perplexityPrev;
                for (; iter <= iterations; iter++) {
                    perplexityPrev = perplexity;
                    cumPerplexity = 0.f;

                    reportProgress(reporter);
                    setCounterValue(iterCounter, iter);

                    while (buf.remaining() > 0) {
                        int recordBytes = buf.getInt();
                        assert (recordBytes > 0) : recordBytes;
                        int wcLength = buf.getInt();
                        final String[] wordCounts = new String[wcLength];
                        for (int j = 0; j < wcLength; j++) {
                            wordCounts[j] = NIOUtils.getString(buf);
                        }
                        update(wordCounts);
                    }
                    buf.rewind();

                    // mean perplexity over `numTrain` mini-batches
                    perplexity = cumPerplexity / numTrain;
                    logger.info("Mean perplexity over mini-batches: " + perplexity);
                    if (Math.abs(perplexityPrev - perplexity) < eps) {
                        break;
                    }
                }
                logger.info("Performed " + Math.min(iter, iterations) + " iterations of "
                        + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus "
                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
                        + " training updates in total) ");
            } else {// read training examples in the temporary file and invoke train for each example
                // write training examples in buffer to a temporary file
                if (buf.remaining() > 0) {
                    writeBuffer(buf, dst);
                }
                try {
                    dst.flush();
                } catch (IOException e) {
                    throw new HiveException("Failed to flush a file: " + dst.getFile().getAbsolutePath(), e);
                }
                if (logger.isInfoEnabled()) {
                    File tmpFile = dst.getFile();
                    logger.info(
                            "Wrote " + numTrainingExamples + " records to a temporary file for iterative training: "
                                    + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")");
                }

                // run iterations
                int iter = 2;
                float perplexity = cumPerplexity / numTrain;
                float perplexityPrev;
                for (; iter <= iterations; iter++) {
                    perplexityPrev = perplexity;
                    cumPerplexity = 0.f;

                    setCounterValue(iterCounter, iter);

                    buf.clear();
                    dst.resetPosition();
                    while (true) {
                        reportProgress(reporter);
                        // TODO prefetch
                        // writes training examples to a buffer in the temporary file
                        final int bytesRead;
                        try {
                            bytesRead = dst.read(buf);
                        } catch (IOException e) {
                            throw new HiveException("Failed to read a file: " + dst.getFile().getAbsolutePath(), e);
                        }
                        if (bytesRead == 0) { // reached file EOF
                            break;
                        }
                        assert (bytesRead > 0) : bytesRead;

                        // reads training examples from a buffer
                        buf.flip();
                        int remain = buf.remaining();
                        if (remain < SizeOf.INT) {
                            throw new HiveException("Illegal file format was detected");
                        }
                        while (remain >= SizeOf.INT) {
                            int pos = buf.position();
                            int recordBytes = buf.getInt();
                            remain -= SizeOf.INT;
                            if (remain < recordBytes) {
                                buf.position(pos);
                                break;
                            }

                            int wcLength = buf.getInt();
                            final String[] wordCounts = new String[wcLength];
                            for (int j = 0; j < wcLength; j++) {
                                wordCounts[j] = NIOUtils.getString(buf);
                            }
                            update(wordCounts);

                            remain -= recordBytes;
                        }
                        buf.compact();
                    }

                    // mean perplexity over `numTrain` mini-batches
                    perplexity = cumPerplexity / numTrain;
                    logger.info("Mean perplexity over mini-batches: " + perplexity);
                    if (Math.abs(perplexityPrev - perplexity) < eps) {
                        break;
                    }
                }
                logger.info("Performed " + Math.min(iter, iterations) + " iterations of "
                        + NumberUtils.formatNumber(numTrainingExamples)
                        + " training examples on a secondary storage (thus "
                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
                        + " training updates in total)");
            }
        } catch (Throwable e) {
            throw new HiveException("Exception caused in the iterative training", e);
        } finally {
            // delete the temporary file and release resources
            try {
                dst.close(true);
            } catch (IOException e) {
                throw new HiveException("Failed to close a file: " + dst.getFile().getAbsolutePath(), e);
            }
            this.inputBuf = null;
            this.fileIO = null;
        }
    }

    protected void forwardModel() throws HiveException {
        final IntWritable topicIdx = new IntWritable();
        final Text word = new Text();
        final FloatWritable score = new FloatWritable();

        final Object[] forwardObjs = new Object[3];
        forwardObjs[0] = topicIdx;
        forwardObjs[1] = word;
        forwardObjs[2] = score;

        for (int k = 0; k < topics; k++) {
            topicIdx.set(k);

            final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
            for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
                score.set(e.getKey().floatValue());
                for (String v : e.getValue()) {
                    word.set(v);
                    forward(forwardObjs);
                }
            }
        }

        logger.info("Forwarded topic words each of " + topics + " topics");
    }

    @VisibleForTesting
    float getWordScore(String label, int k) {
        return model.getWordScore(label, k);
    }

    @VisibleForTesting
    SortedMap<Float, List<String>> getTopicWords(int k) {
        return model.getTopicWords(k);
    }

    @VisibleForTesting
    float[] getTopicDistribution(@Nonnull String[] doc) {
        return model.getTopicDistribution(doc);
    }
}