Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.sequencelearning.hmm.hadoop; import com.google.common.io.Closeables; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileUtil; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.*; import org.apache.mahout.classifier.sequencelearning.hmm.HmmModel; import org.apache.mahout.classifier.sequencelearning.hmm.HmmUtils; import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.common.Pair; import org.apache.mahout.common.iterator.sequencefile.PathFilters; import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.Vector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.nio.ByteBuffer; /** * Utilities to convert between HmmModel and Sequence File representation. */ public class BaumWelchUtils { private static final Logger log = LoggerFactory.getLogger(BaumWelchUtils.class); private BaumWelchUtils() { } /** * Converts the sequence files present in a directory to a {@link HmmModel} model. * * @param nrOfHiddenStates Number of hidden states * @param nrOfOutputStates Number of output states * @param modelPath Location of the sequence files containing the model's distributions * @param conf Configuration object * @return HmmModel the encoded model * @throws IOException */ public static HmmModel createHmmModel(int nrOfHiddenStates, int nrOfOutputStates, Path modelPath, Configuration conf) throws IOException { log.info("Entering Create Hmm Model. Model Path = {}", modelPath.toUri()); Vector initialProbabilities = new DenseVector(nrOfHiddenStates); Matrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates); Matrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates); // Get the path location where the seq files encoding model are stored Path modelFilesPath = new Path(modelPath, "*"); Collection<Path> result = new ArrayList<Path>(); // get all filtered file names in result list FileSystem fs = modelFilesPath.getFileSystem(conf); FileStatus[] matches = fs.listStatus( FileUtil.stat2Paths(fs.globStatus(modelFilesPath, PathFilters.partFilter())), PathFilters.partFilter()); for (FileStatus match : matches) { result.add(fs.makeQualified(match.getPath())); } // iterate through the result path list for (Path path : result) { for (Pair<Writable, MapWritable> pair : new SequenceFileIterable<Writable, MapWritable>(path, true, conf)) { Text key = (Text) pair.getFirst(); MapWritable valueMap = pair.getSecond(); if (key.charAt(0) == (int) 'I') { // initial distribution stripe for (MapWritable.Entry<Writable, Writable> entry : valueMap.entrySet()) { initialProbabilities.set(((IntWritable) entry.getKey()).get(), ((DoubleWritable) entry.getValue()).get()); } } else if (key.charAt(0) == (int) 'T') { // transition distribution stripe // key is of the form TRANSIT_0, TRANSIT_1 etc int stateID = Integer.parseInt(key.toString().split("_")[1]); for (MapWritable.Entry<Writable, Writable> entry : valueMap.entrySet()) { transitionMatrix.set(stateID, ((IntWritable) entry.getKey()).get(), ((DoubleWritable) entry.getValue()).get()); } } else if (key.charAt(0) == (int) 'E') { // emission distribution stripe // key is of the form EMIT_0, EMIT_1 etc int stateID = Integer.parseInt(key.toString().split("_")[1]); for (MapWritable.Entry<Writable, Writable> entry : valueMap.entrySet()) { emissionMatrix.set(stateID, ((IntWritable) entry.getKey()).get(), ((DoubleWritable) entry.getValue()).get()); } } else { throw new IllegalStateException("Error creating HmmModel from Sequence File Path"); } } } HmmModel model = new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities); if (model != null) { return model; } else throw new IOException("Error building model from output location"); } /** * Builds a random {@link HmmModel} encoded as a Sequence File and writes it to the specified location. * * @param numHidden Number of hidden states * @param numObserved Number of observed states * @param modelPath Directory path for storing the created HmmModel * @param conf Configuration object * @throws IOException */ public static void buildRandomModel(int numHidden, int numObserved, Path modelPath, Configuration conf) throws IOException { HmmModel model = new HmmModel(numHidden, numObserved); HmmUtils.validate(model); writeModelToDirectory(model, modelPath, conf); } /** * Constructs a HmmModel object using the distributions and stores it as a sequence file. * * @param initialProb initial hidden state probability distribution vector * @param transitionProb hidden state transition probability distribution matrix * @param emissionProb emission probability distribution matrix * @param modelPath path to store the constructed {@link HmmModel} * @param conf Configuration object * @throws IOException */ public static void buildHmmModelFromDistributions(double[] initialProb, double[][] transitionProb, double[][] emissionProb, Path modelPath, Configuration conf) throws IOException { HmmModel model = new HmmModel(new DenseMatrix(transitionProb), new DenseMatrix(emissionProb), new DenseVector(initialProb)); HmmUtils.validate(model); writeModelToDirectory(model, modelPath, conf); } /** * Encodes a particular HmmModel as a Sequence File and write it to the specified location. * * @param model HmmModel to be encoded * @param modelPath Location to store the encoded model * @param conf Configuration object * @throws IOException */ protected static void writeModelToDirectory(HmmModel model, Path modelPath, Configuration conf) throws IOException { int numHidden = model.getNrOfHiddenStates(); int numObserved = model.getNrOfOutputStates(); Matrix emissionMatrix = model.getEmissionMatrix(); Matrix transitionMatrix = model.getTransitionMatrix(); Vector initialProbability = model.getInitialProbabilities(); MapWritable initialDistributionMap = new MapWritable(); MapWritable transitionDistributionMap = new MapWritable(); MapWritable emissionDistributionMap = new MapWritable(); // delete the output directory HadoopUtil.delete(conf, modelPath); // create new file to store HMM FileSystem fs = FileSystem.get(modelPath.toUri(), conf); Path outFile = new Path(modelPath, "part-randomSeed"); boolean newFile = fs.createNewFile(outFile); if (newFile) { SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outFile, Text.class, MapWritable.class); try { for (int i = 0; i < numHidden; i++) { IntWritable initialDistributionKey = new IntWritable(i); DoubleWritable initialDistributionValue = new DoubleWritable(initialProbability.get(i)); initialDistributionMap.put(initialDistributionKey, initialDistributionValue); Text transitionDistributionKey = new Text("TRANSIT_" + Integer.toString(i)); MapWritable transitionDistributionValue = new MapWritable(); for (int j = 0; j < numHidden; j++) { IntWritable transitionDistributionInnerKey = new IntWritable(j); DoubleWritable transitionDistributionInnerValue = new DoubleWritable( transitionMatrix.get(i, j)); transitionDistributionValue.put(transitionDistributionInnerKey, transitionDistributionInnerValue); } transitionDistributionMap.put(transitionDistributionKey, transitionDistributionValue); Text emissionDistributionKey = new Text("EMIT_" + Integer.toString(i)); MapWritable emissionDistributionValue = new MapWritable(); for (int j = 0; j < numObserved; j++) { IntWritable emissionDistributionInnerKey = new IntWritable(j); DoubleWritable emissionDistributionInnerValue = new DoubleWritable( emissionMatrix.get(i, j)); emissionDistributionValue.put(emissionDistributionInnerKey, emissionDistributionInnerValue); } emissionDistributionMap.put(emissionDistributionKey, emissionDistributionValue); } writer.append(new Text("INITIAL"), initialDistributionMap); log.info("Wrote random Initial Distribution Map to {}", outFile); for (MapWritable.Entry<Writable, Writable> transitionEntry : transitionDistributionMap.entrySet()) { writer.append(transitionEntry.getKey(), transitionEntry.getValue()); } log.info("Wrote random Transition Distribution Map to {}", outFile); for (MapWritable.Entry<Writable, Writable> emissionEntry : emissionDistributionMap.entrySet()) { writer.append(emissionEntry.getKey(), emissionEntry.getValue()); } log.info("Wrote random Emission Distribution Map to {}", outFile); } finally { Closeables.closeQuietly(writer); } } } /** * Checks convergence of two HMM models by computing a simple distance between * emission / transition matrices * * @param oldModel Old HMM Model * @param newModel New HMM Model * @param epsilon Convergence Factor * @return true if training converged to a stable state. */ public static boolean checkConvergence(HmmModel oldModel, HmmModel newModel, double epsilon) { // check convergence of transitionProbabilities Matrix oldTransitionMatrix = oldModel.getTransitionMatrix(); Matrix newTransitionMatrix = newModel.getTransitionMatrix(); double diff = 0; for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) { for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) { double tmp = oldTransitionMatrix.getQuick(i, j) - newTransitionMatrix.getQuick(i, j); diff += tmp * tmp; } } double norm = Math.sqrt(diff); diff = 0; // check convergence of emissionProbabilities Matrix oldEmissionMatrix = oldModel.getEmissionMatrix(); Matrix newEmissionMatrix = newModel.getEmissionMatrix(); for (int i = 0; i < oldModel.getNrOfHiddenStates(); i++) { for (int j = 0; j < oldModel.getNrOfOutputStates(); j++) { double tmp = oldEmissionMatrix.getQuick(i, j) - newEmissionMatrix.getQuick(i, j); diff += tmp * tmp; } } norm += Math.sqrt(diff); return norm < epsilon; } public static byte[] doublePairToByteArray(double d1, double d2) { byte[] bytes = new byte[16]; ByteBuffer.wrap(bytes).putDouble(d1); ByteBuffer.wrap(bytes).putDouble(8, d2); return bytes; } public static double[] toDoublePair(byte[] bytes) { double[] pair = new double[2]; pair[0] = ByteBuffer.wrap(bytes).getDouble(); pair[1] = ByteBuffer.wrap(bytes).getDouble(8); return pair; } }