Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.network.client; import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.List; import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import com.codahale.metrics.MetricSet; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.*; /** * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same * TransportClient for the same remote host. It also shares a single worker thread pool for * all TransportClients. * * TransportClients will be reused whenever possible. Prior to completing the creation of a new * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { /** A simple data structure to track the pool of clients between two peer nodes. */ private static class ClientPool { TransportClient[] clients; Object[] locks; ClientPool(int size) { clients = new TransportClient[size]; locks = new Object[size]; for (int i = 0; i < size; i++) { locks[i] = new Object(); } } } private static final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); private final TransportContext context; private final TransportConf conf; private final List<TransportClientBootstrap> clientBootstraps; private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool; /** Random number generator for picking connections between peers. */ private final Random rand; private final int numConnectionsPerPeer; private final Class<? extends Channel> socketChannelClass; private EventLoopGroup workerGroup; private PooledByteBufAllocator pooledAllocator; private final NettyMemoryMetrics metrics; public TransportClientFactory(TransportContext context, List<TransportClientBootstrap> clientBootstraps) { this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap<>(); this.numConnectionsPerPeer = conf.numConnectionsPerPeer(); this.rand = new Random(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), conf.getModuleName() + "-client"); this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); this.metrics = new NettyMemoryMetrics(this.pooledAllocator, conf.getModuleName() + "-client", conf); } public MetricSet getAllMetrics() { return metrics; } /** * Create a {@link TransportClient} connecting to the given remote host / port. * * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer) * and randomly picks one to use. If no client was previously created in the randomly selected * spot, this function creates a new client and places it there. * * Prior to the creation of a new TransportClient, we will execute all * {@link TransportClientBootstrap}s that are registered with this factory. * * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ public TransportClient createClient(String remoteHost, int remotePort) throws IOException, InterruptedException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. // Use unresolved address here to avoid DNS resolution each time we creates a client. final InetSocketAddress unresolvedAddress = InetSocketAddress.createUnresolved(remoteHost, remotePort); // Create the ClientPool if we don't have it yet. ClientPool clientPool = connectionPool.get(unresolvedAddress); if (clientPool == null) { connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer)); clientPool = connectionPool.get(unresolvedAddress); } int clientIndex = rand.nextInt(numConnectionsPerPeer); TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { // Make sure that the channel will not timeout by updating the last use time of the // handler. Then check that the client is still alive, in case it timed out before // this code was able to update things. TransportChannelHandler handler = cachedClient.getChannel().pipeline() .get(TransportChannelHandler.class); synchronized (handler) { handler.getResponseHandler().updateTimeOfLastRequest(); } if (cachedClient.isActive()) { logger.trace("Returning cached connection to {}: {}", cachedClient.getSocketAddress(), cachedClient); return cachedClient; } } // If we reach here, we don't have an existing connection open. Let's create a new one. // Multiple threads might race here to create new connections. Keep only one of them active. final long preResolveHost = System.nanoTime(); final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort); final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; if (hostResolveTimeMs > 2000) { logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); } else { logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); } synchronized (clientPool.locks[clientIndex]) { cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null) { if (cachedClient.isActive()) { logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient); return cachedClient; } else { logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress); } } clientPool.clients[clientIndex] = createClient(resolvedAddress); return clientPool.clients[clientIndex]; } } /** * Create a completely new {@link TransportClient} to the given remote host / port. * This connection is not pooled. * * As with {@link #createClient(String, int)}, this method is blocking. */ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) throws IOException, InterruptedException { final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); return createClient(address); } /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException, InterruptedException { logger.debug("Creating new connection to {}", address); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup).channel(socketChannelClass) // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, true).option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) .option(ChannelOption.ALLOCATOR, pooledAllocator); if (conf.receiveBuf() > 0) { bootstrap.option(ChannelOption.SO_RCVBUF, conf.receiveBuf()); } if (conf.sendBuf() > 0) { bootstrap.option(ChannelOption.SO_SNDBUF, conf.sendBuf()); } final AtomicReference<TransportClient> clientRef = new AtomicReference<>(); final AtomicReference<Channel> channelRef = new AtomicReference<>(); bootstrap.handler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); clientRef.set(clientHandler.getClient()); channelRef.set(ch); } }); // Connect to the remote server long preConnect = System.nanoTime(); ChannelFuture cf = bootstrap.connect(address); if (!cf.await(conf.connectionTimeoutMs())) { throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } TransportClient client = clientRef.get(); Channel channel = channelRef.get(); assert client != null : "Channel future completed successfully with null client"; // Execute any client bootstraps synchronously before marking the Client as successful. long preBootstrap = System.nanoTime(); logger.debug("Connection to {} successful, running bootstraps...", address); try { for (TransportClientBootstrap clientBootstrap : clientBootstraps) { clientBootstrap.doBootstrap(client, channel); } } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000; logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e); client.close(); throw Throwables.propagate(e); } long postBootstrap = System.nanoTime(); logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000); return client; } /** Close all connections in the connection pool, and shutdown the worker thread pool. */ @Override public void close() { // Go through all clients and close them if they are active. for (ClientPool clientPool : connectionPool.values()) { for (int i = 0; i < clientPool.clients.length; i++) { TransportClient client = clientPool.clients[i]; if (client != null) { clientPool.clients[i] = null; JavaUtils.closeQuietly(client); } } } connectionPool.clear(); if (workerGroup != null) { workerGroup.shutdownGracefully(); workerGroup = null; } } }