org.apache.mahout.classifier.sequencelearning.hmm.hadoop.BaumWelchDriver.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.sequencelearning.hmm.hadoop.BaumWelchDriver.java

Source

/* Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sequencelearning.hmm.hadoop;

import java.io.*;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmModel;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.HadoopUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Main class for launching iterative BaumWelch MapReduce training.
 * Baum-Welch, like K-Means is an Expectation Maximization (EM) algorithm which tries to estimate
 * model's parameters using maximum likelihood criterion,
 * <p/>
 * As shown by Andrew Ng et. al, EM algorithms fit nicely into the MapReduce framework.
 * <p/>
 * The training splits evenly between the mappers and reducers, with combiners reducing the network traffic.
 */

public class BaumWelchDriver extends AbstractJob {

    private static final Logger log = LoggerFactory.getLogger(BaumWelchDriver.class);

    public static void main(String[] args) throws Exception {
        ToolRunner.run(new Configuration(), new BaumWelchDriver(), args);
    }

    @Override
    public int run(String[] args) throws Exception {

        DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();

        Option inputOption = optionBuilder.withLongName("input")
                .withDescription("Sequence file containing VectorWritables as training sequence").withShortName("i")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("path").create())
                .withRequired(true).create();

        Option outputOption = optionBuilder.withLongName("output")
                .withDescription("Output path to store the trained model encoded as Sequence Files")
                .withShortName("o")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("path").create())
                .withRequired(true).create();

        Option modelOption = optionBuilder.withLongName("model")
                .withDescription("Initial HmmModel encoded as a Sequence File. "
                        + "Will be constructed with a random distribution if the 'buildRandom' option is set to true.")
                .withShortName("im")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("path").create())
                .withRequired(false).create();

        Option hiddenStateMapPath = optionBuilder.withLongName("hiddenStateToIDMap")
                .withDescription("Hidden states to ID map path.").withShortName("hmap")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("path").create())
                .withRequired(true).create();

        Option emitStateMapPath = optionBuilder.withLongName("emittedStateToIDMap")
                .withDescription("Emitted states to ID map path.").withShortName("smap")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("path").create())
                .withRequired(true).create();

        Option randomOption = optionBuilder.withLongName("buildRandom")
                .withDescription(
                        "Optional argument to generate a random initial HmmModel and store it in 'model' directory")
                .withShortName("r")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("boolean").create())
                .withRequired(false).create();

        Option scalingOption = optionBuilder.withLongName("Scaling")
                .withDescription("Optional argument to invoke scaled training").withShortName("l")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("string").create())
                .withRequired(true).create();

        Option stateNumberOption = optionBuilder.withLongName("nrOfHiddenStates")
                .withDescription("Number of hidden states").withShortName("nh")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("number").create())
                .withRequired(true).create();

        Option observedStateNumberOption = optionBuilder.withLongName("nrOfObservedStates")
                .withDescription("Number of observed states").withShortName("no")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("number").create())
                .withRequired(true).create();

        Option epsilonOption = optionBuilder.withLongName("epsilon").withDescription("Convergence threshold")
                .withShortName("e")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("number").create())
                .withRequired(true).create();

        Option iterationsOption = optionBuilder.withLongName("maxIterations")
                .withDescription("Maximum iterations number").withShortName("m")
                .withArgument(argumentBuilder.withMaximum(1).withMinimum(1).withName("number").create())
                .withRequired(true).create();

        Group optionGroup = new GroupBuilder().withOption(inputOption).withOption(outputOption)
                .withOption(modelOption).withOption(hiddenStateMapPath).withOption(emitStateMapPath)
                .withOption(randomOption).withOption(scalingOption).withOption(stateNumberOption)
                .withOption(observedStateNumberOption).withOption(epsilonOption).withOption(iterationsOption)
                .withName("Options").create();

        try {
            Parser parser = new Parser();
            parser.setGroup(optionGroup);
            CommandLine commandLine = parser.parse(args);

            String input = (String) commandLine.getValue(inputOption);
            String output = (String) commandLine.getValue(outputOption);
            String modelIn = (String) commandLine.getValue(modelOption);
            String hiddenStateToIdMap = (String) commandLine.getValue(hiddenStateMapPath);
            String emittedStateToIdMap = (String) commandLine.getValue(emitStateMapPath);

            Boolean buildRandom = commandLine.hasOption(randomOption);
            String scaling = (String) commandLine.getValue(scalingOption);

            int numHidden = Integer.parseInt((String) commandLine.getValue(stateNumberOption));
            int numObserved = Integer.parseInt((String) commandLine.getValue(observedStateNumberOption));

            double convergenceDelta = Double.parseDouble((String) commandLine.getValue(epsilonOption));
            int maxIterations = Integer.parseInt((String) commandLine.getValue(iterationsOption));

            if (getConf() == null) {
                setConf(new Configuration());
            }
            if (buildRandom) {

                BaumWelchUtils.buildRandomModel(numHidden, numObserved, new Path(modelIn), getConf());
            }
            run(getConf(), new Path(input), new Path(modelIn), new Path(output), new Path(hiddenStateToIdMap),
                    new Path(emittedStateToIdMap), numHidden, numObserved, convergenceDelta, scaling,
                    maxIterations);
        } catch (OptionException e) {
            CommandLineUtil.printHelp(optionGroup);
        }

        return 0;

    }

    /**
     * Run the Baum-Welch Map Reduce algorithm using the supplied arguments
     *
     * @param conf                the Configuration to use
     * @param input               the Path to the directory containing input observed sequences
     * @param modelIn             the Path to the HmmModel stored as a SequenceFile
     * @param output              the Path to the output directory
     * @param hiddenStateToIdMap  the Path to the map of hidden states to ids
     * @param emittedStateToIdMap the Path to the map of emitted states to ids
     * @param numHidden           the number of Hidden states
     * @param numObserved         the number of Observed states
     * @param convergenceDelta    the convergence delta value
     * @param scaling            use the log scaled version if set to true
     * @param maxIterations       the maximum number of iterations
     * @throws IOException
     * @throws InterruptedIOException
     * @throws ClassNotFoundException
     */
    public static void run(Configuration conf, Path input, Path modelIn, Path output, Path hiddenStateToIdMap,
            Path emittedStateToIdMap, int numHidden, int numObserved, double convergenceDelta, String scaling,
            int maxIterations) throws IOException, InterruptedException, ClassNotFoundException {

        //iterate until the model converges or until maxIterations
        String delta = Double.toString(convergenceDelta);
        if (log.isInfoEnabled()) {
            log.info("Input: {} Model In: {} Out: {}", new Object[] { input, modelIn, output });
            log.info("convergence: {} max Iterations: {}", new Object[] { convergenceDelta, maxIterations });
        }
        Path modelOut = runBaumWelchMR(conf, input, modelIn, output, hiddenStateToIdMap, emittedStateToIdMap,
                numHidden, numObserved, delta, scaling, maxIterations);
    }

    /**
     * Run the Baum-Welch Map Reduce algorithm using the supplied arguments
     *
     * @param conf                the Configuration to use
     * @param input               the Path to the directory containing input observed sequences
     * @param modelIn             the Path to the HmmModel stored as a SequenceFile
     * @param output              the Path to the output directory
     * @param hiddenStateToIdMap  the Path to the map of hidden states to ids
     * @param emittedStateToIdMap the Path to the map of emitted states to ids
     * @param numHidden           the number of Hidden states
     * @param numObserved         the number of Observed states
     * @param delta               the convergence delta value
     * @param scaling            use the log scaled variant if set to true
     * @param maxIterations       the maximum number of iterations
     * @return the path to the output model directory
     * @throws IOException
     * @throws InterruptedIOException
     * @throws ClassNotFoundException
     */
    public static Path runBaumWelchMR(Configuration conf, Path input, Path modelIn, Path output,
            Path hiddenStateToIdMap, Path emittedStateToIdMap, int numHidden, int numObserved, String delta,
            String scaling, int maxIterations) throws IOException, InterruptedException, ClassNotFoundException {
        boolean converged = false;
        int iteration = 1;
        while (!converged && iteration <= maxIterations) {
            log.info("Baum-Welch MR Iteration " + iteration);
            // point the output to a new directory per iteration
            Path modelOut = new Path(output, "model-" + iteration);
            converged = runIteration(conf, input, modelIn, modelOut, hiddenStateToIdMap, emittedStateToIdMap,
                    numHidden, numObserved, scaling, delta);
            modelIn = modelOut;
            iteration++;
        }
        return new Path(output.toString() + "/model-" + --iteration);
    }

    /**
     * Run one iteration of the Baum-Welch Map Reduce algorithm using the supplied arguments
     *
     * @param conf                the Configuration to use
     * @param input               the Path to the directory containing input
     * @param modelIn             the Path to the HmmModel
     * @param modelOut            the Path to the output directory
     * @param hiddenStateToIdMap  the Path to the map of hidden states to ids
     * @param emittedStateToIdMap the Path to the map of emitted states to ids
     * @param numHidden           the number of Hidden states
     * @param numObserved         the number of Observed states
     * @param scaling             name of the scaling method
     * @param delta               the convergence delta value
     * @return true or false depending on convergence check
     */

    private static boolean runIteration(Configuration conf, Path input, Path modelIn, Path modelOut,
            Path hiddenStateToIdMap, Path emittedStateToIdMap, int numHidden, int numObserved, String scaling,
            String delta) throws IOException, InterruptedException, ClassNotFoundException {

        conf.set(BaumWelchConfigKeys.EMITTED_STATES_MAP_PATH, emittedStateToIdMap.toString());
        conf.set(BaumWelchConfigKeys.HIDDEN_STATES_MAP_PATH, hiddenStateToIdMap.toString());
        conf.set(BaumWelchConfigKeys.SCALING_OPTION_KEY, scaling);
        conf.set(BaumWelchConfigKeys.MODEL_PATH_KEY, modelIn.toString());
        conf.set(BaumWelchConfigKeys.NUMBER_OF_HIDDEN_STATES_KEY, ((Integer) numHidden).toString());
        conf.set(BaumWelchConfigKeys.NUMBER_OF_EMITTED_STATES_KEY, ((Integer) numObserved).toString());
        conf.set(BaumWelchConfigKeys.MODEL_CONVERGENCE_KEY, delta);

        Job job = new Job(conf, "Baum-Welch Driver running runIteration over modelIn: "
                + conf.get(BaumWelchConfigKeys.MODEL_PATH_KEY));
        job.setMapOutputKeyClass(Text.class);
        job.setMapOutputValueClass(MapWritable.class);
        job.setOutputKeyClass(Text.class);
        job.setOutputValueClass(MapWritable.class);

        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);
        job.setMapperClass(BaumWelchMapper.class);
        job.setCombinerClass(BaumWelchCombiner.class);
        job.setReducerClass(BaumWelchReducer.class);

        FileInputFormat.addInputPath(job, input);
        FileOutputFormat.setOutputPath(job, modelOut);

        job.setJarByClass(BaumWelchDriver.class);
        HadoopUtil.delete(conf, modelOut);
        if (!job.waitForCompletion(true)) {
            throw new InterruptedException("Baum-Welch Iteration failed processing " + modelIn);
        }

        return isConverged(modelIn, modelOut, numHidden, numObserved, conf);
    }

    /**
     * Check for convergence of two Hmm models
     *
     * @param modelIn     the Path to the HmmModel
     * @param modelOut    the Path to the output directory
     * @param numHidden   the number of Hidden states
     * @param numObserved the number of Observed states
     * @param conf        the Configuration to use
     * @return true if converged, false otherwise
     */

    private static boolean isConverged(Path modelIn, Path modelOut, int numHidden, int numObserved,
            Configuration conf) throws IOException {

        log.info("-----------Checking Convergence----------");
        HmmModel previousModel = BaumWelchUtils.createHmmModel(numHidden, numObserved, modelIn, conf);
        if (previousModel == null) {
            throw new IllegalStateException("HmmModel from previous iteration is empty!");
        }
        HmmModel newModel = BaumWelchUtils.createHmmModel(numHidden, numObserved, modelOut, conf);
        if (newModel == null) {
            throw new IllegalStateException("HmmModel from current iteration is empty!");
        }

        return BaumWelchUtils.checkConvergence(previousModel, newModel,
                Double.parseDouble(conf.get(BaumWelchConfigKeys.MODEL_CONVERGENCE_KEY)));

    }

}