org.apache.mahout.classifier.bayes.mapreduce.common.BayesFeatureMapper.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.bayes.mapreduce.common.BayesFeatureMapper.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.bayes.mapreduce.common;

import java.io.IOException;
import java.util.regex.Pattern;

import com.google.common.collect.Iterators;
import org.apache.commons.lang.mutable.MutableDouble;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
import org.apache.lucene.analysis.shingle.ShingleFilter;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.mahout.classifier.bayes.BayesParameters;
import org.apache.mahout.common.StringTuple;
import org.apache.mahout.common.lucene.IteratorTokenStream;
import org.apache.mahout.math.function.ObjectIntProcedure;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Reads the input train set(preprocessed using the {@link org.apache.mahout.classifier.BayesFileFormatter}).
 */
public class BayesFeatureMapper extends MapReduceBase implements Mapper<Text, Text, StringTuple, DoubleWritable> {

    private static final Logger log = LoggerFactory.getLogger(BayesFeatureMapper.class);

    private static final DoubleWritable ONE = new DoubleWritable(1.0);
    private static final Pattern SPACE_TAB = Pattern.compile("[ \t]+");

    private int gramSize = 1;

    /**
     * We need to count the number of times we've seen a term with a given label and we need to output that. But
     * this Mapper does more than just outputing the count. It first does weight normalisation. Secondly, it
     * outputs for each unique word in a document value 1 for summing up as the Term Document Frequency. Which
     * later is used to calculate the Idf Thirdly, it outputs for each label the number of times a document was
     * seen(Also used in Idf Calculation)
     * 
     * @param key
     *          The label
     * @param value
     *          the features (all unique) associated w/ this label in stringtuple format
     * @param output
     *          The OutputCollector to write the results to
     * @param reporter
     *          Not used
     */
    @Override
    public void map(Text key, Text value, final OutputCollector<StringTuple, DoubleWritable> output,
            Reporter reporter) throws IOException {
        final String label = key.toString();
        String[] tokens = SPACE_TAB.split(value.toString());
        OpenObjectIntHashMap<String> wordList = new OpenObjectIntHashMap<String>(tokens.length * gramSize);

        if (gramSize > 1) {
            ShingleFilter sf = new ShingleFilter(new IteratorTokenStream(Iterators.forArray(tokens)), gramSize);
            do {
                String term = sf.getAttribute(CharTermAttribute.class).toString();
                if (!term.isEmpty()) {
                    if (wordList.containsKey(term)) {
                        wordList.put(term, 1 + wordList.get(term));
                    } else {
                        wordList.put(term, 1);
                    }
                }
            } while (sf.incrementToken());
        } else {
            for (String term : tokens) {
                if (wordList.containsKey(term)) {
                    wordList.put(term, 1 + wordList.get(term));
                } else {
                    wordList.put(term, 1);
                }
            }
        }
        final MutableDouble lengthNormalisationMut = new MutableDouble(0.0);
        wordList.forEachPair(new ObjectIntProcedure<String>() {
            @Override
            public boolean apply(String word, int dKJ) {
                long squared = (long) dKJ * (long) dKJ;
                lengthNormalisationMut.add(squared);
                return true;
            }
        });

        final double lengthNormalisation = Math.sqrt(lengthNormalisationMut.doubleValue());

        // Output Length Normalized + TF Transformed Frequency per Word per Class
        // Log(1 + D_ij)/SQRT( SIGMA(k, D_kj) )
        wordList.forEachPair(new ObjectIntProcedure<String>() {
            @Override
            public boolean apply(String token, int dKJ) {
                try {
                    StringTuple tuple = new StringTuple();
                    tuple.add(BayesConstants.WEIGHT);
                    tuple.add(label);
                    tuple.add(token);
                    DoubleWritable f = new DoubleWritable(Math.log1p(dKJ) / lengthNormalisation);
                    output.collect(tuple, f);
                } catch (IOException e) {
                    throw new IllegalStateException(e);
                }
                return true;
            }
        });
        reporter.setStatus("Bayes Feature Mapper: Document Label: " + label);

        // Output Document Frequency per Word per Class
        // Corpus Document Frequency (FEATURE_COUNT)
        // Corpus Term Frequency (FEATURE_TF)
        wordList.forEachPair(new ObjectIntProcedure<String>() {
            @Override
            public boolean apply(String token, int dKJ) {
                try {
                    StringTuple dfTuple = new StringTuple();
                    dfTuple.add(BayesConstants.DOCUMENT_FREQUENCY);
                    dfTuple.add(label);
                    dfTuple.add(token);
                    output.collect(dfTuple, ONE);

                    StringTuple tokenCountTuple = new StringTuple();
                    tokenCountTuple.add(BayesConstants.FEATURE_COUNT);
                    tokenCountTuple.add(token);
                    output.collect(tokenCountTuple, ONE);

                    StringTuple tokenTfTuple = new StringTuple();
                    tokenTfTuple.add(BayesConstants.FEATURE_TF);
                    tokenTfTuple.add(token);
                    output.collect(tokenTfTuple, new DoubleWritable(dKJ));
                } catch (IOException e) {
                    throw new IllegalStateException(e);
                }
                return true;
            }
        });

        // output that we have seen the label to calculate the Count of Document per
        // class
        StringTuple labelCountTuple = new StringTuple();
        labelCountTuple.add(BayesConstants.LABEL_COUNT);
        labelCountTuple.add(label);
        output.collect(labelCountTuple, ONE);
    }

    @Override
    public void configure(JobConf job) {
        try {
            BayesParameters params = new BayesParameters(job.get("bayes.parameters", ""));
            log.info("Bayes Parameter {}", params.print());
            gramSize = params.getGramSize();

        } catch (IOException ex) {
            log.warn(ex.toString(), ex);
        }
    }

}