junto.algorithm.parallel.MADHadoop.java Source code

Java tutorial

Introduction

Here is the source code for junto.algorithm.parallel.MADHadoop.java

Source

package junto.algorithm.parallel;

/**
 * Copyright 2011 Partha Pratim Talukdar
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import gnu.trove.map.hash.TObjectDoubleHashMap;
import gnu.trove.iterator.TObjectDoubleIterator;

import java.io.IOException;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;

import junto.config.*;
import junto.util.*;

import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.TextInputFormat;
import org.apache.hadoop.mapred.TextOutputFormat;
import org.apache.hadoop.mapred.jobcontrol.Job;

public class MADHadoop {

    private static String _kDelim = "\t";

    public static class MADHadoopMap extends MapReduceBase implements Mapper<LongWritable, Text, Text, Text> {
        private Text word = new Text();

        public void map(LongWritable key, Text value, OutputCollector<Text, Text> output, Reporter reporter)
                throws IOException {
            /////
            // Constructing the vertex from the string representation
            /////
            String line = value.toString();

            // id gold_label injected_labels estimated_labels neighbors rw_probabilities 
            String[] fields = line.split(_kDelim);
            TObjectDoubleHashMap neighbors = CollectionUtil.String2Map(fields[4]);
            TObjectDoubleHashMap rwProbabilities = CollectionUtil.String2Map(fields[5]);

            // If the current node is a seed node but there is no
            // estimate label information yet, then transfer the seed label
            // to the estimated label distribution. Ideally, this is likely
            // to be used in the map of the very first iteration.
            boolean isSeedNode = fields[2].length() > 0 ? true : false;
            if (isSeedNode && fields[3].length() == 0) {
                fields[3] = fields[2];
            }

            // TODO(partha): move messages to ProtocolBuffers

            // Send two types of messages:
            //   -- self messages which will store the injection labels and
            //        random walk probabilities.
            //   -- messages to neighbors about current estimated scores
            //        of the node.
            //
            // message to self
            output.collect(new Text(fields[0]), new Text("labels" + _kDelim + line));

            // message to neighbors
            TObjectDoubleIterator neighIterator = neighbors.iterator();
            while (neighIterator.hasNext()) {
                neighIterator.advance();

                // message (neighbor_node, current_node + DELIM + curr_node_label_scores
                output.collect(new Text((String) neighIterator.key()),
                        new Text("labels" + _kDelim + fields[0] + _kDelim + fields[3]));

                // message (neighbor_node, curr_node + DELIM + curr_node_edge_weights + DELIM curr_node_cont_prob
                assert (neighbors.containsKey((String) neighIterator.key()));
                output.collect(new Text((String) neighIterator.key()),
                        new Text("edge_info" + _kDelim + fields[0] + _kDelim
                                + neighbors.get((String) neighIterator.key()) + _kDelim
                                + rwProbabilities.get(Constants._kContProb)));
            }
        }
    }

    public static class MADHadoopReduce extends MapReduceBase implements Reducer<Text, Text, Text, Text> {
        private static double mu1;
        private static double mu2;
        private static double mu3;
        private static int keepTopKLabels;

        public void configure(JobConf conf) {
            mu1 = Double.parseDouble(conf.get("mu1"));
            mu2 = Double.parseDouble(conf.get("mu2"));
            mu3 = Double.parseDouble(conf.get("mu3"));
            keepTopKLabels = Integer.parseInt(conf.get("keepTopKLabels"));
        }

        public void reduce(Text key, Iterator<Text> values, OutputCollector<Text, Text> output, Reporter reporter)
                throws IOException {
            // new scores estimated for the current node
            TObjectDoubleHashMap newEstimatedScores = new TObjectDoubleHashMap();

            // set to true only if the message sent to itself is found.
            boolean isSelfMessageFound = false;

            String vertexId = key.toString();
            String vertexString = "";

            TObjectDoubleHashMap neighbors = null;
            TObjectDoubleHashMap randWalkProbs = null;

            HashMap<String, String> neighScores = new HashMap<String, String>();

            TObjectDoubleHashMap incomingEdgeWeights = new TObjectDoubleHashMap();
            TObjectDoubleHashMap neighborContProb = new TObjectDoubleHashMap();

            int totalMessagesReceived = 0;

            // iterate over all the messages received at the node
            while (values.hasNext()) {
                ++totalMessagesReceived;

                String val = values.next().toString();
                String[] fields = val.split(_kDelim);

                // first field represents the type of message
                String msgType = fields[0];

                if (fields[0].equals("labels")) {
                    // self-message check
                    if (vertexId.equals(fields[1])) {
                        isSelfMessageFound = true;
                        vertexString = val;

                        TObjectDoubleHashMap injLabels = CollectionUtil.String2Map(fields[3]);
                        neighbors = CollectionUtil.String2Map(neighbors, fields[5]);
                        randWalkProbs = CollectionUtil.String2Map(fields[6]);

                        if (injLabels.size() > 0) {
                            // add injected labels to the estimated scores.
                            ProbUtil.AddScores(newEstimatedScores, mu1 * randWalkProbs.get(Constants._kInjProb),
                                    injLabels);
                        }
                    } else {
                        // an empty third field represents that the
                        // neighbor has no valid label assignment yet.
                        if (fields.length > 2) {
                            neighScores.put(fields[1], fields[2]);
                        }
                    }
                } else if (msgType.equals("edge_info")) {
                    // edge_info neigh_vertex incoming_edge_weight cont_prob
                    String neighId = fields[1];

                    if (!incomingEdgeWeights.contains(neighId)) {
                        incomingEdgeWeights.put(neighId, Double.parseDouble(fields[2]));
                    }

                    if (!neighborContProb.contains(neighId)) {
                        neighborContProb.put(neighId, Double.parseDouble(fields[3]));
                    }
                } else {
                    throw new RuntimeException("Invalid message: " + val);
                }
            }

            // terminate if message from self is not received.
            if (!isSelfMessageFound) {
                throw new RuntimeException("Self message not received for node " + vertexId);
            }

            // collect neighbors' label distributions and create one single
            // label distribution
            TObjectDoubleHashMap weightedNeigLablDist = new TObjectDoubleHashMap();
            Iterator<String> neighIter = neighScores.keySet().iterator();
            while (neighIter.hasNext()) {
                String neighName = neighIter.next();

                double mult = randWalkProbs.get(Constants._kContProb) * neighbors.get(neighName)
                        + neighborContProb.get(neighName) * incomingEdgeWeights.get(neighName);

                ProbUtil.AddScores(weightedNeigLablDist, // newEstimatedScores,
                        mu2 * mult, CollectionUtil.String2Map(neighScores.get(neighName)));
            }

            // now add the collective neighbor label distribution to
            // the estimate of the current node's labels.
            ProbUtil.AddScores(newEstimatedScores, 1.0, weightedNeigLablDist);

            // add dummy label scores
            ProbUtil.AddScores(newEstimatedScores, mu3 * randWalkProbs.get(Constants._kTermProb),
                    Constants.GetDummyLabelDist());

            if (keepTopKLabels < Integer.MAX_VALUE) {
                ProbUtil.KeepTopScoringKeys(newEstimatedScores, keepTopKLabels);
            }

            ProbUtil.DivScores(newEstimatedScores, GetNormalizationConstant(neighbors, randWalkProbs,
                    incomingEdgeWeights, neighborContProb, mu1, mu2, mu3));

            // now reconstruct the vertex representation (with the new estimated scores)
            // so that the output from the current mapper can be used as input in next
            // iteration's mapper.
            String[] vertexFields = vertexString.split(_kDelim);

            // replace estimated scores with the new ones.
            // Skip the first two fields as they contained the message header and
            // vertex id respectively.
            String[] newVertexFields = new String[vertexFields.length - 2];
            for (int i = 2; i < vertexFields.length; ++i) {
                newVertexFields[i - 2] = vertexFields[i];
            }
            newVertexFields[2] = CollectionUtil.Map2String(newEstimatedScores);

            output.collect(key, new Text(CollectionUtil.Join(newVertexFields, _kDelim)));
        }

        public double GetNormalizationConstant(TObjectDoubleHashMap neighbors, TObjectDoubleHashMap randWalkProbs,
                TObjectDoubleHashMap incomingEdgeWeights, TObjectDoubleHashMap neighborContProb, double mu1,
                double mu2, double mu3) {
            double mii = 0;
            double totalNeighWeight = 0;
            TObjectDoubleIterator nIter = neighbors.iterator();
            while (nIter.hasNext()) {
                nIter.advance();
                totalNeighWeight += randWalkProbs.get(Constants._kContProb) * nIter.value();

                String neighName = (String) nIter.key();
                totalNeighWeight += neighborContProb.get(neighName) * incomingEdgeWeights.get(neighName);
            }

            // mu1 x p^{inj} +
            //   0.5 * mu2 x \sum_j (p_{i}^{cont} W_{ij} + p_{j}^{cont} W_{ji}) + 
            //   mu3
            mii = mu1 * randWalkProbs.get(Constants._kInjProb) +
            /*0.5 **/ mu2 * totalNeighWeight + mu3;

            return (mii);
        }
    }

    public static void main(String[] args) throws Exception {
        Hashtable config = ConfigReader.read_config(args);

        String baseInputFilePat = Defaults.GetValueOrDie(config, "hdfs_input_pattern");
        String baseOutputFilePat = Defaults.GetValueOrDie(config, "hdfs_output_base");
        int numIterations = Integer.parseInt(Defaults.GetValueOrDie(config, "iters"));
        int numReducers = Defaults.GetValueOrDefault((String) config.get("num_reducers"), 10);

        String currInputFilePat = baseInputFilePat;
        String currOutputFilePat = "";
        for (int iter = 1; iter <= numIterations; ++iter) {
            JobConf conf = new JobConf(MADHadoop.class);
            conf.setJobName("mad_hadoop");

            conf.setOutputKeyClass(Text.class);
            conf.setOutputValueClass(Text.class);

            conf.setMapperClass(MADHadoopMap.class);
            // conf.setCombinerClass(MADHadoopReduce.class);
            conf.setReducerClass(MADHadoopReduce.class);
            conf.setNumReduceTasks(numReducers);

            conf.setInputFormat(TextInputFormat.class);
            conf.setOutputFormat(TextOutputFormat.class);

            // hyperparameters
            conf.set("mu1", Defaults.GetValueOrDie(config, "mu1"));
            conf.set("mu2", Defaults.GetValueOrDie(config, "mu2"));
            conf.set("mu3", Defaults.GetValueOrDie(config, "mu3"));
            conf.set("keepTopKLabels", Defaults.GetValueOrDefault((String) config.get("keep_top_k_labels"),
                    Integer.toString(Integer.MAX_VALUE)));

            if (iter > 1) {
                // output from last iteration is the input for current iteration
                currInputFilePat = currOutputFilePat + "/*";
            }
            FileInputFormat.setInputPaths(conf, new Path(currInputFilePat));

            currOutputFilePat = baseOutputFilePat + "_iter_" + iter;
            FileOutputFormat.setOutputPath(conf, new Path(currOutputFilePat));

            JobClient.runJob(conf);
        }
    }
}