com.marklogic.contentpump.MultithreadedMapper.java Source code

Java tutorial

Introduction

Here is the source code for com.marklogic.contentpump.MultithreadedMapper.java

Source

/*
 * Copyright 2003-2016 MarkLogic Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.marklogic.contentpump;

import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.Counter;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.OutputFormat;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.StatusReporter;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.util.ReflectionUtils;

import com.marklogic.contentpump.utilities.ReflectionUtil;

/**
 * Multithreaded implementation for @link org.apache.hadoop.mapreduce.Mapper,
 * currently used by import operations when applicable to leverage concurrent
 * updates to MarkLogic.
 */
public class MultithreadedMapper<K1, V1, K2, V2> extends Mapper<K1, V1, K2, V2> {
    private static final Log LOG = LogFactory.getLog(MultithreadedMapper.class);
    private Class<? extends BaseMapper<K1, V1, K2, V2>> mapClass;
    private Context outer;
    private List<MapRunner> runners;
    private int threadCount = 0;
    private ExecutorService threadPool;

    /**
     * Get thread count set for this mapper.
     * @return thread count set for this mapper.
     */
    public int getThreadCount(Context context) {
        if (threadCount > 0) {
            return threadCount;
        } else {
            return getNumberOfThreads(context);
        }
    }

    /**
     * Set thread count for this mapper.
     * @param threadCount Thread count for this mapper.
     */
    public void setThreadCount(int threadCount) {
        this.threadCount = threadCount;
    }

    /**
     * Get a thread pool to be used for this mapper.
     * @return thread pool to be used for this mapper.
     */
    public ExecutorService getThreadPool() {
        return threadPool;
    }

    /**
     * Set the thread pool for this mapper.
     * @param pool thread pool to be used for this mapper.
     */
    public void setThreadPool(ExecutorService pool) {
        this.threadPool = pool;
    }

    /**
      * The number of threads in the thread pool that will run the map function.
      * 
      * @param job
      *            the job
      * @return the number of threads
      */
    public static int getNumberOfThreads(JobContext job) {
        return job.getConfiguration().getInt(ConfigConstants.CONF_THREADS_PER_SPLIT, 10);
    }

    /**
     * Set the number of threads in the pool for running maps.
     * 
     * @param job
     *            the job to modify
     * @param threads
     *            the new number of threads
     */
    public static void setNumberOfThreads(Job job, int threads) {
        job.getConfiguration().setInt(ConfigConstants.CONF_THREADS_PER_SPLIT, threads);
    }

    public static void setNumberOfThreads(Configuration conf, int threads) {
        conf.setInt(ConfigConstants.CONF_THREADS_PER_SPLIT, threads);
    }

    /**
     * Get the application's mapper class.
     * 
     * @param <K1>
     *            the map's input key type
     * @param <V1>
     *            the map's input value type
     * @param <K2>
     *            the map's output key type
     * @param <V2>
     *            the map's output value type
     * @param job
     *            the job
     * @return the mapper class to run
     */
    @SuppressWarnings("unchecked")
    public static <K1, V1, K2, V2> Class<BaseMapper<K1, V1, K2, V2>> getMapperClass(JobContext job) {
        Configuration conf = job.getConfiguration();
        return (Class<BaseMapper<K1, V1, K2, V2>>) conf.getClass(ConfigConstants.CONF_MULTITHREADEDMAPPER_CLASS,
                BaseMapper.class);
    }

    /**
     * Set the application's mapper class.
     * 
     * @param <K1>
     *            the map input key type
     * @param <V1>
     *            the map input value type
     * @param <K2>
     *            the map output key type
     * @param <V2>
     *            the map output value type
     * @param job
     *            the job to modify
     * @param internalMapperClass
     *            the class to use as the mapper
     */
    public static <K1, V1, K2, V2> void setMapperClass(Configuration conf,
            Class<? extends BaseMapper<?, ?, ?, ?>> internalMapperClass) {
        if (MultithreadedMapper.class.isAssignableFrom(internalMapperClass)) {
            throw new IllegalArgumentException("Can't have recursive " + "MultithreadedMapper instances.");
        }
        conf.setClass(ConfigConstants.CONF_MULTITHREADEDMAPPER_CLASS, internalMapperClass, Mapper.class);
    }

    /**
     * Run the application's maps using a thread pool.
     */
    @Override
    public void run(Context context) throws IOException, InterruptedException {
        outer = context;
        int numberOfThreads = getThreadCount(context);
        mapClass = getMapperClass(context);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Running with " + numberOfThreads + " threads");
        }
        // current mapper takes 1 thread
        numberOfThreads--;

        InputSplit split = context.getInputSplit();

        // submit runners
        try {
            List<Future<?>> taskList = null;
            if (threadPool != null) {
                taskList = new ArrayList<Future<?>>();
                synchronized (threadPool) {
                    for (int i = 0; i < numberOfThreads; ++i) {
                        MapRunner runner = new MapRunner();
                        BaseMapper<K1, V1, K2, V2> mapper = runner.getMapper();
                        if (!threadPool.isShutdown()) {
                            Collection<Future<Object>> tasks = mapper.submitTasks(threadPool, split);
                            taskList.addAll(tasks);
                            numberOfThreads -= tasks.size();
                            Future<?> future = threadPool.submit(runner);
                            taskList.add(future);
                        } else {
                            throw new InterruptedException("Thread Pool has been shut down");
                        }
                    }
                    threadPool.notify();
                }

                // MapRunner that runs in current thread
                MapRunner r = new MapRunner();
                r.run();

                for (Future<?> f : taskList) {
                    f.get();
                }
            } else {
                runners = new ArrayList<MapRunner>(numberOfThreads);
                for (int i = 0; i < numberOfThreads; ++i) {
                    MapRunner thread;
                    thread = new MapRunner();
                    thread.start();
                    runners.add(i, thread);
                }
                // MapRunner runs in current thread
                MapRunner r = new MapRunner();
                r.run();

                for (int i = 0; i < numberOfThreads; ++i) {
                    MapRunner thread = runners.get(i);
                    thread.join();
                    Throwable th = thread.throwable;
                    if (th != null) {
                        if (th instanceof IOException) {
                            throw (IOException) th;
                        } else if (th instanceof InterruptedException) {
                            throw (InterruptedException) th;
                        } else {
                            throw new RuntimeException(th);
                        }
                    }
                }
            }
        } catch (ClassNotFoundException e) {
            LOG.error("MapRunner class not found", e);
        } catch (ExecutionException e) {
            LOG.error("Error waiting for MapRunner threads to complete", e);
        }
    }

    private class SubMapRecordReader extends RecordReader<K1, V1> {
        private K1 key;
        private V1 value;

        @Override
        public void close() throws IOException {
        }

        @Override
        public float getProgress() throws IOException, InterruptedException {
            return 0;
        }

        @Override
        public void initialize(InputSplit split, TaskAttemptContext context)
                throws IOException, InterruptedException {
        }

        @SuppressWarnings("unchecked")
        @Override
        public boolean nextKeyValue() throws IOException, InterruptedException {
            if (!outer.nextKeyValue()) {
                return false;
            }
            if (outer.getCurrentKey() == null) {
                return true;
            }
            key = (K1) ReflectionUtils.newInstance(outer.getCurrentKey().getClass(), outer.getConfiguration());
            key = ReflectionUtils.copy(outer.getConfiguration(), outer.getCurrentKey(), key);
            value = (V1) ReflectionUtils.newInstance(outer.getCurrentValue().getClass(), outer.getConfiguration());
            value = ReflectionUtils.copy(outer.getConfiguration(), outer.getCurrentValue(), value);
            return true;
        }

        public K1 getCurrentKey() {
            return key;
        }

        @Override
        public V1 getCurrentValue() {
            return value;
        }
    }

    private class SubMapStatusReporter extends StatusReporter {

        @Override
        public Counter getCounter(Enum<?> name) {
            return outer.getCounter(name);
        }

        @Override
        public Counter getCounter(String group, String name) {
            return outer.getCounter(group, name);
        }

        @Override
        public void progress() {
            outer.progress();
        }

        @Override
        public void setStatus(String status) {
            outer.setStatus(status);
        }

        public float getProgress() {
            Method getProgressMethod;
            try {
                getProgressMethod = outer.getClass().getMethod("getProgress", new Class[0]);
                if (getProgressMethod != null) {
                    return (Float) getProgressMethod.invoke(outer, new Object[0]);
                }
            } catch (Exception e) {
                LOG.error("Error getting progress", e);
            }

            return 0;
        }

    }

    private class MapRunner extends Thread {
        private BaseMapper<K1, V1, K2, V2> mapper;
        private Context subcontext;
        private Throwable throwable;
        private RecordWriter<K2, V2> writer;

        MapRunner() throws IOException, InterruptedException, ClassNotFoundException {
            // initiate the real mapper that does the work
            mapper = ReflectionUtils.newInstance(mapClass, outer.getConfiguration());
            @SuppressWarnings("unchecked")
            OutputFormat<K2, V2> outputFormat = (OutputFormat<K2, V2>) ReflectionUtils
                    .newInstance(outer.getOutputFormatClass(), outer.getConfiguration());
            writer = outputFormat.getRecordWriter(outer);
            try {
                subcontext = (Context) ReflectionUtil.createMapperContext(mapper, outer.getConfiguration(),
                        outer.getTaskAttemptID(), new SubMapRecordReader(), writer, outer.getOutputCommitter(),
                        new SubMapStatusReporter(), outer.getInputSplit());
            } catch (Exception e) {
                throw new IOException("Error creating mapper context", e);
            }
        }

        public BaseMapper<K1, V1, K2, V2> getMapper() {
            return mapper;
        }

        @Override
        public void run() {
            try {
                mapper.runThreadSafe(outer, subcontext);
                writer.close(subcontext);
            } catch (Throwable ie) {
                LOG.error(ie.getMessage(), ie);
            }
        }
    }

}