com.github.sparkfy.network.server.TransportRequestHandler.java Source code

Java tutorial

Introduction

Here is the source code for com.github.sparkfy.network.server.TransportRequestHandler.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.sparkfy.network.server;

import com.github.sparkfy.network.buffer.ManagedBuffer;
import com.github.sparkfy.network.buffer.NioManagedBuffer;
import com.github.sparkfy.network.client.RpcResponseCallback;
import com.github.sparkfy.network.client.TransportClient;
import com.github.sparkfy.network.protocol.*;
import com.github.sparkfy.network.util.NettyUtils;
import com.google.common.base.Throwables;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.ByteBuffer;

/**
 * A handler that processes requests from clients and writes chunk data back. Each handler is
 * attached to a single Netty channel, and keeps track of which streams have been fetched via this
 * channel, in order to clean them up if the channel is terminated (see #channelUnregistered).
 *
 * The messages should have been processed by the pipeline setup by {@link TransportServer}.
 */
public class TransportRequestHandler extends MessageHandler<RequestMessage> {
    private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);

    /** The Netty channel that this handler is associated with. */
    private final Channel channel;

    /** Client on the same channel allowing us to talk back to the requester. */
    private final TransportClient reverseClient;

    /** Handles all RPC messages. */
    private final RpcHandler rpcHandler;

    /** Returns each chunk part of a stream. */
    private final StreamManager streamManager;

    public TransportRequestHandler(Channel channel, TransportClient reverseClient, RpcHandler rpcHandler) {
        this.channel = channel;
        this.reverseClient = reverseClient;
        this.rpcHandler = rpcHandler;
        this.streamManager = rpcHandler.getStreamManager();
    }

    @Override
    public void exceptionCaught(Throwable cause) {
        rpcHandler.exceptionCaught(cause, reverseClient);
    }

    @Override
    public void channelActive() {
        rpcHandler.channelActive(reverseClient);
    }

    @Override
    public void channelInactive() {
        if (streamManager != null) {
            try {
                streamManager.connectionTerminated(channel);
            } catch (RuntimeException e) {
                logger.error("StreamManager connectionTerminated() callback failed.", e);
            }
        }
        rpcHandler.channelInactive(reverseClient);
    }

    @Override
    public void handle(RequestMessage request) {
        if (request instanceof ChunkFetchRequest) {
            processFetchRequest((ChunkFetchRequest) request);
        } else if (request instanceof RpcRequest) {
            processRpcRequest((RpcRequest) request);
        } else if (request instanceof OneWayMessage) {
            processOneWayMessage((OneWayMessage) request);
        } else if (request instanceof StreamRequest) {
            processStreamRequest((StreamRequest) request);
        } else {
            throw new IllegalArgumentException("Unknown request type: " + request);
        }
    }

    private void processFetchRequest(final ChunkFetchRequest req) {
        final String client = NettyUtils.getRemoteAddress(channel);

        logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);

        ManagedBuffer buf;
        try {
            streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
            streamManager.registerChannel(channel, req.streamChunkId.streamId);
            buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
        } catch (Exception e) {
            logger.error(String.format("Error opening block %s for request from %s", req.streamChunkId, client), e);
            respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
            return;
        }

        respond(new ChunkFetchSuccess(req.streamChunkId, buf));
    }

    private void processStreamRequest(final StreamRequest req) {
        final String client = NettyUtils.getRemoteAddress(channel);
        ManagedBuffer buf;
        try {
            buf = streamManager.openStream(req.streamId);
        } catch (Exception e) {
            logger.error(String.format("Error opening stream %s for request from %s", req.streamId, client), e);
            respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e)));
            return;
        }

        if (buf != null) {
            respond(new StreamResponse(req.streamId, buf.size(), buf));
        } else {
            respond(new StreamFailure(req.streamId, String.format("Stream '%s' was not found.", req.streamId)));
        }
    }

    private void processRpcRequest(final RpcRequest req) {
        try {
            rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
                @Override
                public void onSuccess(ByteBuffer response) {
                    respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
                }

                @Override
                public void onFailure(Throwable e) {
                    respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
                }
            });
        } catch (Exception e) {
            logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
            respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
        } finally {
            req.body().release();
        }
    }

    private void processOneWayMessage(OneWayMessage req) {
        try {
            rpcHandler.receive(reverseClient, req.body().nioByteBuffer());
        } catch (Exception e) {
            logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
        } finally {
            req.body().release();
        }
    }

    /**
     * Responds to a single message with some Encodable object. If a failure occurs while sending,
     * it will be logged and the channel closed.
     */
    private void respond(final Encodable result) {
        final String remoteAddress = channel.remoteAddress().toString();

        channel.writeAndFlush(result).addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future) throws Exception {
                if (future.isSuccess()) {
                    logger.trace(String.format("Sent result %s to client %s", result, remoteAddress));
                } else {
                    logger.error(String.format("Error sending result %s to %s; closing connection", result,
                            remoteAddress), future.cause());
                    channel.close();
                }
            }
        });
    }
}