com.github.sparkfy.network.client.TransportClient.java Source code

Java tutorial

Introduction

Here is the source code for com.github.sparkfy.network.client.TransportClient.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.sparkfy.network.client;

import com.github.sparkfy.network.buffer.NioManagedBuffer;
import com.github.sparkfy.network.protocol.*;
import com.github.sparkfy.network.util.NettyUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.SettableFuture;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

/**
 * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
 * efficient transfer of a large amount of data, broken up into chunks with size ranging from
 * hundreds of KB to a few MB.
 * <p/>
 * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
 * the actual setup of the streams is done outside the scope of the transport layer. The convenience
 * method "sendRPC" is provided to enable control plane communication between the client and server
 * to perform this setup.
 * <p/>
 * For example, a typical workflow might be:
 * client.sendRPC(new OpenFile("/foo")) --&gt; returns StreamId = 100
 * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
 * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
 * ...
 * client.sendRPC(new CloseStream(100))
 * <p/>
 * Construct an instance of TransportClient using {@link TransportClientFactory}. A single
 * TransportClient may be used for multiple streams, but any given stream must be restricted to a
 * single client, in order to avoid out-of-order responses.
 * <p/>
 * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is
 * responsible for handling responses from the server.
 * <p/>
 * Concurrency: thread safe and can be called from multiple threads
 */
public class TransportClient implements Closeable {
    private final Logger logger = LoggerFactory.getLogger(TransportClient.class);

    private final Channel channel;
    private final TransportResponseHandler handler;
    private String clientId;
    private volatile boolean timedOut;

    public TransportClient(Channel channel, TransportResponseHandler handler) {
        this.channel = Preconditions.checkNotNull(channel);
        this.handler = Preconditions.checkNotNull(handler);
        this.timedOut = false;
    }

    public Channel getChannel() {
        return channel;
    }

    public boolean isActive() {
        return !timedOut && (channel.isOpen() || channel.isActive());
    }

    public SocketAddress getSocketAddress() {
        return channel.remoteAddress();
    }

    /**
     * Returns the ID used by the client to authenticate itself when authentication is enabled.
     *
     * @return The client ID, or null if authentication is disabled.
     */
    public String getClientId() {
        return clientId;
    }

    /**
     * Sets the authenticated client ID. This is meant to be used by the authentication layer.
     * <p/>
     * Trying to set a different client ID after it's been set will result in an exception.
     */
    public void setClientId(String id) {
        Preconditions.checkState(clientId == null, "Client ID has already been set.");
        this.clientId = id;
    }

    /**
     * Requests a single chunk from the remote side, from the pre-negotiated streamId.
     * <p/>
     * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
     * some streams may not support this.
     * <p/>
     * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
     * to be returned in the same order that they were requested, assuming only a single
     * TransportClient is used to fetch the chunks.
     *
     * @param streamId   Identifier that refers to a stream in the remote StreamManager. This should
     *                   be agreed upon by client and server beforehand.
     * @param chunkIndex 0-based index of the chunk to fetch
     * @param callback   Callback invoked upon successful receipt of chunk, or upon any failure.
     */
    public void fetchChunk(long streamId, final int chunkIndex, final ChunkReceivedCallback callback) {
        final String serverAddr = NettyUtils.getRemoteAddress(channel);
        final long startTime = System.currentTimeMillis();
        logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr);

        final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
        handler.addFetchRequest(streamChunkId, callback);

        channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future) throws Exception {
                if (future.isSuccess()) {
                    long timeTaken = System.currentTimeMillis() - startTime;
                    logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, timeTaken);
                } else {
                    String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
                            serverAddr, future.cause());
                    logger.error(errorMsg, future.cause());
                    handler.removeFetchRequest(streamChunkId);
                    channel.close();
                    try {
                        callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
                    } catch (Exception e) {
                        logger.error("Uncaught exception in RPC response callback handler!", e);
                    }
                }
            }
        });
    }

    /**
     * Request to stream the data with the given stream ID from the remote end.
     *
     * @param streamId The stream to fetch.
     * @param callback Object to call with the stream data.
     */
    public void stream(final String streamId, final StreamCallback callback) {
        final String serverAddr = NettyUtils.getRemoteAddress(channel);
        final long startTime = System.currentTimeMillis();
        logger.debug("Sending stream request for {} to {}", streamId, serverAddr);

        // Need to synchronize here so that the callback is added to the queue and the RPC is
        // written to the socket atomically, so that callbacks are called in the right order
        // when responses arrive.
        synchronized (this) {
            handler.addStreamCallback(callback);
            channel.writeAndFlush(new StreamRequest(streamId)).addListener(new ChannelFutureListener() {
                @Override
                public void operationComplete(ChannelFuture future) throws Exception {
                    if (future.isSuccess()) {
                        long timeTaken = System.currentTimeMillis() - startTime;
                        logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr, timeTaken);
                    } else {
                        String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId,
                                serverAddr, future.cause());
                        logger.error(errorMsg, future.cause());
                        channel.close();
                        try {
                            callback.onFailure(streamId, new IOException(errorMsg, future.cause()));
                        } catch (Exception e) {
                            logger.error("Uncaught exception in RPC response callback handler!", e);
                        }
                    }
                }
            });
        }
    }

    /**
     * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
     * with the server's response or upon any failure.
     *
     * @param message The message to send.
     * @return The RPC's id.
     */
    public ByteBuffer sendRpcSyncSafely(ByteBuffer message, long timeoutMs) {
        final String serverAddr = NettyUtils.getRemoteAddress(channel);
        final long startTime = System.currentTimeMillis();
        logger.trace("Sending RPC to {}", serverAddr);

        final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());

        final SettableFuture<ByteBuffer> result = SettableFuture.create();
        final RpcResponseCallback callback = new RpcResponseCallback() {
            @Override
            public void onSuccess(ByteBuffer response) {
                result.set(response);
            }

            @Override
            public void onFailure(Throwable e) {
                result.setException(e);
            }
        };

        handler.addRpcRequest(requestId, callback);

        try {
            synchronized (this) {

                channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
                        .addListener(new ChannelFutureListener() {
                            @Override
                            public void operationComplete(ChannelFuture future) throws Exception {
                                if (future.isSuccess()) {
                                    long timeTaken = System.currentTimeMillis() - startTime;
                                    logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr,
                                            timeTaken);
                                } else {
                                    String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
                                            serverAddr, future.cause());
                                    logger.error(errorMsg, future.cause());
                                    handler.removeRpcRequest(requestId);
                                    channel.close();
                                    try {
                                        callback.onFailure(new IOException(errorMsg, future.cause()));
                                    } catch (Exception e) {
                                        logger.error("Uncaught exception in RPC response callback handler!", e);
                                    }
                                }
                            }
                        }).get(timeoutMs, TimeUnit.MILLISECONDS);
            }
            return result.get(timeoutMs - (System.currentTimeMillis() - startTime), TimeUnit.MILLISECONDS);
        } catch (ExecutionException e) {
            channel.close();
            throw Throwables.propagate(e.getCause());
        } catch (Exception e) {
            channel.close();
            throw Throwables.propagate(e);
        }

    }

    /**
     * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
     * with the server's response or upon any failure.
     *
     * @param message  The message to send.
     * @param callback Callback to handle the RPC's reply.
     * @return The RPC's id.
     */
    public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) {
        final String serverAddr = NettyUtils.getRemoteAddress(channel);
        final long startTime = System.currentTimeMillis();
        logger.trace("Sending RPC to {}", serverAddr);

        final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
        handler.addRpcRequest(requestId, callback);

        channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
                .addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                        if (future.isSuccess()) {
                            long timeTaken = System.currentTimeMillis() - startTime;
                            logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken);
                        } else {
                            String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
                                    serverAddr, future.cause());
                            logger.error(errorMsg, future.cause());
                            handler.removeRpcRequest(requestId);
                            channel.close();
                            try {
                                callback.onFailure(new IOException(errorMsg, future.cause()));
                            } catch (Exception e) {
                                logger.error("Uncaught exception in RPC response callback handler!", e);
                            }
                        }
                    }
                });

        return requestId;
    }

    /**
     * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
     * a specified timeout for a response.
     */
    public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
        final SettableFuture<ByteBuffer> result = SettableFuture.create();

        sendRpc(message, new RpcResponseCallback() {
            @Override
            public void onSuccess(ByteBuffer response) {
                ByteBuffer copy = ByteBuffer.allocate(response.remaining());
                copy.put(response);
                // flip "copy" to make it readable
                copy.flip();
                result.set(copy);
            }

            @Override
            public void onFailure(Throwable e) {
                result.setException(e);
            }
        });

        try {
            return result.get(timeoutMs, TimeUnit.MILLISECONDS);
        } catch (ExecutionException e) {
            throw Throwables.propagate(e.getCause());
        } catch (Exception e) {
            throw Throwables.propagate(e);
        }
    }

    /**
     * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the
     * message, and no delivery guarantees are made.
     *
     * @param message The message to send.
     */
    public void send(ByteBuffer message) {
        channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
    }

    /**
     * Removes any state associated with the given RPC.
     *
     * @param requestId The RPC id returned by {@link #sendRpc(ByteBuffer, RpcResponseCallback)}.
     */
    public void removeRpcRequest(long requestId) {
        handler.removeRpcRequest(requestId);
    }

    /**
     * Mark this channel as having timed out.
     */
    public void timeOut() {
        this.timedOut = true;
    }

    @VisibleForTesting
    public TransportResponseHandler getHandler() {
        return handler;
    }

    @Override
    public void close() {
        // close is a local operation and should finish with milliseconds; timeout just to be safe
        channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
    }

    @Override
    public String toString() {
        return Objects.toStringHelper(this).add("remoteAdress", channel.remoteAddress()).add("clientId", clientId)
                .add("isActive", isActive()).toString();
    }
}