org.apache.mahout.math.hadoop.stats.BasicStats.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.math.hadoop.stats.BasicStats.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math.hadoop.stats;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;

import java.io.IOException;

/**
 * Methods for calculating basic stats (mean, variance, stdDev, etc.) in map/reduce
 */
public final class BasicStats {

    private BasicStats() {
    }

    /**
     * Calculate the variance of values stored as
     *
     * @param input    The input file containing the key and the count
     * @param output   The output to store the intermediate values
     * @param baseConf
     * @return The variance (based on sample estimation)
     */
    public static double variance(Path input, Path output, Configuration baseConf)
            throws IOException, InterruptedException, ClassNotFoundException {
        VarianceTotals varianceTotals = computeVarianceTotals(input, output, baseConf);
        return varianceTotals.computeVariance();
    }

    /**
     * Calculate the variance by a predefined mean of values stored as
     *
     * @param input    The input file containing the key and the count
     * @param output   The output to store the intermediate values
     * @param mean The mean based on which to compute the variance
     * @param baseConf
     * @return The variance (based on sample estimation)
     */
    public static double varianceForGivenMean(Path input, Path output, double mean, Configuration baseConf)
            throws IOException, InterruptedException, ClassNotFoundException {
        VarianceTotals varianceTotals = computeVarianceTotals(input, output, baseConf);
        return varianceTotals.computeVarianceForGivenMean(mean);
    }

    private static VarianceTotals computeVarianceTotals(Path input, Path output, Configuration baseConf)
            throws IOException, InterruptedException, ClassNotFoundException {
        Configuration conf = new Configuration(baseConf);
        conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
                + "org.apache.hadoop.io.serializer.WritableSerialization");
        Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class,
                StandardDeviationCalculatorMapper.class, IntWritable.class, DoubleWritable.class,
                StandardDeviationCalculatorReducer.class, IntWritable.class, DoubleWritable.class,
                SequenceFileOutputFormat.class, conf);
        HadoopUtil.delete(conf, output);
        job.setCombinerClass(StandardDeviationCalculatorReducer.class);
        boolean succeeded = job.waitForCompletion(true);
        if (!succeeded) {
            throw new IllegalStateException("Job failed!");
        }

        // Now extract the computed sum
        Path filesPattern = new Path(output, "part-*");
        double sumOfSquares = 0;
        double sum = 0;
        double totalCount = 0;
        for (Pair<Writable, Writable> record : new SequenceFileDirIterable<Writable, Writable>(filesPattern,
                PathType.GLOB, null, null, true, conf)) {

            int key = ((IntWritable) record.getFirst()).get();
            if (key == StandardDeviationCalculatorMapper.SUM_OF_SQUARES.get()) {
                sumOfSquares += ((DoubleWritable) record.getSecond()).get();
            } else if (key == StandardDeviationCalculatorMapper.TOTAL_COUNT.get()) {
                totalCount += ((DoubleWritable) record.getSecond()).get();
            } else if (key == StandardDeviationCalculatorMapper.SUM.get()) {
                sum += ((DoubleWritable) record.getSecond()).get();
            }
        }

        VarianceTotals varianceTotals = new VarianceTotals();
        varianceTotals.setSum(sum);
        varianceTotals.setSumOfSquares(sumOfSquares);
        varianceTotals.setTotalCount(totalCount);

        return varianceTotals;
    }

    /**
     * Calculate the standard deviation
     *
     * @param input    The input file containing the key and the count
     * @param output   The output file to write the counting results to
     * @param baseConf The base configuration
     * @return The standard deviation
     */
    public static double stdDev(Path input, Path output, Configuration baseConf)
            throws IOException, InterruptedException, ClassNotFoundException {
        return Math.sqrt(variance(input, output, baseConf));
    }

    /**
     * Calculate the standard deviation given a predefined mean
     *
     * @param input    The input file containing the key and the count
     * @param output   The output file to write the counting results to
     * @param mean The mean based on which to compute the standard deviation
     * @param baseConf The base configuration
     * @return The standard deviation
     */
    public static double stdDevForGivenMean(Path input, Path output, double mean, Configuration baseConf)
            throws IOException, InterruptedException, ClassNotFoundException {
        return Math.sqrt(varianceForGivenMean(input, output, mean, baseConf));
    }
}