com.marklogic.contentpump.LocalJobRunner.java Source code

Java tutorial

Introduction

Here is the source code for com.marklogic.contentpump.LocalJobRunner.java

Source

/*
 * Copyright 2003-2016 MarkLogic Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.marklogic.contentpump;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.Counter;
import org.apache.hadoop.mapreduce.CounterGroup;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.JobID;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.OutputCommitter;
import org.apache.hadoop.mapreduce.OutputFormat;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.TaskAttemptID;
import org.apache.hadoop.mapreduce.TaskID;
import org.apache.hadoop.mapreduce.TaskType;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.StringUtils;

import com.marklogic.contentpump.utilities.ReflectionUtil;

/**
 * Runs a job in-process, potentially multi-threaded.  Only supports map-only
 * jobs.
 * 
 * @author jchen
 *
 */
public class LocalJobRunner implements ConfigConstants {
    public static final Log LOG = LogFactory.getLog(LocalJobRunner.class);
    public static final int DEFAULT_THREAD_COUNT = 4;

    private Job job;
    private ExecutorService pool;
    private AtomicInteger[] progress;
    private AtomicBoolean jobComplete;
    private long startTime;
    private int threadsPerSplit = 0;
    private int threadCount;
    //TODO confusing, rename it
    private int availableThreads = 1;
    // minimally required thread per task defined by the job
    private int minThreads = 1;
    private Command cmd;
    private ContentPumpReporter reporter;

    public LocalJobRunner(Job job, CommandLine cmdline, Command cmd) {
        this.job = job;
        this.cmd = cmd;

        threadCount = DEFAULT_THREAD_COUNT;
        if (cmdline.hasOption(THREAD_COUNT)) {
            threadCount = Integer.parseInt(cmdline.getOptionValue(THREAD_COUNT));
        }
        if (threadCount > 1) {
            pool = Executors.newFixedThreadPool(threadCount);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Thread pool size: " + threadCount);
            }
        }

        if (cmdline.hasOption(THREADS_PER_SPLIT)) {
            threadsPerSplit = Integer.parseInt(cmdline.getOptionValue(THREADS_PER_SPLIT));
        }

        Configuration conf = job.getConfiguration();
        minThreads = conf.getInt(CONF_MIN_THREADS, minThreads);

        jobComplete = new AtomicBoolean();
        startTime = System.currentTimeMillis();
    }

    /**
     * Run the job.  Get the input splits, create map tasks and submit it to
     * the thread pool if there is one; otherwise, runs the the task one by
     * one.
     * 
     * @param <INKEY>
     * @param <INVALUE>
     * @param <OUTKEY>
     * @param <OUTVALUE>
     * @throws Exception
     */
    @SuppressWarnings("unchecked")
    public <INKEY, INVALUE, OUTKEY, OUTVALUE, T extends org.apache.hadoop.mapreduce.InputSplit> void run()
            throws Exception {
        Configuration conf = job.getConfiguration();
        InputFormat<INKEY, INVALUE> inputFormat = (InputFormat<INKEY, INVALUE>) ReflectionUtils
                .newInstance(job.getInputFormatClass(), conf);
        List<InputSplit> splits = inputFormat.getSplits(job);
        T[] array = (T[]) splits.toArray(new org.apache.hadoop.mapreduce.InputSplit[splits.size()]);

        // sort the splits into order based on size, so that the biggest
        // goes first
        Arrays.sort(array, new SplitLengthComparator());
        OutputFormat<OUTKEY, OUTVALUE> outputFormat = (OutputFormat<OUTKEY, OUTVALUE>) ReflectionUtils
                .newInstance(job.getOutputFormatClass(), conf);
        Class<? extends Mapper<?, ?, ?, ?>> mapperClass = job.getMapperClass();
        Mapper<INKEY, INVALUE, OUTKEY, OUTVALUE> mapper = (Mapper<INKEY, INVALUE, OUTKEY, OUTVALUE>) ReflectionUtils
                .newInstance(mapperClass, conf);
        try {
            outputFormat.checkOutputSpecs(job);
        } catch (Exception ex) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Error checking output specification: ", ex);
            } else {
                LOG.error("Error checking output specification: ");
                LOG.error(ex.getMessage());
            }
            return;
        }
        conf = job.getConfiguration();
        progress = new AtomicInteger[splits.size()];
        for (int i = 0; i < splits.size(); i++) {
            progress[i] = new AtomicInteger();
        }
        Monitor monitor = new Monitor();
        monitor.start();
        reporter = new ContentPumpReporter();
        List<Future<Object>> taskList = new ArrayList<Future<Object>>();
        for (int i = 0; i < array.length; i++) {
            InputSplit split = array[i];
            if (pool != null) {
                LocalMapTask<INKEY, INVALUE, OUTKEY, OUTVALUE> task = new LocalMapTask<INKEY, INVALUE, OUTKEY, OUTVALUE>(
                        inputFormat, outputFormat, conf, i, split, reporter, progress[i]);
                availableThreads = assignThreads(i, array.length);
                Class<? extends Mapper<?, ?, ?, ?>> runtimeMapperClass = job.getMapperClass();
                if (availableThreads > 1 && availableThreads != threadsPerSplit) {
                    // possible runtime adjustment
                    if (runtimeMapperClass != (Class) MultithreadedMapper.class) {
                        runtimeMapperClass = (Class<? extends Mapper<INKEY, INVALUE, OUTKEY, OUTVALUE>>) cmd
                                .getRuntimeMapperClass(job, mapperClass, threadsPerSplit, availableThreads);
                    }
                    if (runtimeMapperClass != mapperClass) {
                        task.setMapperClass(runtimeMapperClass);
                    }
                    if (runtimeMapperClass == (Class) MultithreadedMapper.class) {
                        task.setThreadCount(availableThreads);
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Thread Count for Split#" + i + " : " + availableThreads);
                        }
                    }
                }

                if (runtimeMapperClass == (Class) MultithreadedMapper.class) {
                    synchronized (pool) {
                        taskList.add(pool.submit(task));
                        pool.wait();
                    }
                } else {
                    pool.submit(task);
                }
            } else { // single-threaded
                JobID jid = new JobID();
                TaskID taskId = new TaskID(jid.getJtIdentifier(), jid.getId(), TaskType.MAP, i);
                TaskAttemptID taskAttemptId = new TaskAttemptID(taskId, 0);
                TaskAttemptContext context = ReflectionUtil.createTaskAttemptContext(conf, taskAttemptId);
                RecordReader<INKEY, INVALUE> reader = inputFormat.createRecordReader(split, context);
                RecordWriter<OUTKEY, OUTVALUE> writer = outputFormat.getRecordWriter(context);
                OutputCommitter committer = outputFormat.getOutputCommitter(context);
                TrackingRecordReader trackingReader = new TrackingRecordReader(reader, progress[i]);

                Mapper.Context mapperContext = ReflectionUtil.createMapperContext(mapper, conf, taskAttemptId,
                        trackingReader, writer, committer, reporter, split);

                trackingReader.initialize(split, mapperContext);

                // no thread pool (only 1 thread specified)
                Class<? extends Mapper<?, ?, ?, ?>> mapClass = job.getMapperClass();
                mapperContext.getConfiguration().setClass(CONF_MAPREDUCE_JOB_MAP_CLASS, mapClass, Mapper.class);
                mapper = (Mapper<INKEY, INVALUE, OUTKEY, OUTVALUE>) ReflectionUtils.newInstance(mapClass,
                        mapperContext.getConfiguration());
                mapper.run(mapperContext);
                trackingReader.close();
                writer.close(mapperContext);
                committer.commitTask(context);
            }
        }
        // wait till all tasks are done
        if (pool != null) {
            for (Future<Object> f : taskList) {
                f.get();
            }
            pool.shutdown();
            while (!pool.awaitTermination(1, TimeUnit.DAYS))
                ;
            jobComplete.set(true);
        }
        monitor.interrupt();
        monitor.join(1000);

        // report counters
        Iterator<CounterGroup> groupIt = reporter.counters.iterator();
        while (groupIt.hasNext()) {
            CounterGroup group = groupIt.next();
            LOG.info(group.getDisplayName() + ": ");
            Iterator<Counter> counterIt = group.iterator();
            while (counterIt.hasNext()) {
                Counter counter = counterIt.next();
                LOG.info(counter.getDisplayName() + ": " + counter.getValue());
            }
        }
        LOG.info("Total execution time: " + (System.currentTimeMillis() - startTime) / 1000 + " sec");
    }

    /**
     * Assign thread count for a given split
     * 
     * @param splitIndex split index
     * @param splitCount
     * @return
     */
    private int assignThreads(int splitIndex, int splitCount) {
        if (threadsPerSplit > 0) {
            return threadsPerSplit;
        }
        if (splitCount == 1) {
            return threadCount;
        }
        if (splitCount * minThreads > threadCount) {
            return minThreads;
        }
        if (splitIndex % threadCount < threadCount % splitCount) {
            return threadCount / splitCount + 1;
        } else {
            return threadCount / splitCount;
        }
    }

    /**
     * A map task to be run in a thread.
     * 
     * @author jchen
     *
     * @param <INKEY>
     * @param <INVALUE>
     * @param <OUTKEY>
     * @param <OUTVALUE>
     */
    public class LocalMapTask<INKEY, INVALUE, OUTKEY, OUTVALUE> implements Callable<Object> {
        private InputFormat<INKEY, INVALUE> inputFormat;
        private OutputFormat<OUTKEY, OUTVALUE> outputFormat;
        private Mapper<INKEY, INVALUE, OUTKEY, OUTVALUE> mapper;
        private Configuration conf;
        private int id;
        private InputSplit split;
        private AtomicInteger pctProgress;
        private ContentPumpReporter reporter;
        private Class<? extends Mapper<?, ?, ?, ?>> mapperClass;
        private int threadCount = 0;

        public LocalMapTask(InputFormat<INKEY, INVALUE> inputFormat, OutputFormat<OUTKEY, OUTVALUE> outputFormat,
                Configuration conf, int id, InputSplit split, ContentPumpReporter reporter,
                AtomicInteger pctProgress) {
            this.inputFormat = inputFormat;
            this.outputFormat = outputFormat;
            this.conf = conf;
            this.id = id;
            this.split = split;
            this.pctProgress = pctProgress;
            this.reporter = reporter;
            try {
                mapperClass = job.getMapperClass();
            } catch (ClassNotFoundException e) {
                LOG.error("Mapper class not found", e);
            }
        }

        public int getThreadCount() {
            return threadCount;
        }

        public void setThreadCount(int threads) {
            threadCount = threads;
        }

        public Class<? extends Mapper<?, ?, ?, ?>> getMapperClass() {
            return mapperClass;
        }

        public void setMapperClass(Class<? extends Mapper<?, ?, ?, ?>> runtimeMapperClass) {
            mapperClass = runtimeMapperClass;
        }

        @SuppressWarnings("unchecked")
        @Override
        public Object call() {
            TaskAttemptContext context = null;
            Mapper.Context mapperContext = null;
            TrackingRecordReader trackingReader = null;
            RecordWriter<OUTKEY, OUTVALUE> writer = null;
            OutputCommitter committer = null;
            JobID jid = new JobID();
            TaskID taskId = new TaskID(jid.getJtIdentifier(), jid.getId(), TaskType.MAP, id);
            TaskAttemptID taskAttemptId = new TaskAttemptID(taskId, 0);
            try {
                context = ReflectionUtil.createTaskAttemptContext(conf, taskAttemptId);
                RecordReader<INKEY, INVALUE> reader = inputFormat.createRecordReader(split, context);
                writer = outputFormat.getRecordWriter(context);
                committer = outputFormat.getOutputCommitter(context);
                trackingReader = new TrackingRecordReader(reader, pctProgress);
                mapper = (Mapper<INKEY, INVALUE, OUTKEY, OUTVALUE>) ReflectionUtils.newInstance(mapperClass, conf);
                mapperContext = ReflectionUtil.createMapperContext(mapper, conf, taskAttemptId, trackingReader,
                        writer, committer, reporter, split);
                trackingReader.initialize(split, mapperContext);
                if (mapperClass == (Class) MultithreadedMapper.class) {
                    ((MultithreadedMapper) mapper).setThreadCount(threadCount);
                    ((MultithreadedMapper) mapper).setThreadPool(pool);
                }
                mapper.run(mapperContext);
            } catch (Throwable t) {
                LOG.error("Error running task: ", t);
                try {
                    synchronized (pool) {
                        pool.notify();
                    }
                } catch (Throwable t1) {
                    LOG.error(t1);
                }
            } finally {
                try {
                    if (trackingReader != null) {
                        trackingReader.close();
                    }
                    if (writer != null) {
                        writer.close(mapperContext);
                    }
                    committer.commitTask(context);
                } catch (Throwable t) {
                    LOG.error("Error committing task: ", t);
                }
            }
            return null;
        }
    }

    class TrackingRecordReader<K, V> extends RecordReader<K, V> {
        private final RecordReader<K, V> real;
        private AtomicInteger pctProgress;

        TrackingRecordReader(RecordReader<K, V> real, AtomicInteger pctProgress) {
            this.real = real;
            this.pctProgress = pctProgress;
        }

        @Override
        public void close() throws IOException {
            real.close();
        }

        @Override
        public K getCurrentKey() throws IOException, InterruptedException {
            return real.getCurrentKey();
        }

        @Override
        public V getCurrentValue() throws IOException, InterruptedException {
            return real.getCurrentValue();
        }

        @Override
        public float getProgress() throws IOException, InterruptedException {
            return real.getProgress();
        }

        @Override
        public void initialize(org.apache.hadoop.mapreduce.InputSplit split,
                org.apache.hadoop.mapreduce.TaskAttemptContext context) throws IOException, InterruptedException {
            real.initialize(split, context);
        }

        @Override
        public boolean nextKeyValue() throws IOException, InterruptedException {
            boolean result = real.nextKeyValue();
            pctProgress.set((int) (getProgress() * 100));
            return result;
        }
    }

    class Monitor extends Thread {
        private String lastReport;

        public void run() {
            try {
                while (!jobComplete.get() && !interrupted()) {
                    Thread.sleep(1000);
                    String report = (" completed " + StringUtils.formatPercent(computeProgress(), 0));
                    if (!report.equals(lastReport)) {
                        LOG.info(report);
                        lastReport = report;
                    }
                }
            } catch (InterruptedException e) {
            } catch (Throwable t) {
                LOG.error("Error in monitor thread", t);
            }
            String report = (" completed " + StringUtils.formatPercent(computeProgress(), 0));
            if (!report.equals(lastReport)) {
                LOG.info(report);
            }
        }
    }

    public double computeProgress() {
        if (progress.length == 0) {
            return (double) 1;
        }
        long result = 0;
        for (AtomicInteger pct : progress) {
            result += pct.longValue();
        }
        return (double) result / progress.length / 100;
    }

    private static class SplitLengthComparator implements Comparator<org.apache.hadoop.mapreduce.InputSplit> {

        @Override
        public int compare(org.apache.hadoop.mapreduce.InputSplit o1, org.apache.hadoop.mapreduce.InputSplit o2) {
            try {
                long len1 = o1.getLength();
                long len2 = o2.getLength();
                if (len1 < len2) {
                    return 1;
                } else if (len1 == len2) {
                    return 0;
                } else {
                    return -1;
                }
            } catch (IOException ie) {
                throw new RuntimeException("exception in compare", ie);
            } catch (InterruptedException ie) {
                throw new RuntimeException("exception in compare", ie);
            }
        }
    }

    public ContentPumpReporter getReporter() {
        return this.reporter;
    }
}