Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.giraph.comm.netty; import org.apache.giraph.comm.netty.handler.AddressRequestIdGenerator; import org.apache.giraph.comm.netty.handler.ClientRequestId; import org.apache.giraph.comm.netty.handler.RequestEncoder; import org.apache.giraph.comm.netty.handler.RequestInfo; import org.apache.giraph.comm.netty.handler.RequestServerHandler; import org.apache.giraph.comm.netty.handler.ResponseClientHandler; /*if_not[HADOOP_NON_SECURE]*/ import org.apache.giraph.comm.netty.handler.SaslClientHandler; import org.apache.giraph.comm.requests.RequestType; import org.apache.giraph.comm.requests.SaslTokenMessageRequest; /*end[HADOOP_NON_SECURE]*/ import org.apache.giraph.comm.requests.WritableRequest; import org.apache.giraph.conf.GiraphConstants; import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration; import org.apache.giraph.graph.TaskInfo; import org.apache.giraph.utils.PipelineUtils; import org.apache.giraph.utils.ProgressableUtils; import org.apache.giraph.utils.ThreadUtils; import org.apache.giraph.utils.TimedLogger; import org.apache.hadoop.mapreduce.Mapper; import org.apache.log4j.Logger; import com.google.common.collect.Lists; import com.google.common.collect.MapMaker; import com.google.common.collect.Maps; import java.io.IOException; import java.net.InetSocketAddress; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.FixedLengthFrameDecoder; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.util.AttributeKey; import io.netty.util.concurrent.DefaultEventExecutorGroup; import io.netty.util.concurrent.EventExecutorGroup; import static org.apache.giraph.conf.GiraphConstants.CLIENT_RECEIVE_BUFFER_SIZE; import static org.apache.giraph.conf.GiraphConstants.CLIENT_SEND_BUFFER_SIZE; import static org.apache.giraph.conf.GiraphConstants.MAX_REQUEST_MILLISECONDS; import static org.apache.giraph.conf.GiraphConstants.MAX_RESOLVE_ADDRESS_ATTEMPTS; import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_EXECUTION_AFTER_HANDLER; import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_EXECUTION_THREADS; import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_USE_EXECUTION_HANDLER; import static org.apache.giraph.conf.GiraphConstants.NETTY_MAX_CONNECTION_FAILURES; import static org.apache.giraph.conf.GiraphConstants.WAITING_REQUEST_MSECS; /** * Netty client for sending requests. Thread-safe. */ public class NettyClient { /** Do we have a limit on number of open requests we can have */ public static final String LIMIT_NUMBER_OF_OPEN_REQUESTS = "giraph.waitForRequestsConfirmation"; /** Default choice about having a limit on number of open requests */ public static final boolean LIMIT_NUMBER_OF_OPEN_REQUESTS_DEFAULT = false; /** Maximum number of requests without confirmation we should have */ public static final String MAX_NUMBER_OF_OPEN_REQUESTS = "giraph.maxNumberOfOpenRequests"; /** Default maximum number of requests without confirmation */ public static final int MAX_NUMBER_OF_OPEN_REQUESTS_DEFAULT = 10000; /** Maximum number of requests to list (for debugging) */ public static final int MAX_REQUESTS_TO_LIST = 10; /** * Maximum number of destination task ids with open requests to list * (for debugging) */ public static final int MAX_DESTINATION_TASK_IDS_TO_LIST = 10; /** 30 seconds to connect by default */ public static final int MAX_CONNECTION_MILLISECONDS_DEFAULT = 30 * 1000; /*if_not[HADOOP_NON_SECURE]*/ /** Used to authenticate with other workers acting as servers */ public static final AttributeKey<SaslNettyClient> SASL = AttributeKey.valueOf("saslNettyClient"); /*end[HADOOP_NON_SECURE]*/ /** Class logger */ private static final Logger LOG = Logger.getLogger(NettyClient.class); /** Context used to report progress */ private final Mapper<?, ?, ?, ?>.Context context; /** Client bootstrap */ private final Bootstrap bootstrap; /** * Map of the peer connections, mapping from remote socket address to client * meta data */ private final ConcurrentMap<InetSocketAddress, ChannelRotater> addressChannelMap = new MapMaker().makeMap(); /** * Map from task id to address of its server */ private final Map<Integer, InetSocketAddress> taskIdAddressMap = new MapMaker().makeMap(); /** * Request map of client request ids to request information. */ private final ConcurrentMap<ClientRequestId, RequestInfo> clientRequestIdRequestInfoMap; /** Number of channels per server */ private final int channelsPerServer; /** Inbound byte counter for this client */ private final InboundByteCounter inboundByteCounter = new InboundByteCounter(); /** Outbound byte counter for this client */ private final OutboundByteCounter outboundByteCounter = new OutboundByteCounter(); /** Send buffer size */ private final int sendBufferSize; /** Receive buffer size */ private final int receiveBufferSize; /** Do we have a limit on number of open requests */ private final boolean limitNumberOfOpenRequests; /** Maximum number of requests without confirmation we can have */ private final int maxNumberOfOpenRequests; /** Maximum number of connection failures */ private final int maxConnectionFailures; /** Maximum number of milliseconds for a request */ private final int maxRequestMilliseconds; /** Waiting internal for checking outstanding requests msecs */ private final int waitingRequestMsecs; /** Timed logger for printing request debugging */ private final TimedLogger requestLogger = new TimedLogger(15 * 1000, LOG); /** Worker executor group */ private final EventLoopGroup workerGroup; /** Address request id generator */ private final AddressRequestIdGenerator addressRequestIdGenerator = new AddressRequestIdGenerator(); /** Task info */ private final TaskInfo myTaskInfo; /** Maximum thread pool size */ private final int maxPoolSize; /** Maximum number of attempts to resolve an address*/ private final int maxResolveAddressAttempts; /** Use execution handler? */ private final boolean useExecutionGroup; /** EventExecutor Group (if used) */ private final EventExecutorGroup executionGroup; /** Name of the handler to use execution group for (if used) */ private final String handlerToUseExecutionGroup; /** When was the last time we checked if we should resend some requests */ private final AtomicLong lastTimeCheckedRequestsForProblems = new AtomicLong(0); /** * Logger used to dump stack traces for every exception that happens * in netty client threads. */ private final LogOnErrorChannelFutureListener logErrorListener = new LogOnErrorChannelFutureListener(); /** * Only constructor * * @param context Context for progress * @param conf Configuration * @param myTaskInfo Current task info * @param exceptionHandler handler for uncaught exception. Will * terminate job. */ public NettyClient(Mapper<?, ?, ?, ?>.Context context, final ImmutableClassesGiraphConfiguration conf, TaskInfo myTaskInfo, final Thread.UncaughtExceptionHandler exceptionHandler) { this.context = context; this.myTaskInfo = myTaskInfo; this.channelsPerServer = GiraphConstants.CHANNELS_PER_SERVER.get(conf); sendBufferSize = CLIENT_SEND_BUFFER_SIZE.get(conf); receiveBufferSize = CLIENT_RECEIVE_BUFFER_SIZE.get(conf); limitNumberOfOpenRequests = conf.getBoolean(LIMIT_NUMBER_OF_OPEN_REQUESTS, LIMIT_NUMBER_OF_OPEN_REQUESTS_DEFAULT); if (limitNumberOfOpenRequests) { maxNumberOfOpenRequests = conf.getInt(MAX_NUMBER_OF_OPEN_REQUESTS, MAX_NUMBER_OF_OPEN_REQUESTS_DEFAULT); if (LOG.isInfoEnabled()) { LOG.info("NettyClient: Limit number of open requests to " + maxNumberOfOpenRequests); } } else { maxNumberOfOpenRequests = -1; } maxRequestMilliseconds = MAX_REQUEST_MILLISECONDS.get(conf); maxConnectionFailures = NETTY_MAX_CONNECTION_FAILURES.get(conf); waitingRequestMsecs = WAITING_REQUEST_MSECS.get(conf); maxPoolSize = GiraphConstants.NETTY_CLIENT_THREADS.get(conf); maxResolveAddressAttempts = MAX_RESOLVE_ADDRESS_ATTEMPTS.get(conf); clientRequestIdRequestInfoMap = new MapMaker().concurrencyLevel(maxPoolSize).makeMap(); handlerToUseExecutionGroup = NETTY_CLIENT_EXECUTION_AFTER_HANDLER.get(conf); useExecutionGroup = NETTY_CLIENT_USE_EXECUTION_HANDLER.get(conf); if (useExecutionGroup) { int executionThreads = NETTY_CLIENT_EXECUTION_THREADS.get(conf); executionGroup = new DefaultEventExecutorGroup(executionThreads, ThreadUtils.createThreadFactory("netty-client-exec-%d", exceptionHandler)); if (LOG.isInfoEnabled()) { LOG.info("NettyClient: Using execution handler with " + executionThreads + " threads after " + handlerToUseExecutionGroup + "."); } } else { executionGroup = null; } workerGroup = new NioEventLoopGroup(maxPoolSize, ThreadUtils.createThreadFactory("netty-client-worker-%d", exceptionHandler)); bootstrap = new Bootstrap(); bootstrap.group(workerGroup).channel(NioSocketChannel.class) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, MAX_CONNECTION_MILLISECONDS_DEFAULT) .option(ChannelOption.TCP_NODELAY, true).option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.SO_SNDBUF, sendBufferSize).option(ChannelOption.SO_RCVBUF, receiveBufferSize) .option(ChannelOption.ALLOCATOR, conf.getNettyAllocator()) .handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { /*if_not[HADOOP_NON_SECURE]*/ if (conf.authenticate()) { LOG.info("Using Netty with authentication."); // Our pipeline starts with just byteCounter, and then we use // addLast() to incrementally add pipeline elements, so that we // can name them for identification for removal or replacement // after client is authenticated by server. // After authentication is complete, the pipeline's SASL-specific // functionality is removed, restoring the pipeline to exactly the // same configuration as it would be without authentication. PipelineUtils.addLastWithExecutorCheck("clientInboundByteCounter", inboundByteCounter, handlerToUseExecutionGroup, executionGroup, ch); if (conf.doCompression()) { PipelineUtils.addLastWithExecutorCheck("compressionDecoder", conf.getNettyCompressionDecoder(), handlerToUseExecutionGroup, executionGroup, ch); } PipelineUtils.addLastWithExecutorCheck("clientOutboundByteCounter", outboundByteCounter, handlerToUseExecutionGroup, executionGroup, ch); if (conf.doCompression()) { PipelineUtils.addLastWithExecutorCheck("compressionEncoder", conf.getNettyCompressionEncoder(), handlerToUseExecutionGroup, executionGroup, ch); } // The following pipeline component is needed to decode the // server's SASL tokens. It is replaced with a // FixedLengthFrameDecoder (same as used with the // non-authenticated pipeline) after authentication // completes (as in non-auth pipeline below). PipelineUtils.addLastWithExecutorCheck("length-field-based-frame-decoder", new LengthFieldBasedFrameDecoder(1024, 0, 4, 0, 4), handlerToUseExecutionGroup, executionGroup, ch); PipelineUtils.addLastWithExecutorCheck("request-encoder", new RequestEncoder(conf), handlerToUseExecutionGroup, executionGroup, ch); // The following pipeline component responds to the server's SASL // tokens with its own responses. Both client and server share the // same Hadoop Job token, which is used to create the SASL // tokens to authenticate with each other. // After authentication finishes, this pipeline component // is removed. PipelineUtils.addLastWithExecutorCheck("sasl-client-handler", new SaslClientHandler(conf), handlerToUseExecutionGroup, executionGroup, ch); PipelineUtils.addLastWithExecutorCheck("response-handler", new ResponseClientHandler(clientRequestIdRequestInfoMap, conf), handlerToUseExecutionGroup, executionGroup, ch); } else { LOG.info("Using Netty without authentication."); /*end[HADOOP_NON_SECURE]*/ PipelineUtils.addLastWithExecutorCheck("clientInboundByteCounter", inboundByteCounter, handlerToUseExecutionGroup, executionGroup, ch); if (conf.doCompression()) { PipelineUtils.addLastWithExecutorCheck("compressionDecoder", conf.getNettyCompressionDecoder(), handlerToUseExecutionGroup, executionGroup, ch); } PipelineUtils.addLastWithExecutorCheck("clientOutboundByteCounter", outboundByteCounter, handlerToUseExecutionGroup, executionGroup, ch); if (conf.doCompression()) { PipelineUtils.addLastWithExecutorCheck("compressionEncoder", conf.getNettyCompressionEncoder(), handlerToUseExecutionGroup, executionGroup, ch); } PipelineUtils.addLastWithExecutorCheck("fixed-length-frame-decoder", new FixedLengthFrameDecoder(RequestServerHandler.RESPONSE_BYTES), handlerToUseExecutionGroup, executionGroup, ch); PipelineUtils.addLastWithExecutorCheck("request-encoder", new RequestEncoder(conf), handlerToUseExecutionGroup, executionGroup, ch); PipelineUtils.addLastWithExecutorCheck("response-handler", new ResponseClientHandler(clientRequestIdRequestInfoMap, conf), handlerToUseExecutionGroup, executionGroup, ch); /*if_not[HADOOP_NON_SECURE]*/ } /*end[HADOOP_NON_SECURE]*/ } }); } /** * Pair object for connectAllAddresses(). */ private static class ChannelFutureAddress { /** Future object */ private final ChannelFuture future; /** Address of the future */ private final InetSocketAddress address; /** Task id */ private final Integer taskId; /** * Constructor. * * @param future Immutable future * @param address Immutable address * @param taskId Immutable taskId */ ChannelFutureAddress(ChannelFuture future, InetSocketAddress address, Integer taskId) { this.future = future; this.address = address; this.taskId = taskId; } @Override public String toString() { return "(future=" + future + ",address=" + address + ",taskId=" + taskId + ")"; } } /** * Connect to a collection of tasks servers * * @param tasks Tasks to connect to (if haven't already connected) */ public void connectAllAddresses(Collection<? extends TaskInfo> tasks) { List<ChannelFutureAddress> waitingConnectionList = Lists .newArrayListWithCapacity(tasks.size() * channelsPerServer); for (TaskInfo taskInfo : tasks) { context.progress(); InetSocketAddress address = taskIdAddressMap.get(taskInfo.getTaskId()); if (address == null || !address.getHostName().equals(taskInfo.getHostname()) || address.getPort() != taskInfo.getPort()) { address = resolveAddress(maxResolveAddressAttempts, taskInfo.getInetSocketAddress()); taskIdAddressMap.put(taskInfo.getTaskId(), address); } if (address == null || address.getHostName() == null || address.getHostName().isEmpty()) { throw new IllegalStateException("connectAllAddresses: Null address " + "in addresses " + tasks); } if (address.isUnresolved()) { throw new IllegalStateException("connectAllAddresses: Unresolved " + "address " + address); } if (addressChannelMap.containsKey(address)) { continue; } // Start connecting to the remote server up to n time for (int i = 0; i < channelsPerServer; ++i) { ChannelFuture connectionFuture = bootstrap.connect(address); waitingConnectionList .add(new ChannelFutureAddress(connectionFuture, address, taskInfo.getTaskId())); } } // Wait for all the connections to succeed up to n tries int failures = 0; int connected = 0; while (failures < maxConnectionFailures) { List<ChannelFutureAddress> nextCheckFutures = Lists.newArrayList(); for (ChannelFutureAddress waitingConnection : waitingConnectionList) { context.progress(); ChannelFuture future = waitingConnection.future; ProgressableUtils.awaitChannelFuture(future, context); if (!future.isSuccess()) { LOG.warn("connectAllAddresses: Future failed " + "to connect with " + waitingConnection.address + " with " + failures + " failures because of " + future.cause()); ChannelFuture connectionFuture = bootstrap.connect(waitingConnection.address); nextCheckFutures.add(new ChannelFutureAddress(connectionFuture, waitingConnection.address, waitingConnection.taskId)); ++failures; } else { Channel channel = future.channel(); if (LOG.isDebugEnabled()) { LOG.debug("connectAllAddresses: Connected to " + channel.remoteAddress() + ", open = " + channel.isOpen()); } if (channel.remoteAddress() == null) { throw new IllegalStateException("connectAllAddresses: Null remote address!"); } ChannelRotater rotater = addressChannelMap.get(waitingConnection.address); if (rotater == null) { ChannelRotater newRotater = new ChannelRotater(waitingConnection.taskId); rotater = addressChannelMap.putIfAbsent(waitingConnection.address, newRotater); if (rotater == null) { rotater = newRotater; } } rotater.addChannel(future.channel()); ++connected; } } LOG.info("connectAllAddresses: Successfully added " + (waitingConnectionList.size() - nextCheckFutures.size()) + " connections, (" + connected + " total connected) " + nextCheckFutures.size() + " failed, " + failures + " failures total."); if (nextCheckFutures.isEmpty()) { break; } waitingConnectionList = nextCheckFutures; } if (failures >= maxConnectionFailures) { throw new IllegalStateException("connectAllAddresses: Too many failures (" + failures + ")."); } } /*if_not[HADOOP_NON_SECURE]*/ /** * Authenticate all servers in addressChannelMap. */ public void authenticate() { LOG.info("authenticate: NettyClient starting authentication with " + "servers."); for (InetSocketAddress address : addressChannelMap.keySet()) { if (LOG.isDebugEnabled()) { LOG.debug("authenticate: Authenticating with address:" + address); } ChannelRotater channelRotater = addressChannelMap.get(address); for (Channel channel : channelRotater.getChannels()) { if (LOG.isDebugEnabled()) { LOG.debug("authenticate: Authenticating with server on channel: " + channel); } authenticateOnChannel(channelRotater.getTaskId(), channel); } } if (LOG.isInfoEnabled()) { LOG.info("authenticate: NettyClient successfully authenticated with " + addressChannelMap.size() + " server" + ((addressChannelMap.size() != 1) ? "s" : "") + " - continuing with normal work."); } } /** * Authenticate with server connected at given channel. * * @param taskId Task id of the channel * @param channel Connection to server to authenticate with. */ private void authenticateOnChannel(Integer taskId, Channel channel) { try { SaslNettyClient saslNettyClient = channel.attr(SASL).get(); if (channel.attr(SASL).get() == null) { if (LOG.isDebugEnabled()) { LOG.debug("authenticateOnChannel: Creating saslNettyClient now " + "for channel: " + channel); } saslNettyClient = new SaslNettyClient(); channel.attr(SASL).set(saslNettyClient); } if (!saslNettyClient.isComplete()) { if (LOG.isDebugEnabled()) { LOG.debug("authenticateOnChannel: Waiting for authentication " + "to complete.."); } SaslTokenMessageRequest saslTokenMessage = saslNettyClient.firstToken(); sendWritableRequest(taskId, saslTokenMessage); // We now wait for Netty's thread pool to communicate over this // channel to authenticate with another worker acting as a server. try { synchronized (saslNettyClient.getAuthenticated()) { while (!saslNettyClient.isComplete()) { saslNettyClient.getAuthenticated().wait(); } } } catch (InterruptedException e) { LOG.error("authenticateOnChannel: Interrupted while waiting for " + "authentication."); } } if (LOG.isDebugEnabled()) { LOG.debug("authenticateOnChannel: Authentication on channel: " + channel + " has completed successfully."); } } catch (IOException e) { LOG.error("authenticateOnChannel: Failed to authenticate with server " + "due to error: " + e); } return; } /*end[HADOOP_NON_SECURE]*/ /** * Stop the client. */ public void stop() { if (LOG.isInfoEnabled()) { LOG.info("stop: Halting netty client"); } // Close connections asynchronously, in a Netty-approved // way, without cleaning up thread pools until all channels // in addressChannelMap are closed (success or failure) int channelCount = 0; for (ChannelRotater channelRotater : addressChannelMap.values()) { channelCount += channelRotater.size(); } final int done = channelCount; final AtomicInteger count = new AtomicInteger(0); for (ChannelRotater channelRotater : addressChannelMap.values()) { channelRotater.closeChannels(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture cf) { context.progress(); if (count.incrementAndGet() == done) { if (LOG.isInfoEnabled()) { LOG.info("stop: reached wait threshold, " + done + " connections closed, releasing " + "resources now."); } workerGroup.shutdownGracefully(); if (executionGroup != null) { executionGroup.shutdownGracefully(); } } } }); } ProgressableUtils.awaitTerminationFuture(workerGroup, context); if (executionGroup != null) { ProgressableUtils.awaitTerminationFuture(executionGroup, context); } if (LOG.isInfoEnabled()) { LOG.info("stop: Netty client halted"); } } /** * Get the next available channel, reconnecting if necessary * * @param remoteServer Remote server to get a channel for * @return Available channel for this remote server */ private Channel getNextChannel(InetSocketAddress remoteServer) { Channel channel = addressChannelMap.get(remoteServer).nextChannel(); if (channel == null) { throw new IllegalStateException("getNextChannel: No channel exists for " + remoteServer); } // Return this channel if it is connected if (channel.isActive()) { return channel; } // Get rid of the failed channel if (addressChannelMap.get(remoteServer).removeChannel(channel)) { LOG.warn("getNextChannel: Unlikely event that the channel " + channel + " was already removed!"); } if (LOG.isInfoEnabled()) { LOG.info("getNextChannel: Fixing disconnected channel to " + remoteServer + ", open = " + channel.isOpen() + ", " + "bound = " + channel.isRegistered()); } int reconnectFailures = 0; while (reconnectFailures < maxConnectionFailures) { ChannelFuture connectionFuture = bootstrap.connect(remoteServer); ProgressableUtils.awaitChannelFuture(connectionFuture, context); if (connectionFuture.isSuccess()) { if (LOG.isInfoEnabled()) { LOG.info("getNextChannel: Connected to " + remoteServer + "!"); } addressChannelMap.get(remoteServer).addChannel(connectionFuture.channel()); return connectionFuture.channel(); } ++reconnectFailures; LOG.warn( "getNextChannel: Failed to reconnect to " + remoteServer + " on attempt " + reconnectFailures + " out of " + maxConnectionFailures + " max attempts, sleeping for 5 secs", connectionFuture.cause()); try { Thread.sleep(5000); } catch (InterruptedException e) { LOG.warn("getNextChannel: Unexpected interrupted exception", e); } } throw new IllegalStateException("getNextChannel: Failed to connect " + "to " + remoteServer + " in " + reconnectFailures + " connect attempts"); } /** * Send a request to a remote server (should be already connected) * * @param destTaskId Destination task id * @param request Request to send */ public void sendWritableRequest(Integer destTaskId, WritableRequest request) { InetSocketAddress remoteServer = taskIdAddressMap.get(destTaskId); if (clientRequestIdRequestInfoMap.isEmpty()) { inboundByteCounter.resetAll(); outboundByteCounter.resetAll(); } boolean registerRequest = true; /*if_not[HADOOP_NON_SECURE]*/ if (request.getType() == RequestType.SASL_TOKEN_MESSAGE_REQUEST) { registerRequest = false; } /*end[HADOOP_NON_SECURE]*/ Channel channel = getNextChannel(remoteServer); RequestInfo newRequestInfo = new RequestInfo(remoteServer, request); if (registerRequest) { request.setClientId(myTaskInfo.getTaskId()); request.setRequestId(addressRequestIdGenerator.getNextRequestId(remoteServer)); ClientRequestId clientRequestId = new ClientRequestId(destTaskId, request.getRequestId()); RequestInfo oldRequestInfo = clientRequestIdRequestInfoMap.putIfAbsent(clientRequestId, newRequestInfo); if (oldRequestInfo != null) { throw new IllegalStateException( "sendWritableRequest: Impossible to " + "have a previous request id = " + request.getRequestId() + ", " + "request info of " + oldRequestInfo); } } ChannelFuture writeFuture = channel.write(request); newRequestInfo.setWriteFuture(writeFuture); writeFuture.addListener(logErrorListener); if (limitNumberOfOpenRequests && clientRequestIdRequestInfoMap.size() > maxNumberOfOpenRequests) { waitSomeRequests(maxNumberOfOpenRequests); } } /** * Ensure all the request sent so far are complete. * * @throws InterruptedException */ public void waitAllRequests() { waitSomeRequests(0); if (LOG.isInfoEnabled()) { LOG.info("waitAllRequests: Finished all requests. " + inboundByteCounter.getMetrics() + "\n" + outboundByteCounter.getMetrics()); } } /** * Ensure that at most maxOpenRequests are not complete. Periodically, * check the state of every request. If we find the connection failed, * re-establish it and re-send the request. * * @param maxOpenRequests Maximum number of requests which can be not * complete */ private void waitSomeRequests(int maxOpenRequests) { while (clientRequestIdRequestInfoMap.size() > maxOpenRequests) { // Wait for requests to complete for some time logInfoAboutOpenRequests(maxOpenRequests); synchronized (clientRequestIdRequestInfoMap) { if (clientRequestIdRequestInfoMap.size() <= maxOpenRequests) { break; } try { clientRequestIdRequestInfoMap.wait(waitingRequestMsecs); } catch (InterruptedException e) { LOG.error("waitSomeRequests: Got unexpected InterruptedException", e); } } // Make sure that waiting doesn't kill the job context.progress(); checkRequestsForProblems(); } } /** * Log the status of open requests. * * @param maxOpenRequests Maximum number of requests which can be not complete */ private void logInfoAboutOpenRequests(int maxOpenRequests) { if (LOG.isInfoEnabled() && requestLogger.isPrintable()) { LOG.info("logInfoAboutOpenRequests: Waiting interval of " + waitingRequestMsecs + " msecs, " + clientRequestIdRequestInfoMap.size() + " open requests, waiting for it to be <= " + maxOpenRequests + ", " + inboundByteCounter.getMetrics() + "\n" + outboundByteCounter.getMetrics()); if (clientRequestIdRequestInfoMap.size() < MAX_REQUESTS_TO_LIST) { for (Map.Entry<ClientRequestId, RequestInfo> entry : clientRequestIdRequestInfoMap.entrySet()) { LOG.info("logInfoAboutOpenRequests: Waiting for request " + entry.getKey() + " - " + entry.getValue()); } } // Count how many open requests each task has Map<Integer, Integer> openRequestCounts = Maps.newHashMap(); for (ClientRequestId clientRequestId : clientRequestIdRequestInfoMap.keySet()) { int taskId = clientRequestId.getDestinationTaskId(); Integer currentCount = openRequestCounts.get(taskId); openRequestCounts.put(taskId, (currentCount == null ? 0 : currentCount) + 1); } // Sort it in decreasing order of number of open requests List<Map.Entry<Integer, Integer>> sorted = Lists.newArrayList(openRequestCounts.entrySet()); Collections.sort(sorted, new Comparator<Map.Entry<Integer, Integer>>() { @Override public int compare(Map.Entry<Integer, Integer> entry1, Map.Entry<Integer, Integer> entry2) { int value1 = entry1.getValue(); int value2 = entry2.getValue(); return (value1 < value2) ? 1 : ((value1 == value2) ? 0 : -1); } }); // Print task ids which have the most open requests StringBuilder message = new StringBuilder(); message.append("logInfoAboutOpenRequests: "); int itemsToPrint = Math.min(MAX_DESTINATION_TASK_IDS_TO_LIST, sorted.size()); for (int i = 0; i < itemsToPrint; i++) { message.append(sorted.get(i).getValue()).append(" requests for taskId=") .append(sorted.get(i).getKey()).append(", "); } LOG.info(message); } } /** * Check if there are some open requests which have been sent a long time * ago, and if so resend them. */ private void checkRequestsForProblems() { long lastTimeChecked = lastTimeCheckedRequestsForProblems.get(); // If not enough time passed from the previous check, return if (System.currentTimeMillis() < lastTimeChecked + waitingRequestMsecs) { return; } // If another thread did the check already, return if (!lastTimeCheckedRequestsForProblems.compareAndSet(lastTimeChecked, System.currentTimeMillis())) { return; } List<ClientRequestId> addedRequestIds = Lists.newArrayList(); List<RequestInfo> addedRequestInfos = Lists.newArrayList(); // Check all the requests for problems for (Map.Entry<ClientRequestId, RequestInfo> entry : clientRequestIdRequestInfoMap.entrySet()) { RequestInfo requestInfo = entry.getValue(); ChannelFuture writeFuture = requestInfo.getWriteFuture(); // Request wasn't sent yet if (writeFuture == null) { continue; } // If not connected anymore, request failed, or the request is taking // too long, re-establish and resend if (!writeFuture.channel().isActive() || (writeFuture.isDone() && !writeFuture.isSuccess()) || (requestInfo.getElapsedMsecs() > maxRequestMilliseconds)) { LOG.warn("checkRequestsForProblems: Problem with request id " + entry.getKey() + " connected = " + writeFuture.channel().isActive() + ", future done = " + writeFuture.isDone() + ", " + "success = " + writeFuture.isSuccess() + ", " + "cause = " + writeFuture.cause() + ", " + "elapsed time = " + requestInfo.getElapsedMsecs() + ", " + "destination = " + writeFuture.channel().remoteAddress() + " " + requestInfo); addedRequestIds.add(entry.getKey()); addedRequestInfos .add(new RequestInfo(requestInfo.getDestinationAddress(), requestInfo.getRequest())); } } // Add any new requests to the system, connect if necessary, and re-send for (int i = 0; i < addedRequestIds.size(); ++i) { ClientRequestId requestId = addedRequestIds.get(i); RequestInfo requestInfo = addedRequestInfos.get(i); if (clientRequestIdRequestInfoMap.put(requestId, requestInfo) == null) { LOG.warn("checkRequestsForProblems: Request " + requestId + " completed prior to sending the next request"); clientRequestIdRequestInfoMap.remove(requestId); } InetSocketAddress remoteServer = requestInfo.getDestinationAddress(); Channel channel = getNextChannel(remoteServer); if (LOG.isInfoEnabled()) { LOG.info("checkRequestsForProblems: Re-issuing request " + requestInfo); } ChannelFuture writeFuture = channel.write(requestInfo.getRequest()); requestInfo.setWriteFuture(writeFuture); writeFuture.addListener(logErrorListener); } addedRequestIds.clear(); addedRequestInfos.clear(); } /** * Utility method for resolving addresses * * @param maxResolveAddressAttempts Maximum number of attempts to resolve the * address * @param address The address we are attempting to resolve * @return The successfully resolved address. * @throws IllegalStateException if the address is not resolved * in <code>maxResolveAddressAttempts</code> tries. */ private static InetSocketAddress resolveAddress(int maxResolveAddressAttempts, InetSocketAddress address) { int resolveAttempts = 0; while (address.isUnresolved() && resolveAttempts < maxResolveAddressAttempts) { ++resolveAttempts; LOG.warn("resolveAddress: Failed to resolve " + address + " on attempt " + resolveAttempts + " of " + maxResolveAddressAttempts + " attempts, sleeping for 5 seconds"); try { Thread.sleep(5000); } catch (InterruptedException e) { LOG.warn("resolveAddress: Interrupted.", e); } address = new InetSocketAddress(address.getHostName(), address.getPort()); } if (resolveAttempts >= maxResolveAddressAttempts) { throw new IllegalStateException( "resolveAddress: Couldn't " + "resolve " + address + " in " + resolveAttempts + " tries."); } return address; } /** * This listener class just dumps exception stack traces if * something happens. */ private static class LogOnErrorChannelFutureListener implements ChannelFutureListener { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isDone() && !future.isSuccess()) { LOG.error("Request failed", future.cause()); } } } }