org.apache.spark.network.netty.NettyTransportClientFactory.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.spark.network.netty.NettyTransportClientFactory.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.network.netty;

import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

/**
 * Factory for creating {@link NettyTransportClient}s by using createClient.
 *
 * The factory maintains a connection pool to other hosts and should return the same
 * TransportClient for the same remote host. It also shares a single worker thread pool for
 * all TransportClients.
 *
 * TransportClients will be reused whenever possible. Prior to completing the creation of a new
 * TransportClient, all given {@link TransportClientBootstrap}s will be run.
 */
public class NettyTransportClientFactory implements TransportClientFactory {

    /** A simple data structure to track the pool of clients between two peer nodes. */
    private static class ClientPool {
        NettyTransportClient[] clients;
        Object[] locks;

        public ClientPool(int size) {
            clients = new NettyTransportClient[size];
            locks = new Object[size];
            for (int i = 0; i < size; i++) {
                locks[i] = new Object();
            }
        }
    }

    private final Logger logger = LoggerFactory.getLogger(NettyTransportClientFactory.class);

    private final NettyTransportContext context;
    private final TransportConf conf;
    private final List<TransportClientBootstrap> clientBootstraps;
    private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;

    /** Random number generator for picking connections between peers. */
    private final Random rand;
    private final int numConnectionsPerPeer;

    private final Class<? extends Channel> socketChannelClass;
    private EventLoopGroup workerGroup;
    private PooledByteBufAllocator pooledAllocator;

    public NettyTransportClientFactory(NettyTransportContext context,
            List<TransportClientBootstrap> clientBootstraps) {
        this.context = Preconditions.checkNotNull(context);
        this.conf = context.getConf();
        this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
        this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
        this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
        this.rand = new Random();

        IOMode ioMode = IOMode.valueOf(conf.ioMode());
        this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
        // TODO: Make thread pool name configurable.
        this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
        this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(conf.preferDirectBufs(),
                false /* allowCache */, conf.clientThreads());
    }

    /**
     * Create a {@link NettyTransportClient} connecting to the given remote host / port.
     *
     * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
     * and randomly picks one to use. If no client was previously created in the randomly selected
     * spot, this function creates a new client and places it there.
     *
     * Prior to the creation of a new TransportClient, we will execute all
     * {@link TransportClientBootstrap}s that are registered with this factory.
     *
     * This blocks until a connection is successfully established and fully bootstrapped.
     *
     * Concurrency: This method is safe to call from multiple threads.
     */
    public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
        // Get connection from the connection pool first.
        // If it is not found or not active, create a new one.
        final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);

        // Create the ClientPool if we don't have it yet.
        ClientPool clientPool = connectionPool.get(address);
        if (clientPool == null) {
            connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer));
            clientPool = connectionPool.get(address);
        }

        int clientIndex = rand.nextInt(numConnectionsPerPeer);
        NettyTransportClient cachedClient = clientPool.clients[clientIndex];

        if (cachedClient != null && cachedClient.isActive()) {
            // Make sure that the channel will not timeout by updating the last use time of the
            // handler. Then check that the client is still alive, in case it timed out before
            // this code was able to update things.
            TransportChannelHandler handler = cachedClient.getChannel().pipeline()
                    .get(TransportChannelHandler.class);
            synchronized (handler) {
                handler.getResponseHandler().updateTimeOfLastRequest();
            }

            if (cachedClient.isActive()) {
                logger.trace("Returning cached connection to {}: {}", address, cachedClient);
                return cachedClient;
            }
        }

        // If we reach here, we don't have an existing connection open. Let's create a new one.
        // Multiple threads might race here to create new connections. Keep only one of them active.
        synchronized (clientPool.locks[clientIndex]) {
            cachedClient = clientPool.clients[clientIndex];

            if (cachedClient != null) {
                if (cachedClient.isActive()) {
                    logger.trace("Returning cached connection to {}: {}", address, cachedClient);
                    return cachedClient;
                } else {
                    logger.info("Found inactive connection to {}, creating a new one.", address);
                }
            }
            clientPool.clients[clientIndex] = createClient(address);
            return clientPool.clients[clientIndex];
        }
    }

    /**
     * Create a completely new {@link TransportClient} to the given remote host / port.
     * This connection is not pooled.
     *
     * As with {@link #createClient(String, int)}, this method is blocking.
     */
    public NettyTransportClient createUnmanagedClient(String remoteHost, int remotePort) throws IOException {
        final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
        return createClient(address);
    }

    /** Create a completely new {@link TransportClient} to the remote address. */
    private NettyTransportClient createClient(InetSocketAddress address) throws IOException {
        logger.debug("Creating new connection to " + address);

        Bootstrap bootstrap = new Bootstrap();
        bootstrap.group(workerGroup).channel(socketChannelClass)
                // Disable Nagle's Algorithm since we don't want packets to wait
                .option(ChannelOption.TCP_NODELAY, true).option(ChannelOption.SO_KEEPALIVE, true)
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
                .option(ChannelOption.ALLOCATOR, pooledAllocator);

        final AtomicReference<NettyTransportClient> clientRef = new AtomicReference<NettyTransportClient>();
        final AtomicReference<Channel> channelRef = new AtomicReference<Channel>();

        bootstrap.handler(new ChannelInitializer<SocketChannel>() {
            @Override
            public void initChannel(SocketChannel ch) {
                TransportChannelHandler clientHandler = context.initializePipeline(ch);
                clientRef.set(clientHandler.getClient());
                channelRef.set(ch);
            }
        });

        // Connect to the remote server
        long preConnect = System.nanoTime();
        ChannelFuture cf = bootstrap.connect(address);
        if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
            throw new IOException(
                    String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
        } else if (cf.cause() != null) {
            throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
        }

        NettyTransportClient client = clientRef.get();
        Channel channel = channelRef.get();
        assert client != null : "Channel future completed successfully with null client";

        // Execute any client bootstraps synchronously before marking the Client as successful.
        long preBootstrap = System.nanoTime();
        logger.debug("Connection to {} successful, running bootstraps...", address);
        try {
            for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
                clientBootstrap.doBootstrap(client, channel);
            }
        } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
            long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
            logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
            client.close();
            throw Throwables.propagate(e);
        }
        long postBootstrap = System.nanoTime();

        logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", address,
                (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);

        return client;
    }

    /** Close all connections in the connection pool, and shutdown the worker thread pool. */
    @Override
    public void close() {
        // Go through all clients and close them if they are active.
        for (ClientPool clientPool : connectionPool.values()) {
            for (int i = 0; i < clientPool.clients.length; i++) {
                NettyTransportClient client = clientPool.clients[i];
                if (client != null) {
                    clientPool.clients[i] = null;
                    JavaUtils.closeQuietly(client);
                }
            }
        }
        connectionPool.clear();

        if (workerGroup != null) {
            workerGroup.shutdownGracefully();
            workerGroup = null;
        }
    }
}