Java tutorial
/* * Copyright 2014 Indiana University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.iu.kmeans; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Calendar; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicLong; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.CollectiveMapper; import edu.iu.harp.arrpar.ArrCombiner; import edu.iu.harp.arrpar.ArrPartition; import edu.iu.harp.arrpar.ArrTable; import edu.iu.harp.arrpar.Double2DArrAvg; import edu.iu.harp.arrpar.DoubleArrPlus; import edu.iu.harp.comm.data.Array; import edu.iu.harp.comm.data.DoubleArray; import edu.iu.harp.comm.resource.ResourcePool; public class KMeansCollectiveMapper extends CollectiveMapper<String, String, Object, Object> { private int jobID; private int numMappers; private int vectorSize; private int cenVecSize; private int numCentroids; private int numCenPartitions; private int pointsPerFile; private int iteration; private int numMaxThreads; /** * Mapper configuration. */ @Override protected void setup(Context context) throws IOException, InterruptedException { LOG.info("start setup" + new SimpleDateFormat("yyyyMMdd_HHmmss").format(Calendar.getInstance().getTime())); long startTime = System.currentTimeMillis(); Configuration configuration = context.getConfiguration(); jobID = configuration.getInt(KMeansConstants.JOB_ID, 0); numMappers = configuration.getInt(KMeansConstants.NUM_MAPPERS, 10); numCentroids = configuration.getInt(KMeansConstants.NUM_CENTROIDS, 20); numCenPartitions = numMappers; pointsPerFile = configuration.getInt(KMeansConstants.POINTS_PER_FILE, 20); vectorSize = configuration.getInt(KMeansConstants.VECTOR_SIZE, 20); cenVecSize = vectorSize + 1; iteration = configuration.getInt(KMeansConstants.ITERATION_COUNT, 1); numMaxThreads = 32; long endTime = System.currentTimeMillis(); LOG.info("config (ms) :" + (endTime - startTime)); } protected void mapCollective(KeyValReader reader, Context context) throws IOException, InterruptedException { LOG.info("Start collective mapper."); long startTime = System.currentTimeMillis(); List<String> pointFiles = new ArrayList<String>(); while (reader.nextKeyValue()) { String key = reader.getCurrentKey(); String value = reader.getCurrentValue(); LOG.info("Key: " + key + ", Value: " + value); pointFiles.add(value); } Configuration conf = context.getConfiguration(); runKmeans(pointFiles, conf, context); LOG.info("Total iterations in master view: " + (System.currentTimeMillis() - startTime)); } private void runKmeans(List<String> fileNames, Configuration conf, Context context) throws IOException { // ----------------------------------------------------------------------------- // Load centroids long startTime = System.currentTimeMillis(); ArrTable<DoubleArray, DoubleArrPlus> cenTable = new ArrTable<DoubleArray, DoubleArrPlus>(0, DoubleArray.class, DoubleArrPlus.class); if (this.isMaster()) { Map<Integer, DoubleArray> cenParMap = createCenParMap(numCentroids, numCenPartitions, cenVecSize, this.getResourcePool()); loadCentroids(cenParMap, vectorSize, conf.get(KMeansConstants.CFILE), conf); addPartitionMapToTable(cenParMap, cenTable); } long endTime = System.currentTimeMillis(); LOG.info("Load centroids (ms): " + (endTime - startTime)); // ------------------------------------------------------------------------------ // Bcast centroids startTime = System.currentTimeMillis(); bcastCentroids(cenTable, numCenPartitions); endTime = System.currentTimeMillis(); LOG.info("Bcast centroids (ms): " + (endTime - startTime)); // --------------------------------------------------------------------------- // Load data points // Use less threads to load data startTime = System.currentTimeMillis(); List<DoubleArray> pointPartitions = doTasks(fileNames, "load-data-points", new PointLoadTask(pointsPerFile, vectorSize, conf), numMaxThreads); endTime = System.currentTimeMillis(); LOG.info("File read (ms): " + (endTime - startTime)); // ------------------------------------------------------------------------------------- // For iterations ArrPartition<DoubleArray>[] cenPartitions = null; AtomicLong appKernelTime = new AtomicLong(0); for (int i = 0; i < iteration; i++) { LOG.info("Iteration: " + i); // ----------------------------------------------------------------- // Calculate centroids dot prodcut c^2/2 cenPartitions = cenTable.getPartitions(); startTime = System.currentTimeMillis(); doTasks(cenPartitions, "centroids-dot-product", new CenDotProductTask(cenVecSize), numMaxThreads); endTime = System.currentTimeMillis(); LOG.info("Dot-product centroids (ms): " + (endTime - startTime)); // Calculate new centroids startTime = System.currentTimeMillis(); List<Map<Integer, DoubleArray>> localCenParMapList = doTasks( pointPartitions, "centroids-calc", new CenCalcTask(cenPartitions, numCentroids, numCenPartitions, vectorSize, this.getResourcePool(), context, appKernelTime), numMaxThreads); Set<Map<Integer, DoubleArray>> localCenParMapSet = new HashSet<Map<Integer, DoubleArray>>( localCenParMapList); localCenParMapList = null; endTime = System.currentTimeMillis(); LOG.info("Calculate centroids (ms): " + (endTime - startTime)); // ----------------------------------------------------------------- // Merge localCenParMaps to cenPartitions // localCenParMaps uses thread local objects // we convert list to set to avoid duplicates startTime = System.currentTimeMillis(); LOG.info("Local centroid data maps: " + localCenParMapSet.size()); doTasks(cenPartitions, "centroids-merge", new CenMergeTask(localCenParMapSet), numMaxThreads); localCenParMapSet = null; endTime = System.currentTimeMillis(); LOG.info("Merge centroids (ms): " + (endTime - startTime)); // ----------------------------------------------------------------- // Allreduce context.progress(); ArrTable<DoubleArray, DoubleArrPlus> newCenTable = new ArrTable<DoubleArray, DoubleArrPlus>(i, DoubleArray.class, DoubleArrPlus.class); // Old table is altered during allreduce, ignore it. // Allreduce: regroup-combine-aggregate(reduce)-allgather. // Old table is at the state after regroup-combine-aggregate, // but new table is after allgather. try { startTime = System.currentTimeMillis(); allreduce(cenTable, newCenTable, new Double2DArrAvg(cenVecSize)); endTime = System.currentTimeMillis(); LOG.info("Allreduce time (ms): " + (endTime - startTime)); } catch (Exception e) { LOG.error("Fail to do allreduce.", e); throw new IOException(e); } cenTable = newCenTable; newCenTable = null; context.progress(); logMemUsage(); logGCTime(); } // Write out new table if (this.isMaster()) { // LOG.info("App kernel time (nano seconds): " + appKernelTime.get()); LOG.info("Start to write out centroids."); startTime = System.currentTimeMillis(); storeCentroids(conf, cenTable, cenVecSize, jobID + 1); endTime = System.currentTimeMillis(); LOG.info("Store centroids time (ms): " + (endTime - startTime)); } // Clean all the references cenPartitions = cenTable.getPartitions(); releasePartitions(cenPartitions); cenPartitions = null; } /** * Less parameters, better. * * @param numCentroids * @param numCenPartitions * @param cenVecSize * @param resourcePool * @return */ static Map<Integer, DoubleArray> createCenParMap(int numCentroids, int numCenPartitions, int cenVecSize, ResourcePool resourcePool) { int cenParSize = numCentroids / numCenPartitions; int cenRest = numCentroids % numCenPartitions; Map<Integer, DoubleArray> cenParMap = new Int2ObjectAVLTreeMap<DoubleArray>(); double[] doubles = null; int size = 0; for (int i = 0; i < numCenPartitions; i++) { DoubleArray doubleArray = new DoubleArray(); if (cenRest > 0) { size = (cenParSize + 1) * cenVecSize; doubles = resourcePool.getDoubleArrayPool().getArray(size); doubleArray.setArray(doubles); doubleArray.setSize(size); cenRest--; } else if (cenParSize > 0) { size = cenParSize * cenVecSize; doubles = resourcePool.getDoubleArrayPool().getArray(size); doubleArray.setArray(doubles); doubleArray.setSize(size); } else { // cenParSize <= 0 break; } Arrays.fill(doubles, 0); cenParMap.put(i, doubleArray); } return cenParMap; } /** * Fill data from centroid file to cenDataMap * * @param cenDataMap * @param vectorSize * @param cFileName * @param configuration * @throws IOException */ private static void loadCentroids(Map<Integer, DoubleArray> cenDataMap, int vectorSize, String cFileName, Configuration configuration) throws IOException { Path cPath = new Path(cFileName); FileSystem fs = FileSystem.get(configuration); FSDataInputStream in = fs.open(cPath); BufferedReader br = new BufferedReader(new InputStreamReader(in)); double[] cData; int start; int size; int collen = vectorSize + 1; String[] curLine = null; int curPos = 0; for (DoubleArray array : cenDataMap.values()) { cData = array.getArray(); start = array.getStart(); size = array.getSize(); for (int i = start; i < (start + size); i++) { // Don't set the first element in each row if (i % collen != 0) { // cData[i] = in.readDouble(); if (curLine == null || curPos == curLine.length - 1) { curLine = br.readLine().split(" "); curPos = 0; } else { curPos++; } cData[i] = Double.parseDouble(curLine[curPos]); } } } br.close(); } private <A extends Array<?>, C extends ArrCombiner<A>> void addPartitionMapToTable(Map<Integer, A> map, ArrTable<A, C> table) throws IOException { for (Entry<Integer, A> entry : map.entrySet()) { try { table.addPartition(new ArrPartition<A>(entry.getValue(), entry.getKey())); } catch (Exception e) { LOG.error("Fail to add partitions", e); throw new IOException(e); } } } /** * Broadcast centroids data in partitions * * @param table * @param numPartitions * @throws IOException */ private <T, A extends Array<T>, C extends ArrCombiner<A>> void bcastCentroids(ArrTable<A, C> table, int numTotalPartitions) throws IOException { boolean success = true; try { success = arrTableBcastTotalKnown(table, numTotalPartitions); } catch (Exception e) { LOG.error("Fail to bcast.", e); } if (!success) { throw new IOException("Fail to bcast"); } } private void releasePartitions(ArrPartition<DoubleArray>[] partitions) { for (int i = 0; i < partitions.length; i++) { this.getResourcePool().getDoubleArrayPool().releaseArrayInUse(partitions[i].getArray().getArray()); } } private static void storeCentroids(Configuration configuration, ArrTable<DoubleArray, DoubleArrPlus> cenTable, int cenVecSize, int jobID) throws IOException { String cFile = configuration.get(KMeansConstants.CFILE); Path cPath = new Path(cFile.substring(0, cFile.lastIndexOf("_") + 1) + jobID); LOG.info("centroids path: " + cPath.toString()); FileSystem fs = FileSystem.get(configuration); fs.delete(cPath, true); FSDataOutputStream out = fs.create(cPath); BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(out)); ArrPartition<DoubleArray> partitions[] = cenTable.getPartitions(); int linePos = 0; for (ArrPartition<DoubleArray> partition : partitions) { for (int i = 0; i < partition.getArray().getSize(); i++) { linePos = i % cenVecSize; if (linePos == (cenVecSize - 1)) { bw.write(partition.getArray().getArray()[i] + "\n"); } else if (linePos > 0) { // Every row with vectorSize + 1 length, the first one is a count, // ignore it in output bw.write(partition.getArray().getArray()[i] + " "); } } } bw.flush(); bw.close(); // out.flush(); // out.sync(); // out.close(); } }