Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tez.runtime.library.shuffle.common; import java.io.DataInputStream; import java.io.IOException; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.URL; import java.net.URLConnection; import java.security.GeneralSecurityException; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicInteger; import javax.crypto.SecretKey; import javax.net.ssl.HttpsURLConnection; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.security.ssl.SSLFactory; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.tez.common.TezJobConfig; import org.apache.tez.runtime.library.common.InputAttemptIdentifier; import org.apache.tez.runtime.library.common.security.SecureShuffleUtils; import org.apache.tez.runtime.library.common.shuffle.impl.ShuffleHeader; import org.apache.tez.runtime.library.shuffle.common.FetchedInput.Type; import com.google.common.base.Preconditions; /** * Responsible for fetching inputs served by the ShuffleHandler for a single * host. Construct using {@link FetcherBuilder} */ public class Fetcher implements Callable<FetchResult> { private static final Log LOG = LogFactory.getLog(Fetcher.class); private static final int UNIT_CONNECT_TIMEOUT = 60 * 1000; private static final AtomicInteger fetcherIdGen = new AtomicInteger(0); // Configurable fields. private CompressionCodec codec; private int connectionTimeout; private int readTimeout; private boolean ifileReadAhead = TezJobConfig.TEZ_RUNTIME_IFILE_READAHEAD_DEFAULT; private int ifileReadAheadLength = TezJobConfig.TEZ_RUNTIME_IFILE_READAHEAD_BYTES_DEFAULT; private final SecretKey shuffleSecret; private final FetcherCallback fetcherCallback; private final FetchedInputAllocator inputManager; private final ApplicationId appId; private static boolean sslShuffle = false; private static SSLFactory sslFactory; private static boolean sslFactoryInited; private final int fetcherIdentifier; // Parameters to track work. private List<InputAttemptIdentifier> srcAttempts; private String host; private int port; private int partition; // Maps from the pathComponents (unique per srcTaskId) to the specific taskId private final Map<String, InputAttemptIdentifier> pathToAttemptMap; private LinkedHashSet<InputAttemptIdentifier> remaining; private URL url; private String encHash; private String msgToEncode; private Fetcher(FetcherCallback fetcherCallback, FetchedInputAllocator inputManager, ApplicationId appId, SecretKey shuffleSecret, Configuration conf) { this.fetcherCallback = fetcherCallback; this.inputManager = inputManager; this.shuffleSecret = shuffleSecret; this.appId = appId; this.pathToAttemptMap = new HashMap<String, InputAttemptIdentifier>(); this.fetcherIdentifier = fetcherIdGen.getAndIncrement(); // TODO NEWTEZ Ideally, move this out from here into a static initializer block. // Re-enable when ssl shuffle support is needed. // synchronized (Fetcher.class) { // if (!sslFactoryInited) { // sslFactoryInited = true; // sslShuffle = conf.getBoolean( // TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL, // TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_SSL); // if (sslShuffle) { // sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, conf); // try { // sslFactory.init(); // } catch (Exception ex) { // sslFactory.destroy(); // throw new RuntimeException(ex); // } // } // } // } } @Override public FetchResult call() throws Exception { if (srcAttempts.size() == 0) { return new FetchResult(host, port, partition, srcAttempts); } for (InputAttemptIdentifier in : srcAttempts) { pathToAttemptMap.put(in.getPathComponent(), in); } remaining = new LinkedHashSet<InputAttemptIdentifier>(srcAttempts); HttpURLConnection connection; try { connection = connectToShuffleHandler(host, port, partition, srcAttempts); } catch (IOException e) { // ioErrs.increment(1); // If connect did not succeed, just mark all the maps as failed, // indirectly penalizing the host for (Iterator<InputAttemptIdentifier> leftIter = remaining.iterator(); leftIter.hasNext();) { fetcherCallback.fetchFailed(host, leftIter.next(), true); } return new FetchResult(host, port, partition, remaining); } DataInputStream input; try { input = new DataInputStream(connection.getInputStream()); validateConnectionResponse(connection, url, msgToEncode, encHash); } catch (IOException e) { // ioErrs.increment(1); // If we got a read error at this stage, it implies there was a problem // with the first map, typically lost map. So, penalize only that map // and add the rest InputAttemptIdentifier firstAttempt = srcAttempts.get(0); LOG.warn("Fetch Failure from host while connecting: " + host + ", attempt: " + firstAttempt + " Informing ShuffleManager: ", e); fetcherCallback.fetchFailed(host, firstAttempt, false); return new FetchResult(host, port, partition, remaining); } // By this point, the connection is setup and the response has been // validated. // Loop through available map-outputs and fetch them // On any error, faildTasks is not null and we exit // after putting back the remaining maps to the // yet_to_be_fetched list and marking the failed tasks. InputAttemptIdentifier[] failedInputs = null; while (!remaining.isEmpty() && failedInputs == null) { failedInputs = fetchInputs(input); } if (failedInputs != null && failedInputs.length > 0) { LOG.warn("copyInputs failed for tasks " + Arrays.toString(failedInputs)); for (InputAttemptIdentifier left : failedInputs) { fetcherCallback.fetchFailed(host, left, false); } } IOUtils.cleanup(LOG, input); // Sanity check if (failedInputs == null && !remaining.isEmpty()) { throw new IOException("server didn't return all expected map outputs: " + remaining.size() + " left."); } return new FetchResult(host, port, partition, remaining); } private InputAttemptIdentifier[] fetchInputs(DataInputStream input) { FetchedInput fetchedInput = null; InputAttemptIdentifier srcAttemptId = null; long decompressedLength = -1; long compressedLength = -1; try { long startTime = System.currentTimeMillis(); int responsePartition = -1; // Read the shuffle header String pathComponent = null; try { ShuffleHeader header = new ShuffleHeader(); header.readFields(input); pathComponent = header.getMapId(); srcAttemptId = pathToAttemptMap.get(pathComponent); compressedLength = header.getCompressedLength(); decompressedLength = header.getUncompressedLength(); responsePartition = header.getPartition(); } catch (IllegalArgumentException e) { // badIdErrs.increment(1); LOG.warn("Invalid src id ", e); // Don't know which one was bad, so consider all of them as bad return remaining.toArray(new InputAttemptIdentifier[remaining.size()]); } // Do some basic sanity verification if (!verifySanity(compressedLength, decompressedLength, responsePartition, srcAttemptId, pathComponent)) { if (srcAttemptId == null) { LOG.warn("Was expecting " + getNextRemainingAttempt() + " but got null"); srcAttemptId = getNextRemainingAttempt(); } assert (srcAttemptId != null); return new InputAttemptIdentifier[] { srcAttemptId }; } if (LOG.isDebugEnabled()) { LOG.debug("header: " + srcAttemptId + ", len: " + compressedLength + ", decomp len: " + decompressedLength); } // Get the location for the map output - either in-memory or on-disk fetchedInput = inputManager.allocate(decompressedLength, compressedLength, srcAttemptId); // TODO NEWTEZ No concept of WAIT at the moment. // // Check if we can shuffle *now* ... // if (fetchedInput.getType() == FetchedInput.WAIT) { // LOG.info("fetcher#" + id + // " - MergerManager returned Status.WAIT ..."); // //Not an error but wait to process data. // return EMPTY_ATTEMPT_ID_ARRAY; // } // Go! LOG.info("fetcher" + " about to shuffle output of srcAttempt " + fetchedInput.getInputAttemptIdentifier() + " decomp: " + decompressedLength + " len: " + compressedLength + " to " + fetchedInput.getType()); if (fetchedInput.getType() == Type.MEMORY) { ShuffleUtils.shuffleToMemory((MemoryFetchedInput) fetchedInput, input, (int) decompressedLength, (int) compressedLength, codec, ifileReadAhead, ifileReadAheadLength, LOG); } else { ShuffleUtils.shuffleToDisk((DiskFetchedInput) fetchedInput, input, compressedLength, LOG); } // Inform the shuffle scheduler long endTime = System.currentTimeMillis(); fetcherCallback.fetchSucceeded(host, srcAttemptId, fetchedInput, compressedLength, (endTime - startTime)); // Note successful shuffle remaining.remove(srcAttemptId); // metrics.successFetch(); return null; } catch (IOException ioe) { // ioErrs.increment(1); if (srcAttemptId == null || fetchedInput == null) { LOG.info("fetcher" + " failed to read map header" + srcAttemptId + " decomp: " + decompressedLength + ", " + compressedLength, ioe); if (srcAttemptId == null) { return remaining.toArray(new InputAttemptIdentifier[remaining.size()]); } else { return new InputAttemptIdentifier[] { srcAttemptId }; } } LOG.warn("Failed to shuffle output of " + srcAttemptId + " from " + host, ioe); // Inform the shuffle-scheduler try { fetchedInput.abort(); } catch (IOException e) { LOG.info("Failure to cleanup fetchedInput: " + fetchedInput); } // metrics.failedFetch(); return new InputAttemptIdentifier[] { srcAttemptId }; } } /** * Do some basic verification on the input received -- Being defensive * * @param compressedLength * @param decompressedLength * @param fetchPartition * @param remaining * @param mapId * @return true/false, based on if the verification succeeded or not */ private boolean verifySanity(long compressedLength, long decompressedLength, int fetchPartition, InputAttemptIdentifier srcAttemptId, String pathComponent) { if (compressedLength < 0 || decompressedLength < 0) { // wrongLengthErrs.increment(1); LOG.warn(" invalid lengths in input header -> headerPathComponent: " + pathComponent + ", nextRemainingSrcAttemptId: " + getNextRemainingAttempt() + ", mappedSrcAttemptId: " + srcAttemptId + " len: " + compressedLength + ", decomp len: " + decompressedLength); return false; } if (fetchPartition != this.partition) { // wrongReduceErrs.increment(1); LOG.warn(" data for the wrong reduce -> headerPathComponent: " + pathComponent + "nextRemainingSrcAttemptId: " + getNextRemainingAttempt() + ", mappedSrcAttemptId: " + srcAttemptId + " len: " + compressedLength + " decomp len: " + decompressedLength + " for reduce " + fetchPartition); return false; } // Sanity check if (!remaining.contains(srcAttemptId)) { // wrongMapErrs.increment(1); LOG.warn("Invalid input. Received output for headerPathComponent: " + pathComponent + "nextRemainingSrcAttemptId: " + getNextRemainingAttempt() + ", mappedSrcAttemptId: " + srcAttemptId); return false; } return true; } private InputAttemptIdentifier getNextRemainingAttempt() { if (remaining.size() > 0) { return remaining.iterator().next(); } else { return null; } } private HttpURLConnection connectToShuffleHandler(String host, int port, int partition, List<InputAttemptIdentifier> inputs) throws IOException { try { this.url = constructInputURL(host, port, partition, inputs); HttpURLConnection connection = openConnection(url); // generate hash of the url this.msgToEncode = SecureShuffleUtils.buildMsgFrom(url); this.encHash = SecureShuffleUtils.hashFromString(msgToEncode, shuffleSecret); // put url hash into http header connection.addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); // set the read timeout connection.setReadTimeout(readTimeout); // put shuffle version into http header connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); connect(connection, connectionTimeout); return connection; } catch (IOException e) { LOG.warn("Failed to connect to " + host + " with " + srcAttempts.size() + " inputs", e); throw e; } } private void validateConnectionResponse(HttpURLConnection connection, URL url, String msgToEncode, String encHash) throws IOException { int rc = connection.getResponseCode(); if (rc != HttpURLConnection.HTTP_OK) { throw new IOException( "Got invalid response code " + rc + " from " + url + ": " + connection.getResponseMessage()); } if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME .equals(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION .equals(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))) { throw new IOException("Incompatible shuffle response version"); } // get the replyHash which is HMac of the encHash we sent to the server String replyHash = connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH); if (replyHash == null) { throw new IOException("security validation of TT Map output failed"); } if (LOG.isDebugEnabled()) { LOG.debug("url=" + msgToEncode + ";encHash=" + encHash + ";replyHash=" + replyHash); } // verify that replyHash is HMac of encHash SecureShuffleUtils.verifyReply(replyHash, encHash, shuffleSecret); LOG.info("for url=" + msgToEncode + " sent hash and receievd reply"); } protected HttpURLConnection openConnection(URL url) throws IOException { HttpURLConnection conn = (HttpURLConnection) url.openConnection(); if (sslShuffle) { HttpsURLConnection httpsConn = (HttpsURLConnection) conn; try { httpsConn.setSSLSocketFactory(sslFactory.createSSLSocketFactory()); } catch (GeneralSecurityException ex) { throw new IOException(ex); } httpsConn.setHostnameVerifier(sslFactory.getHostnameVerifier()); } return conn; } /** * The connection establishment is attempted multiple times and is given up * only on the last failure. Instead of connecting with a timeout of X, we try * connecting with a timeout of x < X but multiple times. */ private void connect(URLConnection connection, int connectionTimeout) throws IOException { int unit = 0; if (connectionTimeout < 0) { throw new IOException("Invalid timeout " + "[timeout = " + connectionTimeout + " ms]"); } else if (connectionTimeout > 0) { unit = Math.min(UNIT_CONNECT_TIMEOUT, connectionTimeout); } // set the connect timeout to the unit-connect-timeout connection.setConnectTimeout(unit); while (true) { try { connection.connect(); break; } catch (IOException ioe) { // update the total remaining connect-timeout connectionTimeout -= unit; // throw an exception if we have waited for timeout amount of time // note that the updated value if timeout is used here if (connectionTimeout == 0) { throw ioe; } // reset the connect timeout for the last try if (connectionTimeout < unit) { unit = connectionTimeout; // reset the connect time out for the final connect connection.setConnectTimeout(unit); } } } } private URL constructInputURL(String host, int port, int partition, List<InputAttemptIdentifier> inputs) throws MalformedURLException { StringBuilder url = ShuffleUtils.constructBaseURIForShuffleHandler(host, port, partition, appId); boolean first = true; for (InputAttemptIdentifier input : inputs) { if (first) { first = false; url.append(input.getPathComponent()); } else { url.append(",").append(input.getPathComponent()); } } if (LOG.isDebugEnabled()) { LOG.debug("InputFetch URL for: " + host + " : " + url.toString()); } return new URL(url.toString()); } /** * Builder for the construction of Fetchers */ public static class FetcherBuilder { private Fetcher fetcher; private boolean workAssigned = false; public FetcherBuilder(FetcherCallback fetcherCallback, FetchedInputAllocator inputManager, ApplicationId appId, SecretKey shuffleSecret, Configuration conf) { this.fetcher = new Fetcher(fetcherCallback, inputManager, appId, shuffleSecret, conf); } public FetcherBuilder setCompressionParameters(CompressionCodec codec) { fetcher.codec = codec; return this; } public FetcherBuilder setConnectionParameters(int connectionTimeout, int readTimeout) { fetcher.connectionTimeout = connectionTimeout; fetcher.readTimeout = readTimeout; return this; } public FetcherBuilder setIFileParams(boolean readAhead, int readAheadBytes) { fetcher.ifileReadAhead = readAhead; fetcher.ifileReadAheadLength = readAheadBytes; return this; } public FetcherBuilder assignWork(String host, int port, int partition, List<InputAttemptIdentifier> inputs) { fetcher.host = host; fetcher.port = port; fetcher.partition = partition; fetcher.srcAttempts = inputs; workAssigned = true; return this; } public Fetcher build() { Preconditions.checkState(workAssigned == true, "Cannot build a fetcher withot assigning work to it"); return fetcher; } } @Override public int hashCode() { return fetcherIdentifier; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; Fetcher other = (Fetcher) obj; if (fetcherIdentifier != other.fetcherIdentifier) return false; return true; } }