package org.apache.mahout.classifier.sequencelearning.hmm.hadoop;


import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmModel;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.HadoopUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

 * Main class for launching iterative BaumWelch MapReduce training.
 * Baum-Welch, like K-Means is an Expectation Maximization (EM) algorithm which tries to estimate
 * model's parameters using maximum likelihood criterion,
 * <p/>
 * As shown by Andrew Ng et. al, EM algorithms fit nicely into the MapReduce framework.
 * <p/>
 * The training splits evenly between the mappers and reducers, with combiners reducing the network traffic.

public class BaumWelchDriver extends AbstractJob {

    private static final Logger log = LoggerFactory.getLogger(BaumWelchDriver.class);

    public static void main(String[] args) throws Exception { Configuration(), new BaumWelchDriver(), args);

    public int run(String[] args) throws Exception {

        DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();

        Option inputOption = optionBuilder.withLongName("input")
                .withDescription("Sequence file containing VectorWritables as training sequence").withShortName("i")

        Option outputOption = optionBuilder.withLongName("output")
                .withDescription("Output path to store the trained model encoded as Sequence Files")

        Option modelOption = optionBuilder.withLongName("model")
                .withDescription("Initial HmmModel encoded as a Sequence File. "
                        + "Will be constructed with a random distribution if the 'buildRandom' option is set to true.")

        Option hiddenStateMapPath = optionBuilder.withLongName("hiddenStateToIDMap")
                .withDescription("Hidden states to ID map path.").withShortName("hmap")

        Option emitStateMapPath = optionBuilder.withLongName("emittedStateToIDMap")
                .withDescription("Emitted states to ID map path.").withShortName("smap")

        Option randomOption = optionBuilder.withLongName("buildRandom")
                        "Optional argument to generate a random initial HmmModel and store it in 'model' directory")

        Option scalingOption = optionBuilder.withLongName("Scaling")
                .withDescription("Optional argument to invoke scaled training").withShortName("l")

        Option stateNumberOption = optionBuilder.withLongName("nrOfHiddenStates")
                .withDescription("Number of hidden states").withShortName("nh")

        Option observedStateNumberOption = optionBuilder.withLongName("nrOfObservedStates")
                .withDescription("Number of observed states").withShortName("no")

        Option epsilonOption = optionBuilder.withLongName("epsilon").withDescription("Convergence threshold")

        Option iterationsOption = optionBuilder.withLongName("maxIterations")
                .withDescription("Maximum iterations number").withShortName("m")

        Group optionGroup = new GroupBuilder().withOption(inputOption).withOption(outputOption)

        try {
            Parser parser = new Parser();
            CommandLine commandLine = parser.parse(args);

            String input = (String) commandLine.getValue(inputOption);
            String output = (String) commandLine.getValue(outputOption);
            String modelIn = (String) commandLine.getValue(modelOption);
            String hiddenStateToIdMap = (String) commandLine.getValue(hiddenStateMapPath);
            String emittedStateToIdMap = (String) commandLine.getValue(emitStateMapPath);

            Boolean buildRandom = commandLine.hasOption(randomOption);
            String scaling = (String) commandLine.getValue(scalingOption);

            int numHidden = Integer.parseInt((String) commandLine.getValue(stateNumberOption));
            int numObserved = Integer.parseInt((String) commandLine.getValue(observedStateNumberOption));

            double convergenceDelta = Double.parseDouble((String) commandLine.getValue(epsilonOption));
            int maxIterations = Integer.parseInt((String) commandLine.getValue(iterationsOption));

            if (getConf() == null) {
                setConf(new Configuration());
            if (buildRandom) {

                BaumWelchUtils.buildRandomModel(numHidden, numObserved, new Path(modelIn), getConf());
            run(getConf(), new Path(input), new Path(modelIn), new Path(output), new Path(hiddenStateToIdMap),
                    new Path(emittedStateToIdMap), numHidden, numObserved, convergenceDelta, scaling,
        } catch (OptionException e) {

        return 0;


     * Run the Baum-Welch Map Reduce algorithm using the supplied arguments
     * @param conf                the Configuration to use
     * @param input               the Path to the directory containing input observed sequences
     * @param modelIn             the Path to the HmmModel stored as a SequenceFile
     * @param output              the Path to the output directory
     * @param hiddenStateToIdMap  the Path to the map of hidden states to ids
     * @param emittedStateToIdMap the Path to the map of emitted states to ids
     * @param numHidden           the number of Hidden states
     * @param numObserved         the number of Observed states
     * @param convergenceDelta    the convergence delta value
     * @param scaling            use the log scaled version if set to true
     * @param maxIterations       the maximum number of iterations
     * @throws IOException
     * @throws InterruptedIOException
     * @throws ClassNotFoundException
    public static void run(Configuration conf, Path input, Path modelIn, Path output, Path hiddenStateToIdMap,
            Path emittedStateToIdMap, int numHidden, int numObserved, double convergenceDelta, String scaling,
            int maxIterations) throws IOException, InterruptedException, ClassNotFoundException {

        //iterate until the model converges or until maxIterations
        String delta = Double.toString(convergenceDelta);
        if (log.isInfoEnabled()) {
  "Input: {} Model In: {} Out: {}", new Object[] { input, modelIn, output });
  "convergence: {} max Iterations: {}", new Object[] { convergenceDelta, maxIterations });
        Path modelOut = runBaumWelchMR(conf, input, modelIn, output, hiddenStateToIdMap, emittedStateToIdMap,
                numHidden, numObserved, delta, scaling, maxIterations);

     * Run the Baum-Welch Map Reduce algorithm using the supplied arguments
     * @param conf                the Configuration to use
     * @param input               the Path to the directory containing input observed sequences
     * @param modelIn             the Path to the HmmModel stored as a SequenceFile
     * @param output              the Path to the output directory
     * @param hiddenStateToIdMap  the Path to the map of hidden states to ids
     * @param emittedStateToIdMap the Path to the map of emitted states to ids
     * @param numHidden           the number of Hidden states
     * @param numObserved         the number of Observed states
     * @param delta               the convergence delta value
     * @param scaling            use the log scaled variant if set to true
     * @param maxIterations       the maximum number of iterations
     * @return the path to the output model directory
     * @throws IOException
     * @throws InterruptedIOException
     * @throws ClassNotFoundException
    public static Path runBaumWelchMR(Configuration conf, Path input, Path modelIn, Path output,
            Path hiddenStateToIdMap, Path emittedStateToIdMap, int numHidden, int numObserved, String delta,
            String scaling, int maxIterations) throws IOException, InterruptedException, ClassNotFoundException {
        boolean converged = false;
        int iteration = 1;
        while (!converged && iteration <= maxIterations) {
  "Baum-Welch MR Iteration " + iteration);
            // point the output to a new directory per iteration
            Path modelOut = new Path(output, "model-" + iteration);
            converged = runIteration(conf, input, modelIn, modelOut, hiddenStateToIdMap, emittedStateToIdMap,
                    numHidden, numObserved, scaling, delta);
            modelIn = modelOut;
        return new Path(output.toString() + "/model-" + --iteration);

     * Run one iteration of the Baum-Welch Map Reduce algorithm using the supplied arguments
     * @param conf                the Configuration to use
     * @param input               the Path to the directory containing input
     * @param modelIn             the Path to the HmmModel
     * @param modelOut            the Path to the output directory
     * @param hiddenStateToIdMap  the Path to the map of hidden states to ids
     * @param emittedStateToIdMap the Path to the map of emitted states to ids
     * @param numHidden           the number of Hidden states
     * @param numObserved         the number of Observed states
     * @param scaling             name of the scaling method
     * @param delta               the convergence delta value
     * @return true or false depending on convergence check

    private static boolean runIteration(Configuration conf, Path input, Path modelIn, Path modelOut,
            Path hiddenStateToIdMap, Path emittedStateToIdMap, int numHidden, int numObserved, String scaling,
            String delta) throws IOException, InterruptedException, ClassNotFoundException {

        conf.set(BaumWelchConfigKeys.EMITTED_STATES_MAP_PATH, emittedStateToIdMap.toString());
        conf.set(BaumWelchConfigKeys.HIDDEN_STATES_MAP_PATH, hiddenStateToIdMap.toString());
        conf.set(BaumWelchConfigKeys.SCALING_OPTION_KEY, scaling);
        conf.set(BaumWelchConfigKeys.MODEL_PATH_KEY, modelIn.toString());
        conf.set(BaumWelchConfigKeys.NUMBER_OF_HIDDEN_STATES_KEY, ((Integer) numHidden).toString());
        conf.set(BaumWelchConfigKeys.NUMBER_OF_EMITTED_STATES_KEY, ((Integer) numObserved).toString());
        conf.set(BaumWelchConfigKeys.MODEL_CONVERGENCE_KEY, delta);

        Job job = new Job(conf, "Baum-Welch Driver running runIteration over modelIn: "
                + conf.get(BaumWelchConfigKeys.MODEL_PATH_KEY));


        FileInputFormat.addInputPath(job, input);
        FileOutputFormat.setOutputPath(job, modelOut);

        HadoopUtil.delete(conf, modelOut);
        if (!job.waitForCompletion(true)) {
            throw new InterruptedException("Baum-Welch Iteration failed processing " + modelIn);

        return isConverged(modelIn, modelOut, numHidden, numObserved, conf);

     * Check for convergence of two Hmm models
     * @param modelIn     the Path to the HmmModel
     * @param modelOut    the Path to the output directory
     * @param numHidden   the number of Hidden states
     * @param numObserved the number of Observed states
     * @param conf        the Configuration to use
     * @return true if converged, false otherwise

    private static boolean isConverged(Path modelIn, Path modelOut, int numHidden, int numObserved,
            Configuration conf) throws IOException {"-----------Checking Convergence----------");
        HmmModel previousModel = BaumWelchUtils.createHmmModel(numHidden, numObserved, modelIn, conf);
        if (previousModel == null) {
            throw new IllegalStateException("HmmModel from previous iteration is empty!");
        HmmModel newModel = BaumWelchUtils.createHmmModel(numHidden, numObserved, modelOut, conf);
        if (newModel == null) {
            throw new IllegalStateException("HmmModel from current iteration is empty!");

        return BaumWelchUtils.checkConvergence(previousModel, newModel,

