mlbench.bayes.train.IndexInstances.java Source code

Java tutorial

Introduction

Here is the source code for mlbench.bayes.train.IndexInstances.java

Source

/**
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package mlbench.bayes.train;

import mlbench.bayes.BayesUtils;
import mpi.MPIException;
import mpid.core.MPI_D;
import mpid.core.MPI_D_Combiner;
import mpid.core.MPI_D_Exception;
import mpid.core.util.MPI_D_Constants;
import mpid.util.DataMPIUtil;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.FileSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.RecordWriter;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.hadoop.mapred.SequenceFileRecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter;
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.map.OpenObjectIntHashMap;

import java.io.IOException;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;

public class IndexInstances {
    private static Configuration config;
    private static final String MAPRED_OUTPUT_DIR = "mapred.output.dir";
    private static String confPath = null;
    private static String inDir = null, outDir = null, labDir = null, maxUsedMemPercent = null, partSize = null,
            outFileNum = null, spillPercent = null;
    private static Path inPath = null, labPath = null;
    private static int rank = -1;

    public static class IndexInstancesCombiner implements MPI_D_Combiner<WritableComparable<?>, VectorWritable> {
        private VectorWritable result = new VectorWritable();
        private List<VectorWritable> results = new ArrayList<VectorWritable>();

        @Override
        public Iterator<VectorWritable> combine(WritableComparable<?> arg0, Iterator<VectorWritable> values)
                throws MPI_D_Exception {
            Vector vector = null;
            while (values.hasNext()) {
                VectorWritable v = (VectorWritable) values.next();
                if (vector == null) {
                    vector = v.get();
                } else {
                    vector.assign(v.get(), Functions.PLUS);
                }
            }
            results.clear();
            result.set(vector);
            results.add(result);
            return results.iterator();
        }

        public void close() {
        }

    }

    @SuppressWarnings({ "deprecation" })
    public static void main(String[] args) throws MPI_D_Exception, IOException, MPIException {
        parseArgs(args);
        HashMap<String, String> conf = new HashMap<String, String>();
        initConf(conf);
        MPI_D.Init(args, MPI_D.Mode.Common, conf);
        if (MPI_D.COMM_BIPARTITE_O != null) {
            rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_O);

            if (rank == 0) {
                System.out.println(IndexInstances.class.getSimpleName() + " O start.");
                createLabelIndex(labPath);
            }

            HadoopUtil.cacheFiles(labPath, config);

            MPI_D.COMM_BIPARTITE_O.Barrier();

            OpenObjectIntHashMap<String> labelIndex = BayesUtils.readIndexFromCache(config);

            if (MPI_D.COMM_BIPARTITE_O != null) {
                // O communicator
                int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_O);
                int size = MPI_D.Comm_size(MPI_D.COMM_BIPARTITE_O);
                FileSplit[] inputs = DataMPIUtil.HDFSDataLocalLocator.getTaskInputs(MPI_D.COMM_BIPARTITE_O,
                        (JobConf) config, inDir, rank);
                for (int i = 0; i < inputs.length; i++) {
                    FileSplit fsplit = inputs[i];
                    SequenceFileRecordReader<Text, VectorWritable> kvrr = new SequenceFileRecordReader<>(config,
                            fsplit);
                    Text labelText = kvrr.createKey();
                    VectorWritable instance = kvrr.createValue();
                    while (kvrr.next(labelText, instance)) {
                        String label = SLASH.split(labelText.toString())[1];
                        if (labelIndex.containsKey(label)) {
                            MPI_D.Send(new IntWritable(labelIndex.get(label)), instance);
                        }
                    }
                }
            }
        } else if (MPI_D.COMM_BIPARTITE_A != null) {
            int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_A);
            config.set(MAPRED_OUTPUT_DIR, outDir);
            config.set("mapred.task.id", DataMPIUtil.getHadoopTaskAttemptID().toString().toString());
            ((JobConf) config).setOutputKeyClass(IntWritable.class);
            ((JobConf) config).setOutputValueClass(VectorWritable.class);
            TaskAttemptContext taskContext = new TaskAttemptContextImpl(config,
                    DataMPIUtil.getHadoopTaskAttemptID());
            SequenceFileOutputFormat<IntWritable, VectorWritable> outfile = new SequenceFileOutputFormat<>();
            FileSystem fs = FileSystem.get(config);

            Path output = new Path(config.get(MAPRED_OUTPUT_DIR));
            FileOutputCommitter fcommitter = new FileOutputCommitter(output, taskContext);
            RecordWriter<IntWritable, VectorWritable> outrw = null;
            try {
                fcommitter.setupJob(taskContext);
                outrw = outfile.getRecordWriter(fs, (JobConf) config, getOutputName(rank), null);
            } catch (IOException e) {
                e.printStackTrace();
                System.err.println("ERROR: Please set the HDFS configuration properly\n");
                System.exit(-1);
            }

            IntWritable key = null, newKey = null;
            VectorWritable point = null, newPoint = null;
            Vector vector = null;
            Object[] vals = MPI_D.Recv();
            while (vals != null) {
                newKey = (IntWritable) vals[0];
                newPoint = (VectorWritable) vals[1];
                if (key == null && point == null) {
                } else if (!key.equals(newKey)) {
                    outrw.write(key, new VectorWritable(vector));
                    vector = null;
                }
                if (vector == null) {
                    vector = newPoint.get();
                } else {
                    vector.assign(newPoint.get(), Functions.PLUS);
                }

                key = newKey;
                point = newPoint;
                vals = MPI_D.Recv();
            }
            if (newKey != null && newPoint != null) {
                outrw.write(key, new VectorWritable(vector));
            }

            outrw.close(null);
            if (fcommitter.needsTaskCommit(taskContext)) {
                fcommitter.commitTask(taskContext);
            }
        }

        MPI_D.Finalize();
    }

    private static final Pattern SLASH = Pattern.compile("/");

    private static String getOutputName(int partition) {
        return "part-" + NumberFormat.getInstance().format(partition);
    }

    private static long createLabelIndex(Path labPath) throws IOException {
        long labelSize = 0;
        Iterable<Pair<Text, IntWritable>> iterable = new SequenceFileDirIterable<Text, IntWritable>(inPath,
                PathType.LIST, PathFilters.logsCRCFilter(), config);
        labelSize = BayesUtils.writeLabelIndex(config, labPath, iterable);
        return labelSize;
    }

    private static void initConf(HashMap<String, String> conf) {
        conf.put(MPI_D_Constants.ReservedKeys.KEY_CLASS, IntWritable.class.getName());
        conf.put(MPI_D_Constants.ReservedKeys.VALUE_CLASS, VectorWritable.class.getName());
        conf.put(MPI_D_Constants.ReservedKeys.COMBINER_CLASS, IndexInstancesCombiner.class.getName());

        if (maxUsedMemPercent != null) {
            conf.put(MPI_D_Constants.ReservedKeys.CommonModeKeys.MAX_MEM_USED_PERCENT, maxUsedMemPercent);
        }
        if (partSize != null) {
            conf.put(MPI_D_Constants.ReservedKeys.CommonModeKeys.BLOCK_SIZE, partSize);
        }
        if (outFileNum != null) {
            conf.put(MPI_D_Constants.ReservedKeys.CommonModeKeys.SEND_QUEUE_LENGTH, outFileNum);
        }
        if (spillPercent != null) {
            conf.put(MPI_D_Constants.ReservedKeys.CommonModeKeys.SPILL_PERCENT, spillPercent);
        }
        config = new JobConf(confPath);
    }

    private static void parseArgs(String[] args) {
        if (args.length < 4) {
            System.err.println("ERROR: Wrong number of parameters: " + args.length + " instead of 4.");
            System.err.println("Usage: IndexInstances <confPath> <inDir> <outDir> <label>");
            System.exit(-1);
        } else if (args.length == 4) {
            confPath = args[0];
            inDir = args[1];
            outDir = args[2];
            labDir = args[3];
            outDir += "/" + BayesUtils.SUMMED_OBSERVATIONS;
            inPath = new Path(inDir);
            labPath = new Path(labDir);
        } else if (args.length <= 8) {
            // Advanced Usage
            confPath = args[0];
            inDir = args[1];
            outDir = args[2];
            labDir = args[3];
            maxUsedMemPercent = args[4];
            partSize = args[5];
            outFileNum = args[6];
            spillPercent = args[7];
            outDir += "/" + BayesUtils.SUMMED_OBSERVATIONS;
            inPath = new Path(inDir);
            labPath = new Path(labDir);
        } else {
            System.err.println("ERROR: Error number of parameters.");
            System.exit(-1);
        }

    }
}