mlbench.bayes.BayesUtils.java Source code

Java tutorial

Introduction

Here is the source code for mlbench.bayes.BayesUtils.java

Source

/**
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package mlbench.bayes;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.training.ThetaMapper;
import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.map.OpenObjectIntHashMap;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

public class BayesUtils {
    public static final Pattern SLASH = Pattern.compile("/");
    public static final String SUMMED_OBSERVATIONS = "summedObservations";
    public static final String WEIGHTS = "weights";
    private static final String DICTIONARY_FILE = "dictionary.file-";
    private static final String OUTPUT_FILES_PATTERN = "part-*";
    private static final int DICTIONARY_BYTE_OVERHEAD = 4;
    private static final String FREQUENCY_FILE = "frequency.file-";
    private static final int SEQUENCEFILE_BYTE_OVERHEAD = 45;

    static Pair<Long[], List<Path>> createDictionaryChunks(Path featureCountPath, Path dictionaryPathBase,
            Configuration baseConf, int chunkSizeInMegabytes) throws IOException {
        List<Path> chunkPaths = Lists.newArrayList();
        Configuration conf = new Configuration(baseConf);

        FileSystem fs = FileSystem.get(featureCountPath.toUri(), conf);

        long chunkSizeLimit = chunkSizeInMegabytes * 1024L * 1024L;
        int chunkIndex = 0;
        Path chunkPath = new Path(dictionaryPathBase, FREQUENCY_FILE + chunkIndex);
        chunkPaths.add(chunkPath);
        SequenceFile.Writer freqWriter = new SequenceFile.Writer(fs, conf, chunkPath, IntWritable.class,
                LongWritable.class);

        try {
            long currentChunkSize = 0;
            long featureCount = 0;
            long vectorCount = Long.MAX_VALUE;
            Path filesPattern = new Path(featureCountPath, OUTPUT_FILES_PATTERN);
            for (Pair<IntWritable, LongWritable> record : new SequenceFileDirIterable<IntWritable, LongWritable>(
                    filesPattern, PathType.GLOB, null, null, true, conf)) {

                if (currentChunkSize > chunkSizeLimit) {
                    Closeables.close(freqWriter, false);
                    chunkIndex++;

                    chunkPath = new Path(dictionaryPathBase, FREQUENCY_FILE + chunkIndex);
                    chunkPaths.add(chunkPath);

                    freqWriter = new SequenceFile.Writer(fs, conf, chunkPath, IntWritable.class,
                            LongWritable.class);
                    currentChunkSize = 0;
                }

                int fieldSize = SEQUENCEFILE_BYTE_OVERHEAD + Integer.SIZE / 8 + Long.SIZE / 8;
                currentChunkSize += fieldSize;
                IntWritable key = record.getFirst();
                LongWritable value = record.getSecond();
                if (key.get() >= 0) {
                    freqWriter.append(key, value);
                } else if (key.get() == -1) {
                    vectorCount = value.get();
                }
                featureCount = Math.max(key.get(), featureCount);

            }
            featureCount++;
            Long[] counts = { featureCount, vectorCount };
            return new Pair<Long[], List<Path>>(counts, chunkPaths);
        } finally {
            Closeables.close(freqWriter, false);
        }
    }

    static List<Path> createDictionaryChunks(Path wordCountPath, Path dictionaryPathBase, Configuration baseConf,
            int chunkSizeInMegabytes, int[] maxTermDimension) throws IOException {
        List<Path> chunkPaths = Lists.newArrayList();

        Configuration conf = new Configuration(baseConf);

        FileSystem fs = FileSystem.get(wordCountPath.toUri(), conf);

        long chunkSizeLimit = chunkSizeInMegabytes * 1024L * 1024L;
        int chunkIndex = 0;
        Path chunkPath = new Path(dictionaryPathBase, DICTIONARY_FILE + chunkIndex);
        chunkPaths.add(chunkPath);

        SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class,
                IntWritable.class);

        try {
            long currentChunkSize = 0;
            Path filesPattern = new Path(wordCountPath, OUTPUT_FILES_PATTERN);
            int i = 0;
            for (Pair<Writable, Writable> record : new SequenceFileDirIterable<Writable, Writable>(filesPattern,
                    PathType.GLOB, null, null, true, conf)) {
                if (currentChunkSize > chunkSizeLimit) {
                    Closeables.close(dictWriter, false);
                    chunkIndex++;

                    chunkPath = new Path(dictionaryPathBase, DICTIONARY_FILE + chunkIndex);
                    chunkPaths.add(chunkPath);

                    dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class);
                    currentChunkSize = 0;
                }

                Writable key = record.getFirst();
                int fieldSize = DICTIONARY_BYTE_OVERHEAD + key.toString().length() * 2 + Integer.SIZE / 8;
                currentChunkSize += fieldSize;
                dictWriter.append(key, new IntWritable(i++));
            }
            maxTermDimension[0] = i;
        } finally {
            Closeables.close(dictWriter, false);
        }

        return chunkPaths;
    }

    /** Write the list of labels into a map file */
    public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath)
            throws IOException {
        FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
        SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
        int i = 0;
        try {
            for (String label : labels) {
                writer.append(new Text(label), new IntWritable(i++));
            }
        } finally {
            Closeables.close(writer, false);
        }
        return i;
    }

    public static int writeLabelIndex(Configuration conf, Path indexPath, Iterable<Pair<Text, IntWritable>> labels)
            throws IOException {
        FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
        SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
        Collection<String> seen = Sets.newHashSet();
        int i = 0;
        try {
            for (Object label : labels) {
                String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1];
                if (!seen.contains(theLabel)) {
                    writer.append(new Text(theLabel), new IntWritable(i++));
                    seen.add(theLabel);
                }
            }
        } finally {
            Closeables.close(writer, false);
        }
        return i;
    }

    public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException {
        OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<String>();
        for (Pair<Writable, IntWritable> entry : new SequenceFileIterable<Writable, IntWritable>(
                HadoopUtil.getSingleCachedFile(conf), conf)) {
            index.put(entry.getFirst().toString(), entry.getSecond().get());
        }
        return index;
    }

    public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) {

        float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);

        // read feature sums and label sums
        Vector scoresPerLabel = null;
        Vector scoresPerFeature = null;
        for (Pair<Text, VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(
                new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {
            String key = record.getFirst().toString();
            VectorWritable value = record.getSecond();
            if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
                scoresPerFeature = value.get();
            } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
                scoresPerLabel = value.get();
            }
        }

        // Preconditions.checkNotNull(scoresPerFeature);
        // Preconditions.checkNotNull(scoresPerLabel);

        Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size());
        for (Pair<IntWritable, VectorWritable> entry : new SequenceFileDirIterable<IntWritable, VectorWritable>(
                new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(),
                conf)) {
            scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
        }

        Vector perlabelThetaNormalizer = scoresPerLabel.like();
        /*
         * for (Pair<Text,VectorWritable> entry : new
         * SequenceFileDirIterable<Text,VectorWritable>( new Path(base,
         * TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(),
         * conf)) { if (entry.getFirst().toString().equals(TrainNaiveBayesJob.
         * LABEL_THETA_NORMALIZER)) { perlabelThetaNormalizer =
         * entry.getSecond().get(); } }
         * 
         * Preconditions.checkNotNull(perlabelThetaNormalizer);
         */
        return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel,
                perlabelThetaNormalizer, alphaI, false);
    }

    public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) {
        Map<Integer, String> labelMap = new HashMap<Integer, String>();
        for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) {
            labelMap.put(pair.getSecond().get(), pair.getFirst().toString());
        }
        return labelMap;
    }

    public static void main(String args[]) throws IOException {
        Configuration conf = new JobConf(args[2]);
        NaiveBayesModel naiveBayesModel = readModelFromDir(new Path(args[0]), conf);
        naiveBayesModel.serialize(new Path(args[1]), conf);
    }
}