com.elex.dmp.lda.ModelTrainer.java Source code

Java tutorial

Introduction

Here is the source code for com.elex.dmp.lda.ModelTrainer.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.elex.dmp.lda;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * Multithreaded LDA model trainer class, which primarily operates by running a "map/reduce"
 * operation, all in memory locally (ie not a hadoop job!) : the "map" operation is to take
 * the "read-only" {@link TopicModel} and use it to iteratively learn the p(topic|term, doc)
 * distribution for documents (this can be done in parallel across many documents, as the
 * "read-only" model is, well, read-only.  Then the outputs of this are "reduced" onto the
 * "write" model, and these updates are not parallelizable in the same way: individual
 * documents can't be added to the same entries in different threads at the same time, but
 * updates across many topics to the same term from the same document can be done in parallel,
 * so they are.
 *
 * Because computation is done asynchronously, when iteration is done, it's important to call
 * the stop() method, which blocks until work is complete.
 *
 * Setting the read model and the write model to be the same object may not quite work yet,
 * on account of parallelism badness.
 */
public class ModelTrainer {

    private static final Logger log = LoggerFactory.getLogger(ModelTrainer.class);

    private final int numTopics;
    private final int numTerms;
    private TopicModel readModel;
    private TopicModel writeModel;
    private ThreadPoolExecutor threadPool;
    private BlockingQueue<Runnable> workQueue;
    private final int numTrainThreads;
    private final boolean isReadWrite;

    public ModelTrainer(TopicModel initialReadModel, TopicModel initialWriteModel, int numTrainThreads,
            int numTopics, int numTerms) {
        this.readModel = initialReadModel;
        this.writeModel = initialWriteModel;
        this.numTrainThreads = numTrainThreads;
        this.numTopics = numTopics;
        this.numTerms = numTerms;
        isReadWrite = initialReadModel == initialWriteModel;
    }

    /**
     * WARNING: this constructor may not lead to good behavior.  What should be verified is that
     * the model updating process does not conflict with model reading.  It might work, but then
     * again, it might not!
     * @param model to be used for both reading (inference) and accumulating (learning)
     * @param numTrainThreads
     * @param numTopics
     * @param numTerms
     */
    public ModelTrainer(TopicModel model, int numTrainThreads, int numTopics, int numTerms) {
        this(model, model, numTrainThreads, numTopics, numTerms);
    }

    public TopicModel getReadModel() {
        return readModel;
    }

    public void start() {
        log.info("Starting training threadpool with " + numTrainThreads + " threads");
        workQueue = new ArrayBlockingQueue<Runnable>(numTrainThreads * 10);
        threadPool = new ThreadPoolExecutor(numTrainThreads, numTrainThreads, 0, TimeUnit.SECONDS, workQueue);
        threadPool.allowCoreThreadTimeOut(false);
        threadPool.prestartAllCoreThreads();
    }

    public void train(VectorIterable matrix, VectorIterable docTopicCounts) {
        train(matrix, docTopicCounts, 1);
    }

    public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts) {
        return calculatePerplexity(matrix, docTopicCounts, 0);
    }

    public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts, double testFraction) {
        Iterator<MatrixSlice> docIterator = matrix.iterator();
        Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator();
        double perplexity = 0;
        double matrixNorm = 0;
        while (docIterator.hasNext() && docTopicIterator.hasNext()) {
            MatrixSlice docSlice = docIterator.next();
            MatrixSlice topicSlice = docTopicIterator.next();
            int docId = docSlice.index();
            Vector document = docSlice.vector();
            Vector topicDist = topicSlice.vector();
            if (testFraction == 0 || docId % (1 / testFraction) == 0) {
                trainSync(document, topicDist, false, 10);
                perplexity += readModel.perplexity(document, topicDist);
                matrixNorm += document.norm(1);
            }
        }
        return perplexity / matrixNorm;
    }

    public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) {
        start();
        Iterator<MatrixSlice> docIterator = matrix.iterator();
        Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator();
        long startTime = System.nanoTime();
        int i = 0;
        double[] times = new double[100];
        Map<Vector, Vector> batch = Maps.newHashMap();
        int numTokensInBatch = 0;
        long batchStart = System.nanoTime();
        while (docIterator.hasNext() && docTopicIterator.hasNext()) {
            i++;
            Vector document = docIterator.next().vector();
            Vector topicDist = docTopicIterator.next().vector();
            if (isReadWrite) {
                if (batch.size() < numTrainThreads) {
                    batch.put(document, topicDist);
                    if (log.isDebugEnabled()) {
                        numTokensInBatch += document.getNumNondefaultElements();
                    }
                } else {
                    batchTrain(batch, true, numDocTopicIters);
                    long time = System.nanoTime();
                    log.debug("trained {} docs with {} tokens, start time {}, end time {}",
                            new Object[] { numTrainThreads, numTokensInBatch, batchStart, time });
                    batchStart = time;
                    numTokensInBatch = 0;
                }
            } else {
                long start = System.nanoTime();
                train(document, topicDist, true, numDocTopicIters);
                if (log.isDebugEnabled()) {
                    times[i % times.length] = (System.nanoTime() - start)
                            / (1.0e6 * document.getNumNondefaultElements());
                    if (i % 100 == 0) {
                        long time = System.nanoTime() - startTime;
                        log.debug("trained " + i + " documents in " + (time / 1.0e6) + "ms");
                        if (i % 500 == 0) {
                            Arrays.sort(times);
                            log.debug("training took median " + times[times.length / 2] + "ms per token-instance");
                        }
                    }
                }
            }
        }
        stop();
    }

    public void batchTrain(Map<Vector, Vector> batch, boolean update, int numDocTopicsIters) {
        while (true) {
            try {
                List<TrainerRunnable> runnables = Lists.newArrayList();
                for (Map.Entry<Vector, Vector> entry : batch.entrySet()) {
                    runnables.add(new TrainerRunnable(readModel, null, entry.getKey(), entry.getValue(),
                            new SparseRowMatrix(numTopics, numTerms, true), numDocTopicsIters));
                }
                threadPool.invokeAll(runnables);
                if (update) {
                    for (TrainerRunnable runnable : runnables) {
                        writeModel.update(runnable.docTopicModel);
                    }
                }
                break;
            } catch (InterruptedException e) {
                log.warn("Interrupted during batch training, retrying!", e);
            }
        }
    }

    public void train(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) {
        while (true) {
            try {
                workQueue.put(new TrainerRunnable(readModel, update ? writeModel : null, document, docTopicCounts,
                        new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters));
                return;
            } catch (InterruptedException e) {
                log.warn("Interrupted waiting to submit document to work queue: " + document, e);
            }
        }
    }

    public void trainSync(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) {
        new TrainerRunnable(readModel, update ? writeModel : null, document, docTopicCounts,
                new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters).run();
    }

    public double calculatePerplexity(Vector document, Vector docTopicCounts, int numDocTopicIters) {
        TrainerRunnable runner = new TrainerRunnable(readModel, null, document, docTopicCounts,
                new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters);
        return runner.call();
    }

    public void stop() {
        long startTime = System.nanoTime();
        log.info("Initiating stopping of training threadpool");
        try {
            threadPool.shutdown();
            if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) {
                log.warn("Threadpool timed out on await termination - jobs still running!");
            }
            long newTime = System.nanoTime();
            log.info("threadpool took: " + (newTime - startTime) / 1.0e6 + "ms");
            startTime = newTime;
            writeModel.awaitTermination();
            newTime = System.nanoTime();
            log.info("writeModel.awaitTermination() took " + (newTime - startTime) / 1.0e6 + "ms");
            TopicModel tmpModel = writeModel;
            writeModel = readModel;
            readModel = tmpModel;
            writeModel.reset();
        } catch (InterruptedException e) {
            log.error("Interrupted shutting down!", e);
        }
    }

    public void persist(Path outputPath) throws IOException {
        readModel.persist(outputPath, true);
    }

    private static class TrainerRunnable implements Runnable, Callable<Double> {
        private final TopicModel readModel;
        private final TopicModel writeModel;
        private final Vector document;
        private final Vector docTopics;
        private final Matrix docTopicModel;
        private final int numDocTopicIters;

        private TrainerRunnable(TopicModel readModel, TopicModel writeModel, Vector document, Vector docTopics,
                Matrix docTopicModel, int numDocTopicIters) {
            this.readModel = readModel;
            this.writeModel = writeModel;
            this.document = document;
            this.docTopics = docTopics;
            this.docTopicModel = docTopicModel;
            this.numDocTopicIters = numDocTopicIters;
        }

        @Override
        public void run() {
            for (int i = 0; i < numDocTopicIters; i++) {
                // synchronous read-only call:
                readModel.trainDocTopicModel(document, docTopics, docTopicModel);
            }
            if (writeModel != null) {
                // parallel call which is read-only on the docTopicModel, and write-only on the writeModel
                // this method does not return until all rows of the docTopicModel have been submitted
                // to write work queues
                writeModel.update(docTopicModel);
            }
        }

        @Override
        public Double call() {
            run();
            return readModel.perplexity(document, docTopics);
        }
    }
}