Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering.lda; import java.io.IOException; import java.util.Random; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.OptionException; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.mahout.common.CommandLineUtil; import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.common.IntPairWritable; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.math.DenseMatrix; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Estimates an LDA model from a corpus of documents, which are SparseVectors of word counts. At each phase, * it outputs a matrix of log probabilities of each topic. */ public final class LDADriver { static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn"; static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics"; static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords"; static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing"; static final int LOG_LIKELIHOOD_KEY = -2; static final int TOPIC_SUM_KEY = -1; static final double OVERALL_CONVERGENCE = 1.0E-5; private static final Logger log = LoggerFactory.getLogger(LDADriver.class); private LDADriver() { } public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException { Option inputOpt = DefaultOptionCreator.inputOption().create(); Option outputOpt = DefaultOptionCreator.outputOption().create(); Option overwriteOutput = DefaultOptionCreator.overwriteOption().create(); Option topicsOpt = DefaultOptionCreator.numTopicsOption().create(); Option wordsOpt = DefaultOptionCreator.numWordsOption().create(); Option topicSmOpt = DefaultOptionCreator.topicSmoothingOption().create(); Option maxIterOpt = DefaultOptionCreator.maxIterationsOption().withRequired(false).create(); Option numReducOpt = DefaultOptionCreator.numReducersOption().create(); Option helpOpt = DefaultOptionCreator.helpOption(); Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(outputOpt) .withOption(topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt) .withOption(numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption(helpOpt)) { CommandLineUtil.printHelp(group); return; } Path input = new Path(cmdLine.getValue(inputOpt).toString()); Path output = new Path(cmdLine.getValue(outputOpt).toString()); if (cmdLine.hasOption(overwriteOutput)) { HadoopUtil.overwriteOutput(output); } int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString()); int numReduceTasks = Integer.parseInt(cmdLine.getValue(numReducOpt).toString()); int numTopics = Integer.parseInt(cmdLine.getValue(topicsOpt).toString()); int numWords = Integer.parseInt(cmdLine.getValue(wordsOpt).toString()); double topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString()); if (topicSmoothing < 1) { topicSmoothing = 50.0 / numTopics; } runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations, numReduceTasks); } catch (OptionException e) { log.error("Exception", e); CommandLineUtil.printHelp(group); } } /** * Run the job using supplied arguments * * @param input * the directory pathname for input points * @param output * the directory pathname for output points * @param numTopics * the number of topics * @param numWords * the number of words * @param topicSmoothing * pseudocounts for each topic, typically small < .5 * @param maxIterations * the maximum number of iterations * @param numReducers * the number of Reducers desired * @throws IOException */ public static void runJob(Path input, Path output, int numTopics, int numWords, double topicSmoothing, int maxIterations, int numReducers) throws IOException, InterruptedException, ClassNotFoundException { Path stateIn = new Path(output, "state-0"); writeInitialState(stateIn, numTopics, numWords); double oldLL = Double.NEGATIVE_INFINITY; boolean converged = false; for (int iteration = 1; ((maxIterations < 1) || (iteration <= maxIterations)) && !converged; iteration++) { log.info("Iteration {}", iteration); // point the output to a new directory per iteration Path stateOut = new Path(output, "state-" + iteration); double ll = runIteration(input, stateIn, stateOut, numTopics, numWords, topicSmoothing, numReducers); double relChange = (oldLL - ll) / oldLL; // now point the input to the old output directory log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll); log.info("(Old LL: {})", oldLL); log.info("(Rel Change: {})", relChange); converged = (iteration > 3) && (relChange < OVERALL_CONVERGENCE); stateIn = stateOut; oldLL = ll; } } private static void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException { Configuration job = new Configuration(); FileSystem fs = statePath.getFileSystem(job); DoubleWritable v = new DoubleWritable(); Random random = RandomUtils.getRandom(); for (int k = 0; k < numTopics; ++k) { Path path = new Path(statePath, "part-" + k); SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class); double total = 0.0; // total number of pseudo counts we made for (int w = 0; w < numWords; ++w) { IntPairWritable kw = new IntPairWritable(k, w); // A small amount of random noise, minimized by having a floor. double pseudocount = random.nextDouble() + 1.0E-8; total += pseudocount; v.set(Math.log(pseudocount)); writer.append(kw, v); } IntPairWritable kTsk = new IntPairWritable(k, TOPIC_SUM_KEY); v.set(Math.log(total)); writer.append(kTsk, v); writer.close(); } } private static double findLL(Path statePath, Configuration job) throws IOException { FileSystem fs = statePath.getFileSystem(job); double ll = 0.0; IntPairWritable key = new IntPairWritable(); DoubleWritable value = new DoubleWritable(); for (FileStatus status : fs.globStatus(new Path(statePath, "part-*"))) { Path path = status.getPath(); SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job); while (reader.next(key, value)) { if (key.getFirst() == LOG_LIKELIHOOD_KEY) { ll = value.get(); break; } } reader.close(); } return ll; } /** * Run the job using supplied arguments * * @param input * the directory pathname for input points * @param stateIn * the directory pathname for input state * @param stateOut * the directory pathname for output state * @param numTopics * the number of clusters * @param numReducers * the number of Reducers desired */ public static double runIteration(Path input, Path stateIn, Path stateOut, int numTopics, int numWords, double topicSmoothing, int numReducers) throws IOException, InterruptedException, ClassNotFoundException { Configuration conf = new Configuration(); conf.set(STATE_IN_KEY, stateIn.toString()); conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics)); conf.set(NUM_WORDS_KEY, Integer.toString(numWords)); conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing)); Job job = new Job(conf); job.setOutputKeyClass(IntPairWritable.class); job.setOutputValueClass(DoubleWritable.class); FileInputFormat.addInputPaths(job, input.toString()); FileOutputFormat.setOutputPath(job, stateOut); job.setMapperClass(LDAMapper.class); job.setReducerClass(LDAReducer.class); job.setCombinerClass(LDAReducer.class); job.setNumReduceTasks(numReducers); job.setOutputFormatClass(SequenceFileOutputFormat.class); job.setInputFormatClass(SequenceFileInputFormat.class); job.setJarByClass(LDADriver.class); job.waitForCompletion(true); return findLL(stateOut, conf); } static LDAState createState(Configuration job) throws IOException { String statePath = job.get(STATE_IN_KEY); int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY)); int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY)); double topicSmoothing = Double.parseDouble(job.get(TOPIC_SMOOTHING_KEY)); Path dir = new Path(statePath); FileSystem fs = dir.getFileSystem(job); DenseMatrix pWgT = new DenseMatrix(numTopics, numWords); double[] logTotals = new double[numTopics]; double ll = 0.0; IntPairWritable key = new IntPairWritable(); DoubleWritable value = new DoubleWritable(); for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) { Path path = status.getPath(); SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job); while (reader.next(key, value)) { int topic = key.getFirst(); int word = key.getSecond(); if (word == TOPIC_SUM_KEY) { logTotals[topic] = value.get(); if (Double.isInfinite(value.get())) { throw new IllegalArgumentException(); } } else if (topic == LOG_LIKELIHOOD_KEY) { ll = value.get(); } else { if (!((topic >= 0) && (word >= 0))) { throw new IllegalArgumentException(topic + " " + word); } if (pWgT.getQuick(topic, word) != 0.0) { throw new IllegalArgumentException(); } pWgT.setQuick(topic, word, value.get()); if (Double.isInfinite(pWgT.getQuick(topic, word))) { throw new IllegalArgumentException(); } } } reader.close(); } return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll); } }