org.apache.mahout.clustering.lda.LDADriver.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.clustering.lda.LDADriver.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.clustering.lda;

import java.io.IOException;
import java.util.Random;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.DenseMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Estimates an LDA model from a corpus of documents, which are SparseVectors of word counts. At each phase,
 * it outputs a matrix of log probabilities of each topic.
 */
public final class LDADriver {

    static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";

    static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";

    static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";

    static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";

    static final int LOG_LIKELIHOOD_KEY = -2;

    static final int TOPIC_SUM_KEY = -1;

    static final double OVERALL_CONVERGENCE = 1.0E-5;

    private static final Logger log = LoggerFactory.getLogger(LDADriver.class);

    private LDADriver() {
    }

    public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
        Option inputOpt = DefaultOptionCreator.inputOption().create();
        Option outputOpt = DefaultOptionCreator.outputOption().create();
        Option overwriteOutput = DefaultOptionCreator.overwriteOption().create();
        Option topicsOpt = DefaultOptionCreator.numTopicsOption().create();
        Option wordsOpt = DefaultOptionCreator.numWordsOption().create();
        Option topicSmOpt = DefaultOptionCreator.topicSmoothingOption().create();
        Option maxIterOpt = DefaultOptionCreator.maxIterationsOption().withRequired(false).create();
        Option numReducOpt = DefaultOptionCreator.numReducersOption().create();
        Option helpOpt = DefaultOptionCreator.helpOption();

        Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(outputOpt)
                .withOption(topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt)
                .withOption(numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create();
        try {
            Parser parser = new Parser();
            parser.setGroup(group);
            CommandLine cmdLine = parser.parse(args);

            if (cmdLine.hasOption(helpOpt)) {
                CommandLineUtil.printHelp(group);
                return;
            }
            Path input = new Path(cmdLine.getValue(inputOpt).toString());
            Path output = new Path(cmdLine.getValue(outputOpt).toString());
            if (cmdLine.hasOption(overwriteOutput)) {
                HadoopUtil.overwriteOutput(output);
            }
            int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
            int numReduceTasks = Integer.parseInt(cmdLine.getValue(numReducOpt).toString());
            int numTopics = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
            int numWords = Integer.parseInt(cmdLine.getValue(wordsOpt).toString());
            double topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString());
            if (topicSmoothing < 1) {
                topicSmoothing = 50.0 / numTopics;
            }

            runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations, numReduceTasks);

        } catch (OptionException e) {
            log.error("Exception", e);
            CommandLineUtil.printHelp(group);
        }
    }

    /**
     * Run the job using supplied arguments
     * 
     * @param input
     *          the directory pathname for input points
     * @param output
     *          the directory pathname for output points
     * @param numTopics
     *          the number of topics
     * @param numWords
     *          the number of words
     * @param topicSmoothing
     *          pseudocounts for each topic, typically small &lt; .5
     * @param maxIterations
     *          the maximum number of iterations
     * @param numReducers
     *          the number of Reducers desired
     * @throws IOException
     */
    public static void runJob(Path input, Path output, int numTopics, int numWords, double topicSmoothing,
            int maxIterations, int numReducers) throws IOException, InterruptedException, ClassNotFoundException {

        Path stateIn = new Path(output, "state-0");
        writeInitialState(stateIn, numTopics, numWords);
        double oldLL = Double.NEGATIVE_INFINITY;
        boolean converged = false;

        for (int iteration = 1; ((maxIterations < 1) || (iteration <= maxIterations)) && !converged; iteration++) {
            log.info("Iteration {}", iteration);
            // point the output to a new directory per iteration
            Path stateOut = new Path(output, "state-" + iteration);
            double ll = runIteration(input, stateIn, stateOut, numTopics, numWords, topicSmoothing, numReducers);
            double relChange = (oldLL - ll) / oldLL;

            // now point the input to the old output directory
            log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll);
            log.info("(Old LL: {})", oldLL);
            log.info("(Rel Change: {})", relChange);

            converged = (iteration > 3) && (relChange < OVERALL_CONVERGENCE);
            stateIn = stateOut;
            oldLL = ll;
        }
    }

    private static void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
        Configuration job = new Configuration();
        FileSystem fs = statePath.getFileSystem(job);

        DoubleWritable v = new DoubleWritable();

        Random random = RandomUtils.getRandom();

        for (int k = 0; k < numTopics; ++k) {
            Path path = new Path(statePath, "part-" + k);
            SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class,
                    DoubleWritable.class);

            double total = 0.0; // total number of pseudo counts we made
            for (int w = 0; w < numWords; ++w) {
                IntPairWritable kw = new IntPairWritable(k, w);
                // A small amount of random noise, minimized by having a floor.
                double pseudocount = random.nextDouble() + 1.0E-8;
                total += pseudocount;
                v.set(Math.log(pseudocount));
                writer.append(kw, v);
            }
            IntPairWritable kTsk = new IntPairWritable(k, TOPIC_SUM_KEY);
            v.set(Math.log(total));
            writer.append(kTsk, v);

            writer.close();
        }
    }

    private static double findLL(Path statePath, Configuration job) throws IOException {
        FileSystem fs = statePath.getFileSystem(job);

        double ll = 0.0;

        IntPairWritable key = new IntPairWritable();
        DoubleWritable value = new DoubleWritable();
        for (FileStatus status : fs.globStatus(new Path(statePath, "part-*"))) {
            Path path = status.getPath();
            SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
            while (reader.next(key, value)) {
                if (key.getFirst() == LOG_LIKELIHOOD_KEY) {
                    ll = value.get();
                    break;
                }
            }
            reader.close();
        }

        return ll;
    }

    /**
     * Run the job using supplied arguments
     * 
     * @param input
     *          the directory pathname for input points
     * @param stateIn
     *          the directory pathname for input state
     * @param stateOut
     *          the directory pathname for output state
     * @param numTopics
     *          the number of clusters
     * @param numReducers
     *          the number of Reducers desired
     */
    public static double runIteration(Path input, Path stateIn, Path stateOut, int numTopics, int numWords,
            double topicSmoothing, int numReducers)
            throws IOException, InterruptedException, ClassNotFoundException {
        Configuration conf = new Configuration();
        conf.set(STATE_IN_KEY, stateIn.toString());
        conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
        conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
        conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));

        Job job = new Job(conf);

        job.setOutputKeyClass(IntPairWritable.class);
        job.setOutputValueClass(DoubleWritable.class);
        FileInputFormat.addInputPaths(job, input.toString());
        FileOutputFormat.setOutputPath(job, stateOut);

        job.setMapperClass(LDAMapper.class);
        job.setReducerClass(LDAReducer.class);
        job.setCombinerClass(LDAReducer.class);
        job.setNumReduceTasks(numReducers);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setJarByClass(LDADriver.class);

        job.waitForCompletion(true);
        return findLL(stateOut, conf);
    }

    static LDAState createState(Configuration job) throws IOException {
        String statePath = job.get(STATE_IN_KEY);
        int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
        int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
        double topicSmoothing = Double.parseDouble(job.get(TOPIC_SMOOTHING_KEY));

        Path dir = new Path(statePath);
        FileSystem fs = dir.getFileSystem(job);

        DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
        double[] logTotals = new double[numTopics];
        double ll = 0.0;

        IntPairWritable key = new IntPairWritable();
        DoubleWritable value = new DoubleWritable();
        for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
            Path path = status.getPath();
            SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
            while (reader.next(key, value)) {
                int topic = key.getFirst();
                int word = key.getSecond();
                if (word == TOPIC_SUM_KEY) {
                    logTotals[topic] = value.get();
                    if (Double.isInfinite(value.get())) {
                        throw new IllegalArgumentException();
                    }
                } else if (topic == LOG_LIKELIHOOD_KEY) {
                    ll = value.get();
                } else {
                    if (!((topic >= 0) && (word >= 0))) {
                        throw new IllegalArgumentException(topic + " " + word);
                    }
                    if (pWgT.getQuick(topic, word) != 0.0) {
                        throw new IllegalArgumentException();
                    }
                    pWgT.setQuick(topic, word, value.get());
                    if (Double.isInfinite(pWgT.getQuick(topic, word))) {
                        throw new IllegalArgumentException();
                    }
                }
            }
            reader.close();
        }

        return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
    }
}