cs6240.project.decisiontree.Pseudohigstest.java Source code

Java tutorial

Introduction

Here is the source code for cs6240.project.decisiontree.Pseudohigstest.java

Source

/**
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package cs6240.project.decisiontree;

//package org.apache.hadoop.examples;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.StringTokenizer;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Partitioner;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.Mapper.Context;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;

import java.net.URI;
import cs6240.project.decisiontree.Histogram.InitializationPartioner;

public class Pseudohigstest {

    public static int Correct_counter = 10;
    public static int Incorrect_counter = 0;
    public static float spam_ratio = 0;
    public static float nspam_ratio = 0;

    public static class TestingMapper extends Mapper<Object, Text, IntWritable, IntWritable> {

        // hashmap to maintain the histograms
        HashMap metadata = new HashMap();

        Integer Bucket_count = 5;

        /*
         * the part files from the training model is aggregated into a single
         * file which will now contain histogram metadata about all the
         * attributes this is on S3 and has to be loaded onto the testing
         * machines using distributed cache. The contents are then read and the
         * hash of histograms indexed by attribute index is constructed.
         */
        public void setupMetadata(Context context) {

            int att;
            String[] input = null;
            String sCurrentLine;
            String[] flatHistogram = null;
            BufferedReader br = null;

            try {

                Path[] uris = DistributedCache.getLocalCacheFiles(context.getConfiguration());

                for (int i = 0; i < uris.length; i++) {
                    if (uris[i].toString().contains("metadata")) {

                        String chunk = null;
                        br = new BufferedReader(new FileReader(uris[i].toString()));
                    }
                }

                // populating the histogram
                br = new BufferedReader(new FileReader(uris[0].toString()));

                while ((sCurrentLine = br.readLine()) != null) {
                    Double[][] his = new Double[4][Bucket_count];
                    for (Double[] temp : his)
                        Arrays.fill(temp, 0.0);
                    System.out.println(sCurrentLine);
                    input = sCurrentLine.split("\t");
                    att = Integer.parseInt((input[0]));

                    flatHistogram = input[1].split(", ");
                    for (String hi : flatHistogram) {
                        String[] parameter = hi.split(" ");
                        int i = Integer.parseInt(parameter[0]);
                        int j = Integer.parseInt(parameter[1]);
                        his[i][j] = Double.parseDouble(parameter[2]);
                    }
                    metadata.put(att, his);

                }
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                try {
                    if (br != null)
                        br.close();
                } catch (IOException ex) {
                    ex.printStackTrace();
                }
            }

        }

        /*
         * each dataset has a unique ratio of spam and nonspam users
         * which is required for the naive bayes formula and hence
         * is calculated here
         */
        public void calculateSpamRatio() {
            Double GrandTotal = 0.0;
            Double SpamTotal = 0.0;
            Double NspamTotal = 0.0;
            for (int i = 0; i < 15; i++) {
                Double[][] his = (Double[][]) (metadata.get(i));
                for (int j = 0; j < Bucket_count; j++) {
                    GrandTotal += his[1][j];
                    SpamTotal += his[2][j];
                    NspamTotal += his[3][j];
                }
            }
            spam_ratio = (float) (SpamTotal / GrandTotal);
            nspam_ratio = (float) (NspamTotal / GrandTotal);

        }

        protected void setup(Context context) throws IOException, InterruptedException {
            setupMetadata(context);
            calculateSpamRatio();
        }

        /*
         * the testing map iterates over each record and fits each attribute value 
         * into a bucket calculating the probability for it based on the label
         * the total product of probabilitiies over all attributes results in 2
         * values, comparing which will help us determine the label for that record
         * This is then compared to the records label and if its a match we increment 
         * a"correct" counter else an "incorrect" counter and send them both to the reducer
         */
        public void map(Object key, Text value, Context context) throws IOException, InterruptedException {

            String[] records = value.toString().split("\n");
            for (String record : records) {
                String[] attributes = record.trim().split(",");

                float SPAM = spam_ratio;
                float NONSPAM = nspam_ratio;
                Double prediction = 0.0;
                Double label = 0.0;
                Integer container_bucket = null;
                for (int i = 1; i < attributes.length; i++) {
                    Double attrib = Double.parseDouble(attributes[i]);
                    label = Double.parseDouble(attributes[0]);
                    SPAM = 1;
                    NONSPAM = 1;

                    Double[] buckets = ((Double[][]) metadata.get(i - 1))[0];
                    // =======get bucket number that contains attr=============
                    for (int col = 0; col < Bucket_count; col++) {
                        if (attrib <= buckets[col]) {
                            container_bucket = col;
                            Double TC = (((Double[][]) metadata.get(i - 1))[1][container_bucket]);
                            Double SC = (((Double[][]) metadata.get(i - 1))[2][container_bucket]);
                            Double NSC = (((Double[][]) metadata.get(i - 1))[3][container_bucket]);

                            // ===calculating probability
                            SPAM *= (float) (SC / TC);
                            NONSPAM *= (float) (NSC / TC);

                            break;
                        }
                    }

                }
                if (SPAM > NONSPAM)
                    prediction = 1.0;
                else
                    prediction = 0.0;
                // ===========VErification===========
                int ans = Double.compare(label, prediction);
                if ((ans < 0) || (ans > 0)) {
                    Incorrect_counter++;
                } else {
                    Correct_counter++;
                }

            }
        }

        protected void cleanup(Context context) throws IOException, InterruptedException {

            super.cleanup(context);
            context.write(new IntWritable(1), new IntWritable(Correct_counter));
            context.write(new IntWritable(0), new IntWritable(Incorrect_counter));
            // DistributedCache.purgeCache(context.getConfiguration());
        }

    }

    /*
     * the partition happen over 2 reduce tasks based on the binary values
     * for correct and incorrect counters
     */
    public static class TestingPartioner extends Partitioner<IntWritable, IntWritable> {
        @Override
        public int getPartition(IntWritable key, IntWritable value, int numofpartitions) {

            return Math.abs(key.hashCode() % numofpartitions);
        }
    }

    /*
     * The reducer simple iterates over the values of each such counter and 
     * sums it up and outputs it
     */
    public static class TestingReducer extends Reducer<IntWritable, IntWritable, IntWritable, IntWritable> {

        Integer sum = 0;
        Integer k = 0;

        public void reduce(IntWritable key, Iterable<IntWritable> values, Context context)
                throws IOException, InterruptedException {

            k = key.get();
            for (IntWritable val : values) {

                sum += val.get();
            }
        }

        protected void cleanup(Context context) throws IOException, InterruptedException {
            super.cleanup(context);
            context.write(new IntWritable(k), new IntWritable(sum));
        }

    }

    public static void main(String[] args) throws Exception {
        Configuration conf = new Configuration();
        String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();
        DistributedCache.addCacheFile(new URI("s3://hr6240/higs/testing/5/higshistogram"), conf);
        // DistributedCache.addCacheFile(new
        // URI("/home/hraj17/Downloads/part-hig"),conf);
        Job job = new Job(conf, "word count");
        job.setJarByClass(Pseudohigstest.class);
        job.setMapperClass(TestingMapper.class);
        job.setReducerClass(TestingReducer.class);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(IntWritable.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(IntWritable.class);

        job.setPartitionerClass(TestingPartioner.class);
        job.setNumReduceTasks(2);

        FileInputFormat.addInputPath(job, new Path(otherArgs[0]));
        FileOutputFormat.setOutputPath(job, new Path(otherArgs[1]));
        System.exit(job.waitForCompletion(true) ? 0 : 1);
    }
}