Java tutorial
package com.tetsuyaodaka.hadoop.math.matrix; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Reducer; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.lib.input.MultipleInputs; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; /** * Matrix Multiplication on Hadoop Map Reduce * * author : tetsuya.odaka@gmail.com * tested on Hadoop1.2 * * Split the Large Scale Matrix to SubMatrices. * Split size (Number Of Rows or Columns) can be specified by arguments. * * This should be decided according to your resources. * Partitioner and Conditioner are not implemented here. * Can calculate real numbers (format double) and be expected. * * This program is distributed under ASF2.0 LICENSE. * */ public class MatrixMult { /* * IndexPair Class * *reduce??MatrixA???MatrixB??????? *customized key for reduce function consists of row BlockNum of MatrixA, MatrixB, and number of elements. * */ public static class IndexPair implements WritableComparable<MatrixMult.IndexPair> { public int index1; public int index2; public IndexPair() { } public IndexPair(int index1, int index2) { this.index1 = index1; this.index2 = index2; } public void write(DataOutput out) throws IOException { out.writeInt(index1); out.writeInt(index2); } public void readFields(DataInput in) throws IOException { index1 = in.readInt(); index2 = in.readInt(); } public int compareTo(MatrixMult.IndexPair o) { if (this.index1 < o.index1) { return -1; } else if (this.index1 > o.index1) { return +1; } if (this.index2 < o.index2) { return -1; } else if (this.index2 > o.index2) { return +1; } return 0; } /* * hasHash() is used by HashPartitionar. */ public int hashCode() { int ib = this.index1; int jb = this.index2; int num = ib * Integer.MAX_VALUE + jb; int hash = new Integer(num).hashCode(); return Math.abs(hash); } } /* * MapA Class * *Matrix A?????? * read MatrixA and decompose it to blocks * */ public static class MapA extends Mapper<LongWritable, Text, MatrixMult.IndexPair, Text> { @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { String strArr[] = value.toString().split("\t"); int i = Integer.parseInt(strArr[0]); String v = strArr[1]; int m = 0; // retrieve from configuration int IB = Integer.parseInt(context.getConfiguration().get("IB")); int N = Integer.parseInt(context.getConfiguration().get("N")); if (i % IB == 0) { m = i / IB; } else { m = i / IB + 1; } for (int j = 1; j < (N + 1); j++) { context.write(new MatrixMult.IndexPair(m, j), new Text("0" + "," + i + "," + v)); } } } /* * MapB Class * *Matrix B'?????? * read MatrixB and decompose it to blocks * */ public static class MapB extends Mapper<LongWritable, Text, MatrixMult.IndexPair, Text> { @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { String strArr[] = value.toString().split("\t"); int k = Integer.parseInt(strArr[0]); String v = strArr[1]; int n = 0; // retrieve from configuration int KB = Integer.parseInt(context.getConfiguration().get("KB")); int M = Integer.parseInt(context.getConfiguration().get("M")); if (k % KB == 0) { n = k / KB; } else { n = k / KB + 1; } for (int j = 1; j < (M + 1); j++) { context.write(new MatrixMult.IndexPair(j, n), new Text("1" + "," + k + "," + v)); } } } /* * Reduce Class * */ public static class Reduce extends Reducer<MatrixMult.IndexPair, Text, Text, DoubleWritable> { @Override protected void reduce(MatrixMult.IndexPair key, Iterable<Text> values, Context context) throws IOException, InterruptedException { List<RowContents> aList = new ArrayList<RowContents>(); List<RowContents> bList = new ArrayList<RowContents>(); Map<String, List<RowContents>> cMap = new HashMap<String, List<RowContents>>(); cMap.put("A", aList); cMap.put("B", bList); for (Text value : values) { String strVal = value.toString(); String mtx; String sRow; String[] strArray = strVal.split(","); if (Integer.parseInt(strArray[0]) == 0) { mtx = "A"; } else { mtx = "B"; } sRow = strArray[1] + "," + strArray[2]; cMap.get(mtx).add(new RowContents(sRow)); } for (RowContents ra : cMap.get("A")) { for (RowContents rb : cMap.get("B")) { int indexA = ra.index; int indexB = rb.index; double sum = 0; for (int i = 0; i < ra.lstRow.size(); i++) { sum += ra.lstRow.get(i) * rb.lstRow.get(i); } BigDecimal bd = new BigDecimal(sum); BigDecimal r = bd.setScale(2, BigDecimal.ROUND_HALF_UP); context.write(new Text(indexA + " " + indexB + " "), new DoubleWritable(r.doubleValue())); } } } public class RowContents { public String strRow; public int index; // means row index public List<Double> lstRow; // list of elements of row. public RowContents() { } public RowContents(String strRow) { this.strRow = strRow; this.lstRow = new ArrayList<Double>(); this.calculate(); } public void calculate() { String[] strArr = this.strRow.split(","); this.index = Integer.parseInt(strArr[0]); String[] aArr = strArr[1].split(" "); for (int i = 0; i < aArr.length; i++) { this.lstRow.add(Double.parseDouble(aArr[i])); } return; } } } public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException { Date startProc = new Date(System.currentTimeMillis()); System.out.println("process started at " + startProc); Configuration conf = new Configuration(); int I = Integer.parseInt(args[3]); // Num of Row of MatrixA int K = Integer.parseInt(args[4]); // Num of Row of MatrixB' int IB = Integer.parseInt(args[5]); // RowBlock Size of MatrixA int KB = Integer.parseInt(args[6]); // RowBlock Size of MatrixB' int M = 0; if (I % IB == 0) { M = I / IB; } else { M = I / IB + 1; } int N = 0; if (K % KB == 0) { N = K / KB; } else { N = K / KB + 1; } conf.set("I", args[3]); // Num of Row of MatrixA conf.set("K", args[4]); // Num of Row of MatrixB' conf.set("IB", args[5]); // RowBlock Size of MatrixA conf.set("KB", args[6]); // RowBlock Size of MatrixB' conf.set("M", new Integer(M).toString()); conf.set("N", new Integer(N).toString()); Job job = new Job(conf, "MatrixMultiplication"); job.setJarByClass(MatrixMult.class); job.setReducerClass(Reduce.class); job.setMapOutputKeyClass(MatrixMult.IndexPair.class); job.setMapOutputValueClass(Text.class); job.setOutputKeyClass(Text.class); job.setOutputValueClass(Text.class); // Mapper????? MultipleInputs.addInputPath(job, new Path(args[0]), TextInputFormat.class, MapA.class); // matrixA MultipleInputs.addInputPath(job, new Path(args[1]), TextInputFormat.class, MapB.class); // matrixB FileOutputFormat.setOutputPath(job, new Path(args[2])); // output path System.out.println("num of MatrixA RowBlock(M) is " + M); System.out.println("num of MatrixB RowBlock(N) is " + N); boolean success = job.waitForCompletion(true); Date endProc = new Date(System.currentTimeMillis()); System.out.println("process ended at " + endProc); System.out.println(success); } }