org.apache.flink.runtime.query.netty.KvStateClient.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.runtime.query.netty.KvStateClient.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.query.netty;

import akka.dispatch.Futures;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.stream.ChunkedWriteHandler;
import org.apache.flink.runtime.io.network.netty.NettyBufferPool;
import org.apache.flink.runtime.query.KvStateID;
import org.apache.flink.runtime.query.KvStateServerAddress;
import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
import org.apache.flink.util.Preconditions;
import scala.concurrent.Future;
import scala.concurrent.Promise;

import java.nio.channels.ClosedChannelException;
import java.util.ArrayDeque;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

/**
 * Netty-based client querying {@link KvStateServer} instances.
 *
 * <p>This client can be used by multiple threads concurrently. Operations are
 * executed asynchronously and return Futures to their result.
 *
 * <p>The incoming pipeline looks as follows:
 * <pre>
 * Socket.read() -> LengthFieldBasedFrameDecoder -> KvStateServerHandler
 * </pre>
 *
 * <p>Received binary messages are expected to contain a frame length field. Netty's
 * {@link LengthFieldBasedFrameDecoder} is used to fully receive the frame before
 * giving it to our {@link KvStateClientHandler}.
 *
 * <p>Connections are established and closed by the client. The server only
 * closes the connection on a fatal failure that cannot be recovered.
 */
public class KvStateClient {

    /** Netty's Bootstrap. */
    private final Bootstrap bootstrap;

    /** Statistics tracker */
    private final KvStateRequestStats stats;

    /** Established connections. */
    private final ConcurrentHashMap<KvStateServerAddress, EstablishedConnection> establishedConnections = new ConcurrentHashMap<>();

    /** Pending connections. */
    private final ConcurrentHashMap<KvStateServerAddress, PendingConnection> pendingConnections = new ConcurrentHashMap<>();

    /** Atomic shut down flag. */
    private final AtomicBoolean shutDown = new AtomicBoolean();

    /**
     * Creates a client with the specified number of event loop threads.
     *
     * @param numEventLoopThreads Number of event loop threads (minimum 1).
     */
    public KvStateClient(int numEventLoopThreads, KvStateRequestStats stats) {
        Preconditions.checkArgument(numEventLoopThreads >= 1, "Non-positive number of event loop threads.");
        NettyBufferPool bufferPool = new NettyBufferPool(numEventLoopThreads);

        ThreadFactory threadFactory = new ThreadFactoryBuilder().setDaemon(true)
                .setNameFormat("Flink KvStateClient Event Loop Thread %d").build();

        NioEventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory);

        this.bootstrap = new Bootstrap().group(nioGroup).channel(NioSocketChannel.class)
                .option(ChannelOption.ALLOCATOR, bufferPool).handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) throws Exception {
                        ch.pipeline().addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
                                // ChunkedWriteHandler respects Channel writability
                                .addLast(new ChunkedWriteHandler());
                    }
                });

        this.stats = Preconditions.checkNotNull(stats, "Statistics tracker");
    }

    /**
     * Returns a future holding the serialized request result.
     *
     * <p>If the server does not serve a KvState instance with the given ID,
     * the Future will be failed with a {@link UnknownKvStateID}.
     *
     * <p>If the KvState instance does not hold any data for the given key
     * and namespace, the Future will be failed with a {@link UnknownKeyOrNamespace}.
     *
     * <p>All other failures are forwarded to the Future.
     *
     * @param serverAddress Address of the server to query
     * @param kvStateId ID of the KvState instance to query
     * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance with
     * @return Future holding the serialized result
     */
    public Future<byte[]> getKvState(KvStateServerAddress serverAddress, KvStateID kvStateId,
            byte[] serializedKeyAndNamespace) {

        if (shutDown.get()) {
            return Futures.failed(new IllegalStateException("Shut down"));
        }

        EstablishedConnection connection = establishedConnections.get(serverAddress);

        if (connection != null) {
            return connection.getKvState(kvStateId, serializedKeyAndNamespace);
        } else {
            PendingConnection pendingConnection = pendingConnections.get(serverAddress);
            if (pendingConnection != null) {
                // There was a race, use the existing pending connection.
                return pendingConnection.getKvState(kvStateId, serializedKeyAndNamespace);
            } else {
                // We try to connect to the server.
                PendingConnection pending = new PendingConnection(serverAddress);
                PendingConnection previous = pendingConnections.putIfAbsent(serverAddress, pending);

                if (previous == null) {
                    // OK, we are responsible to connect.
                    bootstrap.connect(serverAddress.getHost(), serverAddress.getPort()).addListener(pending);

                    return pending.getKvState(kvStateId, serializedKeyAndNamespace);
                } else {
                    // There was a race, use the existing pending connection.
                    return previous.getKvState(kvStateId, serializedKeyAndNamespace);
                }
            }
        }
    }

    /**
     * Shuts down the client and closes all connections.
     *
     * <p>After a call to this method, all returned futures will be failed.
     */
    public void shutDown() {
        if (shutDown.compareAndSet(false, true)) {
            for (Map.Entry<KvStateServerAddress, EstablishedConnection> conn : establishedConnections.entrySet()) {
                if (establishedConnections.remove(conn.getKey(), conn.getValue())) {
                    conn.getValue().close();
                }
            }

            for (Map.Entry<KvStateServerAddress, PendingConnection> conn : pendingConnections.entrySet()) {
                if (pendingConnections.remove(conn.getKey()) != null) {
                    conn.getValue().close();
                }
            }

            if (bootstrap != null) {
                EventLoopGroup group = bootstrap.group();
                if (group != null) {
                    group.shutdownGracefully(0, 10, TimeUnit.SECONDS);
                }
            }
        }
    }

    /**
     * Closes the connection to the given server address if it exists.
     *
     * <p>If there is a request to the server a new connection will be established.
     *
     * @param serverAddress Target address of the connection to close
     */
    public void closeConnection(KvStateServerAddress serverAddress) {
        PendingConnection pending = pendingConnections.get(serverAddress);
        if (pending != null) {
            pending.close();
        }

        EstablishedConnection established = establishedConnections.remove(serverAddress);
        if (established != null) {
            established.close();
        }
    }

    /**
     * A pending connection that is in the process of connecting.
     */
    private class PendingConnection implements ChannelFutureListener {

        /** Lock to guard the connect call, channel hand in, etc. */
        private final Object connectLock = new Object();

        /** Address of the server we are connecting to. */
        private final KvStateServerAddress serverAddress;

        /** Queue of requests while connecting. */
        private final ArrayDeque<PendingRequest> queuedRequests = new ArrayDeque<>();

        /** The established connection after the connect succeeds. */
        private EstablishedConnection established;

        /** Closed flag. */
        private boolean closed;

        /** Failure cause if something goes wrong. */
        private Throwable failureCause;

        /**
         * Creates a pending connection to the given server.
         *
         * @param serverAddress Address of the server to connect to.
         */
        private PendingConnection(KvStateServerAddress serverAddress) {
            this.serverAddress = serverAddress;
        }

        @Override
        public void operationComplete(ChannelFuture future) throws Exception {
            // Callback from the Bootstrap's connect call.
            if (future.isSuccess()) {
                handInChannel(future.channel());
            } else {
                close(future.cause());
            }
        }

        /**
         * Returns a future holding the serialized request result.
         *
         * <p>If the channel has been established, forward the call to the
         * established channel, otherwise queue it for when the channel is
         * handed in.
         *
         * @param kvStateId                 ID of the KvState instance to query
         * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance
         *                                  with
         * @return Future holding the serialized result
         */
        public Future<byte[]> getKvState(KvStateID kvStateId, byte[] serializedKeyAndNamespace) {
            synchronized (connectLock) {
                if (failureCause != null) {
                    return Futures.failed(failureCause);
                } else if (closed) {
                    return Futures.failed(new ClosedChannelException());
                } else {
                    if (established != null) {
                        return established.getKvState(kvStateId, serializedKeyAndNamespace);
                    } else {
                        // Queue this and handle when connected
                        PendingRequest pending = new PendingRequest(kvStateId, serializedKeyAndNamespace);
                        queuedRequests.add(pending);
                        return pending.promise.future();
                    }
                }
            }
        }

        /**
         * Hands in a channel after a successful connection.
         *
         * @param channel Channel to hand in
         */
        private void handInChannel(Channel channel) {
            synchronized (connectLock) {
                if (closed || failureCause != null) {
                    // Close the channel and we are done. Any queued requests
                    // are removed on the close/failure call and after that no
                    // new ones can be enqueued.
                    channel.close();
                } else {
                    established = new EstablishedConnection(serverAddress, channel);

                    PendingRequest pending;
                    while ((pending = queuedRequests.poll()) != null) {
                        Future<byte[]> resultFuture = established.getKvState(pending.kvStateId,
                                pending.serializedKeyAndNamespace);

                        pending.promise.completeWith(resultFuture);
                    }

                    // Publish the channel for the general public
                    establishedConnections.put(serverAddress, established);
                    pendingConnections.remove(serverAddress);

                    // Check shut down for possible race with shut down. We
                    // don't want any lingering connections after shut down,
                    // which can happen if we don't check this here.
                    if (shutDown.get()) {
                        if (establishedConnections.remove(serverAddress, established)) {
                            established.close();
                        }
                    }
                }
            }
        }

        /**
         * Close the connecting channel with a ClosedChannelException.
         */
        private void close() {
            close(new ClosedChannelException());
        }

        /**
         * Close the connecting channel with an Exception (can be
         * <code>null</code>) or forward to the established channel.
         */
        private void close(Throwable cause) {
            synchronized (connectLock) {
                if (!closed) {
                    if (failureCause == null) {
                        failureCause = cause;
                    }

                    if (established != null) {
                        established.close();
                    } else {
                        PendingRequest pending;
                        while ((pending = queuedRequests.poll()) != null) {
                            pending.promise.tryFailure(cause);
                        }
                    }

                    closed = true;
                }
            }
        }

        /**
         * A pending request queued while the channel is connecting.
         */
        private final class PendingRequest {

            private final KvStateID kvStateId;
            private final byte[] serializedKeyAndNamespace;
            private final Promise<byte[]> promise;

            private PendingRequest(KvStateID kvStateId, byte[] serializedKeyAndNamespace) {
                this.kvStateId = kvStateId;
                this.serializedKeyAndNamespace = serializedKeyAndNamespace;
                this.promise = Futures.promise();
            }
        }

        @Override
        public String toString() {
            synchronized (connectLock) {
                return "PendingConnection{" + "serverAddress=" + serverAddress + ", queuedRequests="
                        + queuedRequests.size() + ", established=" + (established != null) + ", closed=" + closed
                        + '}';
            }
        }
    }

    /**
     * An established connection that wraps the actual channel instance and is
     * registered at the {@link KvStateClientHandler} for callbacks.
     */
    private class EstablishedConnection implements KvStateClientHandlerCallback {

        /** Address of the server we are connected to. */
        private final KvStateServerAddress serverAddress;

        /** The actual TCP channel. */
        private final Channel channel;

        /** Pending requests keyed by request ID. */
        private final ConcurrentHashMap<Long, PromiseAndTimestamp> pendingRequests = new ConcurrentHashMap<>();

        /** Current request number used to assign unique request IDs. */
        private final AtomicLong requestCount = new AtomicLong();

        /** Reference to a failure that was reported by the channel. */
        private final AtomicReference<Throwable> failureCause = new AtomicReference<>();

        /**
         * Creates an established connection with the given channel.
         *
         * @param serverAddress Address of the server connected to
         * @param channel The actual TCP channel
         */
        EstablishedConnection(KvStateServerAddress serverAddress, Channel channel) {
            this.serverAddress = Preconditions.checkNotNull(serverAddress, "KvStateServerAddress");
            this.channel = Preconditions.checkNotNull(channel, "Channel");

            // Add the client handler with the callback
            channel.pipeline().addLast("KvStateClientHandler", new KvStateClientHandler(this));

            stats.reportActiveConnection();
        }

        /**
         * Close the channel with a ClosedChannelException.
         */
        void close() {
            close(new ClosedChannelException());
        }

        /**
         * Close the channel with a cause.
         *
         * @param cause The cause to close the channel with.
         * @return Channel close future
         */
        private boolean close(Throwable cause) {
            if (failureCause.compareAndSet(null, cause)) {
                channel.close();
                stats.reportInactiveConnection();

                for (long requestId : pendingRequests.keySet()) {
                    PromiseAndTimestamp pending = pendingRequests.remove(requestId);
                    if (pending != null && pending.promise.tryFailure(cause)) {
                        stats.reportFailedRequest();
                    }
                }

                return true;
            }

            return false;
        }

        /**
         * Returns a future holding the serialized request result.
         *
         * @param kvStateId                 ID of the KvState instance to query
         * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance
         *                                  with
         * @return Future holding the serialized result
         */
        Future<byte[]> getKvState(KvStateID kvStateId, byte[] serializedKeyAndNamespace) {
            PromiseAndTimestamp requestPromiseTs = new PromiseAndTimestamp(Futures.<byte[]>promise(),
                    System.nanoTime());

            try {
                final long requestId = requestCount.getAndIncrement();
                pendingRequests.put(requestId, requestPromiseTs);

                stats.reportRequest();

                ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequest(channel.alloc(), requestId,
                        kvStateId, serializedKeyAndNamespace);

                channel.writeAndFlush(buf).addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                        if (!future.isSuccess()) {
                            // Fail promise if not failed to write
                            PromiseAndTimestamp pending = pendingRequests.remove(requestId);
                            if (pending != null && pending.promise.tryFailure(future.cause())) {
                                stats.reportFailedRequest();
                            }
                        }
                    }
                });

                // Check failure for possible race. We don't want any lingering
                // promises after a failure, which can happen if we don't check
                // this here. Note that close is treated as a failure as well.
                Throwable failure = failureCause.get();
                if (failure != null) {
                    // Remove from pending requests to guard against concurrent
                    // removal and to make sure that we only count it once as failed.
                    PromiseAndTimestamp p = pendingRequests.remove(requestId);
                    if (p != null && p.promise.tryFailure(failure)) {
                        stats.reportFailedRequest();
                    }
                }
            } catch (Throwable t) {
                requestPromiseTs.promise.tryFailure(t);
            }

            return requestPromiseTs.promise.future();
        }

        @Override
        public void onRequestResult(long requestId, byte[] serializedValue) {
            PromiseAndTimestamp pending = pendingRequests.remove(requestId);
            if (pending != null && pending.promise.trySuccess(serializedValue)) {
                long durationMillis = (System.nanoTime() - pending.timestamp) / 1_000_000;
                stats.reportSuccessfulRequest(durationMillis);
            }
        }

        @Override
        public void onRequestFailure(long requestId, Throwable cause) {
            PromiseAndTimestamp pending = pendingRequests.remove(requestId);
            if (pending != null && pending.promise.tryFailure(cause)) {
                stats.reportFailedRequest();
            }
        }

        @Override
        public void onFailure(Throwable cause) {
            if (close(cause)) {
                // Remove from established channels, otherwise future
                // requests will be handled by this failed channel.
                establishedConnections.remove(serverAddress, this);
            }
        }

        @Override
        public String toString() {
            return "EstablishedConnection{" + "serverAddress=" + serverAddress + ", channel=" + channel
                    + ", pendingRequests=" + pendingRequests.size() + ", requestCount=" + requestCount
                    + ", failureCause=" + failureCause + '}';
        }

        /**
         * Pair of promise and a timestamp.
         */
        private class PromiseAndTimestamp {

            private final Promise<byte[]> promise;
            private final long timestamp;

            public PromiseAndTimestamp(Promise<byte[]> promise, long timestamp) {
                this.promise = promise;
                this.timestamp = timestamp;
            }
        }

    }

}