weka.distributed.hadoop.KMeansCentroidSketchHadoopMapper.java Source code

Java tutorial

Introduction

Here is the source code for weka.distributed.hadoop.KMeansCentroidSketchHadoopMapper.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    KMeansCentroidSketchHadoopMapper
 *    Copyright (C) 2014 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.distributed.hadoop;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.zip.GZIPOutputStream;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;

import weka.clusterers.CentroidSketch;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.distributed.CSVToARFFHeaderMapTask;
import weka.distributed.CSVToARFFHeaderReduceTask;
import weka.distributed.KMeansMapTask;
import distributed.core.DistributedJobConfig;

/**
 * Hadoop mapper for the k-means|| initialization procedure. Uses weighted
 * reservoir sampling to maintain a random selection of the instances with the
 * greatest distance from the current set of centroid candidates.
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision$
 */
public class KMeansCentroidSketchHadoopMapper extends Mapper<LongWritable, Text, Text, BytesWritable> {

    /**
     * The key in the Configuration that the options for this task are associated
     * with
     */
    public static final String CENTROID_SKETCH_MAP_TASK_OPTIONS = "*weka.distributed.centroid_sketch_map_task_opts";

    /** File prefix for serialized sketch files */
    public static final String SKETCH_FILE_PREFIX = "sketch_run";

    /** The underlying centroid sketch tasks - one for each run */
    protected CentroidSketch[] m_tasks;

    /** Use a configured KMeansMapTask to apply user-specified filters to the data */
    protected KMeansMapTask m_forFilteringOnly;

    /** Helper Weka CSV map task - used simply for parsing CSV entering the map */
    protected CSVToARFFHeaderMapTask m_rowHelper;

    /** The ARFF header of the data */
    protected Instances m_trainingHeader;

    /** The number of runs of k-means being performed */
    protected int m_numRuns = 1;

    /** True if this is the first iteration of the k-means|| initialization */
    protected boolean m_isFirstIteration;

    /**
     * Helper method for loading serialized centroid sketches from the distributed
     * cache
     * 
     * @param prefix the filename prefix for serialized centroid sketches
     * @param numRuns the number of runs (and hence the number of sketches) that
     *          we're expecting
     * @return an array of CentroidSketch objects (one for each run)
     * @throws Exception if a problem occurs
     */
    protected static CentroidSketch[] loadSketchesFromRunFiles(String prefix, int numRuns) throws Exception {
        CentroidSketch[] sketches = new CentroidSketch[numRuns];

        for (int i = 0; i < numRuns; i++) {
            File f = new File(prefix + i);

            if (!f.exists()) {
                throw new IOException("The centroid sketch file '" + f.toString()
                        + "' does not seem to exist in the distributed cache!");
            }

            ObjectInputStream ois = null;

            try {
                ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(f)));

                CentroidSketch s = (CentroidSketch) ois.readObject();
                sketches[i] = s;
            } finally {
                if (ois != null) {
                    ois.close();
                }
            }
        }

        for (CentroidSketch sketche : sketches) {
            System.err.println(
                    "Starting sketch - num instances in sketch: " + sketche.getCurrentSketch().numInstances());
        }

        return sketches;
    }

    @Override
    public void setup(Context context) throws IOException {
        m_rowHelper = new CSVToARFFHeaderMapTask();

        Configuration conf = context.getConfiguration();
        String taskOptsS = conf.get(CENTROID_SKETCH_MAP_TASK_OPTIONS);
        String csvOptsS = conf.get(CSVToArffHeaderHadoopMapper.CSV_TO_ARFF_HEADER_MAP_TASK_OPTIONS);

        try {
            if (!DistributedJobConfig.isEmpty(csvOptsS)) {
                String[] csvOpts = Utils.splitOptions(csvOptsS);
                m_rowHelper.setOptions(csvOpts);
            }

            if (!DistributedJobConfig.isEmpty(taskOptsS)) {
                String[] taskOpts = Utils.splitOptions(taskOptsS);

                // name of the training ARFF header file
                String arffHeaderFileName = Utils.getOption("arff-header", taskOpts);
                if (DistributedJobConfig.isEmpty(arffHeaderFileName)) {
                    throw new IOException("Can't continue without the name of the ARFF header file!");
                }
                Instances headerWithSummary = WekaClassifierHadoopMapper.loadTrainingHeader(arffHeaderFileName);
                m_trainingHeader = CSVToARFFHeaderReduceTask.stripSummaryAtts(headerWithSummary);

                m_rowHelper
                        .initParserOnly(CSVToARFFHeaderMapTask.instanceHeaderToAttributeNameList(m_trainingHeader));

                // num runs
                String numRuns = Utils.getOption("num-runs", taskOpts);
                if (!DistributedJobConfig.isEmpty(numRuns)) {
                    try {
                        m_numRuns = Integer.parseInt(numRuns);
                    } catch (NumberFormatException ex) {
                        throw new IOException("Unable to parse number of runs from -num-runs option");
                    }
                } else {
                    throw new IOException("Unable to continue without knowing how many runs are being performed!");
                }

                // first iteration?
                m_isFirstIteration = Utils.getFlag("first-iteration", taskOpts);

                // load centroid sketches from the distributed cache
                // m_tasks = loadSketches("centroidSketches.ser");
                m_tasks = loadSketchesFromRunFiles(SKETCH_FILE_PREFIX, m_numRuns);

                // init a KMeansMapTask to use for data filtering
                m_forFilteringOnly = new KMeansMapTask();
                try {
                    m_forFilteringOnly.setOptions(taskOpts);

                    m_forFilteringOnly.init(headerWithSummary);
                } catch (Exception ex) {
                    throw new IOException(ex);
                }
            }
        } catch (Exception ex) {
            throw new IOException(ex);
        }
    }

    protected void processRow(String row) throws IOException {
        if (row != null) {
            String[] parsed = m_rowHelper.parseRowOnly(row);

            if (parsed.length != m_trainingHeader.numAttributes()) {
                throw new IOException("Parsed a row that contains a different number of values than "
                        + "there are attributes in the training ARFF header: " + row);
            }

            try {
                Instance toProcess = m_rowHelper.makeInstance(m_trainingHeader, true, parsed);

                // make sure it goes through any filters first!
                toProcess = m_forFilteringOnly.applyFilters(toProcess);

                for (int k = 0; k < m_numRuns; k++) {
                    m_tasks[k].process(toProcess, m_isFirstIteration);
                }

            } catch (Exception ex) {
                throw new IOException(ex);
            }
        }
    }

    @Override
    public void map(LongWritable key, Text value, Context context) throws IOException {
        if (value != null) {
            processRow(value.toString());
        }
    }

    protected static byte[] sketchToBytes(CentroidSketch sketch) throws IOException {
        ObjectOutputStream p = null;
        byte[] bytes = null;

        try {
            ByteArrayOutputStream ostream = new ByteArrayOutputStream();
            OutputStream os = ostream;

            p = new ObjectOutputStream(new BufferedOutputStream(new GZIPOutputStream(os)));
            p.writeObject(sketch);
            p.flush();
            p.close();

            bytes = ostream.toByteArray();

            p = null;
        } finally {
            if (p != null) {
                p.close();
            }
        }

        return bytes;
    }

    @Override
    public void cleanup(Context context) throws IOException, InterruptedException {
        // emit serialized sketch tasks with run number as key
        for (int i = 0; i < m_tasks.length; i++) {
            System.err.println("Number of instances in sketch: " + m_tasks[i].getCurrentSketch().numInstances());
            System.err.println(
                    "Number of instances in reservoir: " + m_tasks[i].getReservoirSample().getSample().size());
            byte[] bytes = sketchToBytes(m_tasks[i]);
            String runNum = "run" + i;
            Text key = new Text();
            key.set(runNum);
            BytesWritable value = new BytesWritable();
            value.set(bytes, 0, bytes.length);
            context.write(key, value);
        }
    }
}