org.apache.tez.runtime.library.shuffle.common.Fetcher.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.tez.runtime.library.shuffle.common.Fetcher.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.runtime.library.shuffle.common;

import java.io.DataInputStream;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;

import javax.crypto.SecretKey;
import javax.net.ssl.HttpsURLConnection;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.security.ssl.SSLFactory;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.TezJobConfig;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.security.SecureShuffleUtils;
import org.apache.tez.runtime.library.common.shuffle.impl.ShuffleHeader;
import org.apache.tez.runtime.library.shuffle.common.FetchedInput.Type;

import com.google.common.base.Preconditions;

/**
 * Responsible for fetching inputs served by the ShuffleHandler for a single
 * host. Construct using {@link FetcherBuilder}
 */
public class Fetcher implements Callable<FetchResult> {

    private static final Log LOG = LogFactory.getLog(Fetcher.class);

    private static final int UNIT_CONNECT_TIMEOUT = 60 * 1000;
    private static final AtomicInteger fetcherIdGen = new AtomicInteger(0);

    // Configurable fields.
    private CompressionCodec codec;
    private int connectionTimeout;
    private int readTimeout;

    private boolean ifileReadAhead = TezJobConfig.TEZ_RUNTIME_IFILE_READAHEAD_DEFAULT;
    private int ifileReadAheadLength = TezJobConfig.TEZ_RUNTIME_IFILE_READAHEAD_BYTES_DEFAULT;

    private final SecretKey shuffleSecret;

    private final FetcherCallback fetcherCallback;
    private final FetchedInputAllocator inputManager;
    private final ApplicationId appId;

    private static boolean sslShuffle = false;
    private static SSLFactory sslFactory;
    private static boolean sslFactoryInited;

    private final int fetcherIdentifier;

    // Parameters to track work.
    private List<InputAttemptIdentifier> srcAttempts;
    private String host;
    private int port;
    private int partition;

    // Maps from the pathComponents (unique per srcTaskId) to the specific taskId
    private final Map<String, InputAttemptIdentifier> pathToAttemptMap;
    private LinkedHashSet<InputAttemptIdentifier> remaining;

    private URL url;
    private String encHash;
    private String msgToEncode;

    private Fetcher(FetcherCallback fetcherCallback, FetchedInputAllocator inputManager, ApplicationId appId,
            SecretKey shuffleSecret, Configuration conf) {
        this.fetcherCallback = fetcherCallback;
        this.inputManager = inputManager;
        this.shuffleSecret = shuffleSecret;
        this.appId = appId;
        this.pathToAttemptMap = new HashMap<String, InputAttemptIdentifier>();

        this.fetcherIdentifier = fetcherIdGen.getAndIncrement();

        // TODO NEWTEZ Ideally, move this out from here into a static initializer block.
        // Re-enable when ssl shuffle support is needed.
        //    synchronized (Fetcher.class) {
        //      if (!sslFactoryInited) {
        //        sslFactoryInited = true;
        //        sslShuffle = conf.getBoolean(
        //            TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL,
        //            TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_SSL);
        //        if (sslShuffle) {
        //          sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, conf);
        //          try {
        //            sslFactory.init();
        //          } catch (Exception ex) {
        //            sslFactory.destroy();
        //            throw new RuntimeException(ex);
        //          }
        //        }
        //      }
        //    }
    }

    @Override
    public FetchResult call() throws Exception {
        if (srcAttempts.size() == 0) {
            return new FetchResult(host, port, partition, srcAttempts);
        }

        for (InputAttemptIdentifier in : srcAttempts) {
            pathToAttemptMap.put(in.getPathComponent(), in);
        }

        remaining = new LinkedHashSet<InputAttemptIdentifier>(srcAttempts);

        HttpURLConnection connection;
        try {
            connection = connectToShuffleHandler(host, port, partition, srcAttempts);
        } catch (IOException e) {
            // ioErrs.increment(1);
            // If connect did not succeed, just mark all the maps as failed,
            // indirectly penalizing the host
            for (Iterator<InputAttemptIdentifier> leftIter = remaining.iterator(); leftIter.hasNext();) {
                fetcherCallback.fetchFailed(host, leftIter.next(), true);
            }
            return new FetchResult(host, port, partition, remaining);
        }

        DataInputStream input;

        try {
            input = new DataInputStream(connection.getInputStream());
            validateConnectionResponse(connection, url, msgToEncode, encHash);
        } catch (IOException e) {
            // ioErrs.increment(1);
            // If we got a read error at this stage, it implies there was a problem
            // with the first map, typically lost map. So, penalize only that map
            // and add the rest
            InputAttemptIdentifier firstAttempt = srcAttempts.get(0);
            LOG.warn("Fetch Failure from host while connecting: " + host + ", attempt: " + firstAttempt
                    + " Informing ShuffleManager: ", e);
            fetcherCallback.fetchFailed(host, firstAttempt, false);
            return new FetchResult(host, port, partition, remaining);
        }

        // By this point, the connection is setup and the response has been
        // validated.

        // Loop through available map-outputs and fetch them
        // On any error, faildTasks is not null and we exit
        // after putting back the remaining maps to the
        // yet_to_be_fetched list and marking the failed tasks.
        InputAttemptIdentifier[] failedInputs = null;
        while (!remaining.isEmpty() && failedInputs == null) {
            failedInputs = fetchInputs(input);
        }

        if (failedInputs != null && failedInputs.length > 0) {
            LOG.warn("copyInputs failed for tasks " + Arrays.toString(failedInputs));
            for (InputAttemptIdentifier left : failedInputs) {
                fetcherCallback.fetchFailed(host, left, false);
            }
        }

        IOUtils.cleanup(LOG, input);

        // Sanity check
        if (failedInputs == null && !remaining.isEmpty()) {
            throw new IOException("server didn't return all expected map outputs: " + remaining.size() + " left.");
        }

        return new FetchResult(host, port, partition, remaining);

    }

    private InputAttemptIdentifier[] fetchInputs(DataInputStream input) {
        FetchedInput fetchedInput = null;
        InputAttemptIdentifier srcAttemptId = null;
        long decompressedLength = -1;
        long compressedLength = -1;

        try {
            long startTime = System.currentTimeMillis();
            int responsePartition = -1;
            // Read the shuffle header
            String pathComponent = null;
            try {
                ShuffleHeader header = new ShuffleHeader();
                header.readFields(input);
                pathComponent = header.getMapId();

                srcAttemptId = pathToAttemptMap.get(pathComponent);
                compressedLength = header.getCompressedLength();
                decompressedLength = header.getUncompressedLength();
                responsePartition = header.getPartition();
            } catch (IllegalArgumentException e) {
                // badIdErrs.increment(1);
                LOG.warn("Invalid src id ", e);
                // Don't know which one was bad, so consider all of them as bad
                return remaining.toArray(new InputAttemptIdentifier[remaining.size()]);
            }

            // Do some basic sanity verification
            if (!verifySanity(compressedLength, decompressedLength, responsePartition, srcAttemptId,
                    pathComponent)) {
                if (srcAttemptId == null) {
                    LOG.warn("Was expecting " + getNextRemainingAttempt() + " but got null");
                    srcAttemptId = getNextRemainingAttempt();
                }
                assert (srcAttemptId != null);
                return new InputAttemptIdentifier[] { srcAttemptId };
            }

            if (LOG.isDebugEnabled()) {
                LOG.debug("header: " + srcAttemptId + ", len: " + compressedLength + ", decomp len: "
                        + decompressedLength);
            }

            // Get the location for the map output - either in-memory or on-disk
            fetchedInput = inputManager.allocate(decompressedLength, compressedLength, srcAttemptId);

            // TODO NEWTEZ No concept of WAIT at the moment.
            // // Check if we can shuffle *now* ...
            // if (fetchedInput.getType() == FetchedInput.WAIT) {
            // LOG.info("fetcher#" + id +
            // " - MergerManager returned Status.WAIT ...");
            // //Not an error but wait to process data.
            // return EMPTY_ATTEMPT_ID_ARRAY;
            // }

            // Go!
            LOG.info("fetcher" + " about to shuffle output of srcAttempt "
                    + fetchedInput.getInputAttemptIdentifier() + " decomp: " + decompressedLength + " len: "
                    + compressedLength + " to " + fetchedInput.getType());

            if (fetchedInput.getType() == Type.MEMORY) {
                ShuffleUtils.shuffleToMemory((MemoryFetchedInput) fetchedInput, input, (int) decompressedLength,
                        (int) compressedLength, codec, ifileReadAhead, ifileReadAheadLength, LOG);
            } else {
                ShuffleUtils.shuffleToDisk((DiskFetchedInput) fetchedInput, input, compressedLength, LOG);
            }

            // Inform the shuffle scheduler
            long endTime = System.currentTimeMillis();
            fetcherCallback.fetchSucceeded(host, srcAttemptId, fetchedInput, compressedLength,
                    (endTime - startTime));

            // Note successful shuffle
            remaining.remove(srcAttemptId);
            // metrics.successFetch();
            return null;
        } catch (IOException ioe) {
            // ioErrs.increment(1);
            if (srcAttemptId == null || fetchedInput == null) {
                LOG.info("fetcher" + " failed to read map header" + srcAttemptId + " decomp: " + decompressedLength
                        + ", " + compressedLength, ioe);
                if (srcAttemptId == null) {
                    return remaining.toArray(new InputAttemptIdentifier[remaining.size()]);
                } else {
                    return new InputAttemptIdentifier[] { srcAttemptId };
                }
            }
            LOG.warn("Failed to shuffle output of " + srcAttemptId + " from " + host, ioe);

            // Inform the shuffle-scheduler
            try {
                fetchedInput.abort();
            } catch (IOException e) {
                LOG.info("Failure to cleanup fetchedInput: " + fetchedInput);
            }
            // metrics.failedFetch();
            return new InputAttemptIdentifier[] { srcAttemptId };
        }
    }

    /**
     * Do some basic verification on the input received -- Being defensive
     * 
     * @param compressedLength
     * @param decompressedLength
     * @param fetchPartition
     * @param remaining
     * @param mapId
     * @return true/false, based on if the verification succeeded or not
     */
    private boolean verifySanity(long compressedLength, long decompressedLength, int fetchPartition,
            InputAttemptIdentifier srcAttemptId, String pathComponent) {
        if (compressedLength < 0 || decompressedLength < 0) {
            // wrongLengthErrs.increment(1);
            LOG.warn(" invalid lengths in input header -> headerPathComponent: " + pathComponent
                    + ", nextRemainingSrcAttemptId: " + getNextRemainingAttempt() + ", mappedSrcAttemptId: "
                    + srcAttemptId + " len: " + compressedLength + ", decomp len: " + decompressedLength);
            return false;
        }

        if (fetchPartition != this.partition) {
            // wrongReduceErrs.increment(1);
            LOG.warn(" data for the wrong reduce -> headerPathComponent: " + pathComponent
                    + "nextRemainingSrcAttemptId: " + getNextRemainingAttempt() + ", mappedSrcAttemptId: "
                    + srcAttemptId + " len: " + compressedLength + " decomp len: " + decompressedLength
                    + " for reduce " + fetchPartition);
            return false;
        }

        // Sanity check
        if (!remaining.contains(srcAttemptId)) {
            // wrongMapErrs.increment(1);
            LOG.warn("Invalid input. Received output for headerPathComponent: " + pathComponent
                    + "nextRemainingSrcAttemptId: " + getNextRemainingAttempt() + ", mappedSrcAttemptId: "
                    + srcAttemptId);
            return false;
        }
        return true;
    }

    private InputAttemptIdentifier getNextRemainingAttempt() {
        if (remaining.size() > 0) {
            return remaining.iterator().next();
        } else {
            return null;
        }
    }

    private HttpURLConnection connectToShuffleHandler(String host, int port, int partition,
            List<InputAttemptIdentifier> inputs) throws IOException {
        try {
            this.url = constructInputURL(host, port, partition, inputs);
            HttpURLConnection connection = openConnection(url);

            // generate hash of the url
            this.msgToEncode = SecureShuffleUtils.buildMsgFrom(url);
            this.encHash = SecureShuffleUtils.hashFromString(msgToEncode, shuffleSecret);

            // put url hash into http header
            connection.addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash);
            // set the read timeout
            connection.setReadTimeout(readTimeout);
            // put shuffle version into http header
            connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
            connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
                    ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);

            connect(connection, connectionTimeout);
            return connection;
        } catch (IOException e) {
            LOG.warn("Failed to connect to " + host + " with " + srcAttempts.size() + " inputs", e);
            throw e;
        }
    }

    private void validateConnectionResponse(HttpURLConnection connection, URL url, String msgToEncode,
            String encHash) throws IOException {
        int rc = connection.getResponseCode();
        if (rc != HttpURLConnection.HTTP_OK) {
            throw new IOException(
                    "Got invalid response code " + rc + " from " + url + ": " + connection.getResponseMessage());
        }

        if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME
                .equals(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME))
                || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION
                        .equals(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))) {
            throw new IOException("Incompatible shuffle response version");
        }

        // get the replyHash which is HMac of the encHash we sent to the server
        String replyHash = connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH);
        if (replyHash == null) {
            throw new IOException("security validation of TT Map output failed");
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("url=" + msgToEncode + ";encHash=" + encHash + ";replyHash=" + replyHash);
        }
        // verify that replyHash is HMac of encHash
        SecureShuffleUtils.verifyReply(replyHash, encHash, shuffleSecret);
        LOG.info("for url=" + msgToEncode + " sent hash and receievd reply");
    }

    protected HttpURLConnection openConnection(URL url) throws IOException {
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        if (sslShuffle) {
            HttpsURLConnection httpsConn = (HttpsURLConnection) conn;
            try {
                httpsConn.setSSLSocketFactory(sslFactory.createSSLSocketFactory());
            } catch (GeneralSecurityException ex) {
                throw new IOException(ex);
            }
            httpsConn.setHostnameVerifier(sslFactory.getHostnameVerifier());
        }
        return conn;
    }

    /**
     * The connection establishment is attempted multiple times and is given up
     * only on the last failure. Instead of connecting with a timeout of X, we try
     * connecting with a timeout of x < X but multiple times.
     */
    private void connect(URLConnection connection, int connectionTimeout) throws IOException {
        int unit = 0;
        if (connectionTimeout < 0) {
            throw new IOException("Invalid timeout " + "[timeout = " + connectionTimeout + " ms]");
        } else if (connectionTimeout > 0) {
            unit = Math.min(UNIT_CONNECT_TIMEOUT, connectionTimeout);
        }
        // set the connect timeout to the unit-connect-timeout
        connection.setConnectTimeout(unit);
        while (true) {
            try {
                connection.connect();
                break;
            } catch (IOException ioe) {
                // update the total remaining connect-timeout
                connectionTimeout -= unit;

                // throw an exception if we have waited for timeout amount of time
                // note that the updated value if timeout is used here
                if (connectionTimeout == 0) {
                    throw ioe;
                }

                // reset the connect timeout for the last try
                if (connectionTimeout < unit) {
                    unit = connectionTimeout;
                    // reset the connect time out for the final connect
                    connection.setConnectTimeout(unit);
                }
            }
        }
    }

    private URL constructInputURL(String host, int port, int partition, List<InputAttemptIdentifier> inputs)
            throws MalformedURLException {
        StringBuilder url = ShuffleUtils.constructBaseURIForShuffleHandler(host, port, partition, appId);
        boolean first = true;
        for (InputAttemptIdentifier input : inputs) {
            if (first) {
                first = false;
                url.append(input.getPathComponent());
            } else {
                url.append(",").append(input.getPathComponent());
            }
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("InputFetch URL for: " + host + " : " + url.toString());
        }
        return new URL(url.toString());
    }

    /**
     * Builder for the construction of Fetchers
     */
    public static class FetcherBuilder {
        private Fetcher fetcher;
        private boolean workAssigned = false;

        public FetcherBuilder(FetcherCallback fetcherCallback, FetchedInputAllocator inputManager,
                ApplicationId appId, SecretKey shuffleSecret, Configuration conf) {
            this.fetcher = new Fetcher(fetcherCallback, inputManager, appId, shuffleSecret, conf);
        }

        public FetcherBuilder setCompressionParameters(CompressionCodec codec) {
            fetcher.codec = codec;
            return this;
        }

        public FetcherBuilder setConnectionParameters(int connectionTimeout, int readTimeout) {
            fetcher.connectionTimeout = connectionTimeout;
            fetcher.readTimeout = readTimeout;
            return this;
        }

        public FetcherBuilder setIFileParams(boolean readAhead, int readAheadBytes) {
            fetcher.ifileReadAhead = readAhead;
            fetcher.ifileReadAheadLength = readAheadBytes;
            return this;
        }

        public FetcherBuilder assignWork(String host, int port, int partition,
                List<InputAttemptIdentifier> inputs) {
            fetcher.host = host;
            fetcher.port = port;
            fetcher.partition = partition;
            fetcher.srcAttempts = inputs;
            workAssigned = true;
            return this;
        }

        public Fetcher build() {
            Preconditions.checkState(workAssigned == true, "Cannot build a fetcher withot assigning work to it");
            return fetcher;
        }
    }

    @Override
    public int hashCode() {
        return fetcherIdentifier;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        Fetcher other = (Fetcher) obj;
        if (fetcherIdentifier != other.fetcherIdentifier)
            return false;
        return true;
    }
}