Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.naivebayes.trainer; import java.io.IOException; import java.net.URI; import java.util.Iterator; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.filecache.DistributedCache; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapreduce.Mapper; import org.apache.mahout.classifier.naivebayes.BayesConstants; import org.apache.mahout.common.Pair; import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; import org.apache.mahout.math.Vector; import org.apache.mahout.math.Vector.Element; import org.apache.mahout.math.VectorWritable; import org.apache.mahout.math.map.OpenObjectIntHashMap; public class NaiveBayesThetaComplementaryMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> { private final OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>(); private Vector featureSum; private Vector labelSum; private Vector perLabelThetaNormalizer; private double alphaI = 1.0; private double vocabCount; private double totalSum; @Override protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, InterruptedException { Vector vector = value.get(); int label = key.get(); double sigmaK = labelSum.get(label); Iterator<Element> it = vector.iterateNonZero(); while (it.hasNext()) { Element e = it.next(); double numerator = featureSum.get(e.index()) - e.get() + alphaI; double denominator = totalSum - sigmaK + vocabCount; double weight = Math.log(numerator / denominator); perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight); } } @Override protected void setup(Context context) throws IOException, InterruptedException { super.setup(context); Configuration conf = context.getConfiguration(); URI[] localFiles = DistributedCache.getCacheFiles(conf); if (localFiles == null || localFiles.length < 2) { throw new IllegalArgumentException("missing paths from the DistributedCache"); } alphaI = conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f); Path weightFile = new Path(localFiles[0].getPath()); for (Pair<Text, VectorWritable> record : new SequenceFileIterable<Text, VectorWritable>(weightFile, true, conf)) { Text key = record.getFirst(); VectorWritable value = record.getSecond(); if (key.toString().equals(BayesConstants.FEATURE_SUM)) { featureSum = value.get(); } else if (key.toString().equals(BayesConstants.LABEL_SUM)) { labelSum = value.get(); } } perLabelThetaNormalizer = labelSum.like(); totalSum = labelSum.zSum(); vocabCount = featureSum.getNumNondefaultElements(); Path labelMapFile = new Path(localFiles[1].getPath()); // key is word value is id for (Pair<Writable, IntWritable> record : new SequenceFileIterable<Writable, IntWritable>(labelMapFile, true, conf)) { labelMap.put(record.getFirst().toString(), record.getSecond().get()); } } @Override protected void cleanup(Context context) throws IOException, InterruptedException { context.write(new Text(BayesConstants.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer)); super.cleanup(context); } }