org.apache.tez.runtime.library.common.shuffle.impl.ShuffleScheduler.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.tez.runtime.library.common.shuffle.impl.ShuffleScheduler.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tez.runtime.library.common.shuffle.impl;

import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.tez.common.TezJobConfig;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.TezInputContext;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.TezRuntimeUtils;

import com.google.common.collect.Lists;

class ShuffleScheduler {
    static ThreadLocal<Long> shuffleStart = new ThreadLocal<Long>() {
        protected Long initialValue() {
            return 0L;
        }
    };

    private static final Log LOG = LogFactory.getLog(ShuffleScheduler.class);
    private static final int MAX_MAPS_AT_ONCE = 20;
    private static final long INITIAL_PENALTY = 10000;
    private static final float PENALTY_GROWTH_RATE = 1.3f;

    // TODO NEWTEZ May need to be a string if attempting to fetch from multiple inputs.
    private boolean[] finishedMaps;
    private final int numInputs;
    private int remainingMaps;
    private Map<String, MapHost> mapLocations = new HashMap<String, MapHost>();
    //TODO NEWTEZ Clean this and other maps at some point
    private ConcurrentMap<String, InputAttemptIdentifier> pathToIdentifierMap = new ConcurrentHashMap<String, InputAttemptIdentifier>();
    private Set<MapHost> pendingHosts = new HashSet<MapHost>();
    private Set<InputAttemptIdentifier> obsoleteMaps = new HashSet<InputAttemptIdentifier>();

    private final Random random = new Random(System.currentTimeMillis());
    private final DelayQueue<Penalty> penalties = new DelayQueue<Penalty>();
    private final Referee referee = new Referee();
    private final Map<InputAttemptIdentifier, IntWritable> failureCounts = new HashMap<InputAttemptIdentifier, IntWritable>();
    private final Map<String, IntWritable> hostFailures = new HashMap<String, IntWritable>();
    private final TezInputContext inputContext;
    private final Shuffle shuffle;
    private final int abortFailureLimit;
    private final TezCounter shuffledMapsCounter;
    private final TezCounter reduceShuffleBytes;
    private final TezCounter failedShuffleCounter;

    private final long startTime;
    private long lastProgressTime;

    private int maxMapRuntime = 0;
    private int maxFailedUniqueFetches = 5;
    private int maxFetchFailuresBeforeReporting;

    private long totalBytesShuffledTillNow = 0;
    private DecimalFormat mbpsFormat = new DecimalFormat("0.00");

    private boolean reportReadErrorImmediately = true;

    public ShuffleScheduler(TezInputContext inputContext, Configuration conf, int numberOfInputs, Shuffle shuffle,
            TezCounter shuffledMapsCounter, TezCounter reduceShuffleBytes, TezCounter failedShuffleCounter) {
        this.inputContext = inputContext;
        this.numInputs = numberOfInputs;
        abortFailureLimit = Math.max(30, numberOfInputs / 10);
        remainingMaps = numberOfInputs;
        finishedMaps = new boolean[remainingMaps]; // default init to false
        this.shuffle = shuffle;
        this.shuffledMapsCounter = shuffledMapsCounter;
        this.reduceShuffleBytes = reduceShuffleBytes;
        this.failedShuffleCounter = failedShuffleCounter;
        this.startTime = System.currentTimeMillis();
        this.lastProgressTime = startTime;
        referee.start();
        this.maxFailedUniqueFetches = Math.min(numberOfInputs, this.maxFailedUniqueFetches);
        this.maxFetchFailuresBeforeReporting = conf.getInt(TezJobConfig.TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES,
                TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES_LIMIT);
        this.reportReadErrorImmediately = conf.getBoolean(TezJobConfig.TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR,
                TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR);
    }

    public synchronized void copySucceeded(InputAttemptIdentifier srcAttemptIdentifier, MapHost host, long bytes,
            long milis, MapOutput output) throws IOException {
        String taskIdentifier = TezRuntimeUtils.getTaskAttemptIdentifier(
                srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex(),
                srcAttemptIdentifier.getAttemptNumber());
        failureCounts.remove(taskIdentifier);
        hostFailures.remove(host.getHostName());

        if (!isInputFinished(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex())) {
            output.commit();
            setInputFinished(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex());
            shuffledMapsCounter.increment(1);
            if (--remainingMaps == 0) {
                notifyAll();
            }

            // update the status
            lastProgressTime = System.currentTimeMillis();
            totalBytesShuffledTillNow += bytes;
            logProgress();
            reduceShuffleBytes.increment(bytes);
            if (LOG.isDebugEnabled()) {
                LOG.debug(
                        "src task: " + TezRuntimeUtils.getTaskAttemptIdentifier(inputContext.getSourceVertexName(),
                                srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex(),
                                srcAttemptIdentifier.getAttemptNumber()) + " done");
            }
        }
        // TODO NEWTEZ Should this be releasing the output, if not committed ? Possible memory leak in case of speculation.
    }

    private void logProgress() {
        float mbs = (float) totalBytesShuffledTillNow / (1024 * 1024);
        int mapsDone = numInputs - remainingMaps;
        long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 + 1;

        float transferRate = mbs / secsSinceStart;
        LOG.info("copy(" + mapsDone + " of " + numInputs + " at " + mbpsFormat.format(transferRate) + " MB/s)");
    }

    public synchronized void copyFailed(InputAttemptIdentifier srcAttempt, MapHost host, boolean readError) {
        host.penalize();
        int failures = 1;
        if (failureCounts.containsKey(srcAttempt)) {
            IntWritable x = failureCounts.get(srcAttempt);
            x.set(x.get() + 1);
            failures = x.get();
        } else {
            failureCounts.put(srcAttempt, new IntWritable(1));
        }
        String hostname = host.getHostName();
        if (hostFailures.containsKey(hostname)) {
            IntWritable x = hostFailures.get(hostname);
            x.set(x.get() + 1);
        } else {
            hostFailures.put(hostname, new IntWritable(1));
        }
        if (failures >= abortFailureLimit) {
            IOException ioe = new IOException(failures + " failures downloading "
                    + TezRuntimeUtils.getTaskAttemptIdentifier(inputContext.getSourceVertexName(),
                            srcAttempt.getInputIdentifier().getSrcTaskIndex(), srcAttempt.getAttemptNumber()));
            ioe.fillInStackTrace();
            shuffle.reportException(ioe);
        }

        checkAndInformJobTracker(failures, srcAttempt, readError);

        checkReducerHealth();

        long delay = (long) (INITIAL_PENALTY * Math.pow(PENALTY_GROWTH_RATE, failures));

        penalties.add(new Penalty(host, delay));

        failedShuffleCounter.increment(1);
    }

    // Notify the JobTracker  
    // after every read error, if 'reportReadErrorImmediately' is true or
    // after every 'maxFetchFailuresBeforeReporting' failures
    private void checkAndInformJobTracker(int failures, InputAttemptIdentifier srcAttempt, boolean readError) {
        if ((reportReadErrorImmediately && readError) || ((failures % maxFetchFailuresBeforeReporting) == 0)) {
            LOG.info("Reporting fetch failure for "
                    + TezRuntimeUtils.getTaskAttemptIdentifier(inputContext.getSourceVertexName(),
                            srcAttempt.getInputIdentifier().getSrcTaskIndex(), srcAttempt.getAttemptNumber())
                    + " to jobtracker.");

            List<Event> failedEvents = Lists.newArrayListWithCapacity(1);
            failedEvents.add(new InputReadErrorEvent("Fetch failure for "
                    + TezRuntimeUtils.getTaskAttemptIdentifier(inputContext.getSourceVertexName(),
                            srcAttempt.getInputIdentifier().getSrcTaskIndex(), srcAttempt.getAttemptNumber())
                    + " to jobtracker.", srcAttempt.getInputIdentifier().getSrcTaskIndex(),
                    srcAttempt.getAttemptNumber()));

            inputContext.sendEvents(failedEvents);
            //status.addFailedDependency(mapId);
        }
    }

    private void checkReducerHealth() {
        final float MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT = 0.5f;
        final float MIN_REQUIRED_PROGRESS_PERCENT = 0.5f;
        final float MAX_ALLOWED_STALL_TIME_PERCENT = 0.5f;

        long totalFailures = failedShuffleCounter.getValue();
        int doneMaps = numInputs - remainingMaps;

        boolean reducerHealthy = (((float) totalFailures
                / (totalFailures + doneMaps)) < MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT);

        // check if the reducer has progressed enough
        boolean reducerProgressedEnough = (((float) doneMaps / numInputs) >= MIN_REQUIRED_PROGRESS_PERCENT);

        // check if the reducer is stalled for a long time
        // duration for which the reducer is stalled
        int stallDuration = (int) (System.currentTimeMillis() - lastProgressTime);

        // duration for which the reducer ran with progress
        int shuffleProgressDuration = (int) (lastProgressTime - startTime);

        // min time the reducer should run without getting killed
        int minShuffleRunDuration = (shuffleProgressDuration > maxMapRuntime) ? shuffleProgressDuration
                : maxMapRuntime;

        boolean reducerStalled = (((float) stallDuration
                / minShuffleRunDuration) >= MAX_ALLOWED_STALL_TIME_PERCENT);

        // kill if not healthy and has insufficient progress
        if ((failureCounts.size() >= maxFailedUniqueFetches || failureCounts.size() == (numInputs - doneMaps))
                && !reducerHealthy && (!reducerProgressedEnough || reducerStalled)) {
            LOG.fatal("Shuffle failed with too many fetch failures " + "and insufficient progress!");
            String errorMsg = "Exceeded MAX_FAILED_UNIQUE_FETCHES; bailing-out.";
            shuffle.reportException(new IOException(errorMsg));
        }

    }

    public synchronized void addKnownMapOutput(String hostName, int partitionId, String hostUrl,
            InputAttemptIdentifier srcAttempt) {
        String identifier = MapHost.createIdentifier(hostName, partitionId);
        MapHost host = mapLocations.get(identifier);
        if (host == null) {
            host = new MapHost(partitionId, hostName, hostUrl);
            assert identifier.equals(host.getIdentifier());
            mapLocations.put(identifier, host);
        }
        host.addKnownMap(srcAttempt);
        pathToIdentifierMap.put(getIdentifierFromPathAndReduceId(srcAttempt.getPathComponent(), partitionId),
                srcAttempt);

        // Mark the host as pending
        if (host.getState() == MapHost.State.PENDING) {
            pendingHosts.add(host);
            notifyAll();
        }
    }

    public synchronized void obsoleteMapOutput(InputAttemptIdentifier srcAttempt) {
        // The incoming srcAttempt does not contain a path component.
        obsoleteMaps.add(srcAttempt);
    }

    public synchronized void putBackKnownMapOutput(MapHost host, InputAttemptIdentifier srcAttempt) {
        host.addKnownMap(srcAttempt);
    }

    public synchronized MapHost getHost() throws InterruptedException {
        while (pendingHosts.isEmpty()) {
            wait();
        }

        MapHost host = null;
        Iterator<MapHost> iter = pendingHosts.iterator();
        int numToPick = random.nextInt(pendingHosts.size());
        for (int i = 0; i <= numToPick; ++i) {
            host = iter.next();
        }

        pendingHosts.remove(host);
        host.markBusy();

        LOG.info("Assigning " + host + " with " + host.getNumKnownMapOutputs() + " to "
                + Thread.currentThread().getName());
        shuffleStart.set(System.currentTimeMillis());

        return host;
    }

    public InputAttemptIdentifier getIdentifierForFetchedOutput(String path, int reduceId) {
        return pathToIdentifierMap.get(getIdentifierFromPathAndReduceId(path, reduceId));
    }

    public synchronized List<InputAttemptIdentifier> getMapsForHost(MapHost host) {
        List<InputAttemptIdentifier> list = host.getAndClearKnownMaps();
        Iterator<InputAttemptIdentifier> itr = list.iterator();
        List<InputAttemptIdentifier> result = new ArrayList<InputAttemptIdentifier>();
        int includedMaps = 0;
        int totalSize = list.size();
        // find the maps that we still need, up to the limit
        while (itr.hasNext()) {
            InputAttemptIdentifier id = itr.next();
            if (!obsoleteMaps.contains(id) && !isInputFinished(id.getInputIdentifier().getSrcTaskIndex())) {
                result.add(id);
                if (++includedMaps >= MAX_MAPS_AT_ONCE) {
                    break;
                }
            }
        }
        // put back the maps left after the limit
        while (itr.hasNext()) {
            InputAttemptIdentifier id = itr.next();
            if (!obsoleteMaps.contains(id) && !isInputFinished(id.getInputIdentifier().getSrcTaskIndex())) {
                host.addKnownMap(id);
            }
        }
        LOG.info("assigned " + includedMaps + " of " + totalSize + " to " + host + " to "
                + Thread.currentThread().getName());
        return result;
    }

    public synchronized void freeHost(MapHost host) {
        if (host.getState() != MapHost.State.PENALIZED) {
            if (host.markAvailable() == MapHost.State.PENDING) {
                pendingHosts.add(host);
                notifyAll();
            }
        }
        LOG.info(host + " freed by " + Thread.currentThread().getName() + " in "
                + (System.currentTimeMillis() - shuffleStart.get()) + "s");
    }

    public synchronized void resetKnownMaps() {
        mapLocations.clear();
        obsoleteMaps.clear();
        pendingHosts.clear();
        pathToIdentifierMap.clear();
    }

    /**
     * Utility method to check if the Shuffle data fetch is complete.
     * @return
     */
    public synchronized boolean isDone() {
        return remainingMaps == 0;
    }

    /**
     * Wait until the shuffle finishes or until the timeout.
     * @param millis maximum wait time
     * @return true if the shuffle is done
     * @throws InterruptedException
     */
    public synchronized boolean waitUntilDone(int millis) throws InterruptedException {
        if (remainingMaps > 0) {
            wait(millis);
            return remainingMaps == 0;
        }
        return true;
    }

    /**
     * A structure that records the penalty for a host.
     */
    private static class Penalty implements Delayed {
        MapHost host;
        private long endTime;

        Penalty(MapHost host, long delay) {
            this.host = host;
            this.endTime = System.currentTimeMillis() + delay;
        }

        public long getDelay(TimeUnit unit) {
            long remainingTime = endTime - System.currentTimeMillis();
            return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
        }

        public int compareTo(Delayed o) {
            long other = ((Penalty) o).endTime;
            return endTime == other ? 0 : (endTime < other ? -1 : 1);
        }

    }

    private String getIdentifierFromPathAndReduceId(String path, int reduceId) {
        return path + "_" + reduceId;
    }

    /**
     * A thread that takes hosts off of the penalty list when the timer expires.
     */
    private class Referee extends Thread {
        public Referee() {
            setName("ShufflePenaltyReferee");
            setDaemon(true);
        }

        public void run() {
            try {
                while (true) {
                    // take the first host that has an expired penalty
                    MapHost host = penalties.take().host;
                    synchronized (ShuffleScheduler.this) {
                        if (host.markAvailable() == MapHost.State.PENDING) {
                            pendingHosts.add(host);
                            ShuffleScheduler.this.notifyAll();
                        }
                    }
                }
            } catch (InterruptedException ie) {
                return;
            } catch (Throwable t) {
                shuffle.reportException(t);
            }
        }
    }

    public void close() throws InterruptedException {
        referee.interrupt();
        referee.join();
    }

    public synchronized void informMaxMapRunTime(int duration) {
        if (duration > maxMapRuntime) {
            maxMapRuntime = duration;
        }
    }

    void setInputFinished(int inputIndex) {
        synchronized (finishedMaps) {
            finishedMaps[inputIndex] = true;
        }
    }

    boolean isInputFinished(int inputIndex) {
        synchronized (finishedMaps) {
            return finishedMaps[inputIndex];
        }
    }
}