net.dongliu.prettypb.rpc.server.RequestHandler.java Source code

Java tutorial

Introduction

Here is the source code for net.dongliu.prettypb.rpc.server.RequestHandler.java

Source

/**
 *   Copyright 2010-2014 Peter Klauser
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 */
package net.dongliu.prettypb.rpc.server;

import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.compression.ZlibCodecFactory;
import io.netty.handler.codec.compression.ZlibWrapper;
import net.dongliu.prettypb.rpc.common.PeerInfo;
import net.dongliu.prettypb.rpc.protocol.ConnectErrorCode;
import net.dongliu.prettypb.rpc.protocol.ConnectRequest;
import net.dongliu.prettypb.rpc.protocol.ConnectResponse;
import net.dongliu.prettypb.rpc.protocol.WirePayload;
import net.dongliu.prettypb.rpc.utils.Handlers;
import net.dongliu.prettypb.runtime.ExtensionRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * The ServerConnectRequestHandler handles the receipt of ConnectRequest
 * client requests, and uses the RpcClientRegistry to try to
 * register new clients. If the RpcClientRegistry allows the connection,
 * this handler sends back a ConnectResponse to the client.
 *
 * @author Peter Klauser
 */
@Sharable
public class RequestHandler extends MessageToMessageDecoder<WirePayload> {

    private static Logger logger = LoggerFactory.getLogger(RequestHandler.class);

    private final PeerInfo serverPeer;
    private final RpcServiceRegistry rpcServiceRegistry;
    /**
     * extension for rpc messages
     */
    private final ExtensionRegistry extensionRegistry;

    private final ThreadPoolExecutor rpcServiceExecutor;
    private final RpcServerChannelRegistry rpcServerChannelRegistry;

    public RequestHandler(PeerInfo serverPeer, RpcServiceRegistry rpcServiceRegistry,
            ExtensionRegistry extensionRegistry, ThreadPoolExecutor rpcServiceExecutor,
            RpcServerChannelRegistry rpcServerChannelRegistry) {
        this.serverPeer = serverPeer;
        this.rpcServiceRegistry = rpcServiceRegistry;
        this.extensionRegistry = extensionRegistry;
        this.rpcServiceExecutor = rpcServiceExecutor;
        this.rpcServerChannelRegistry = rpcServerChannelRegistry;
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, WirePayload msg, List<Object> out) throws Exception {
        if (msg.hasConnectRequest()) {
            ConnectRequest connectRequest = msg.getConnectRequest();
            logger.info("Received ConnectRequest from {}:{}, use compress: {}.", connectRequest.getClientHostName(),
                    connectRequest.getClientPort(), connectRequest.isCompress());
            PeerInfo clientInfo = new PeerInfo(connectRequest.getClientHostName(), connectRequest.getClientPort(),
                    connectRequest.getClientPID());
            ConnectResponse connectResponse;

            RpcServerChannel rpcServerChannel = new RpcServerChannel(ctx.channel(), serverPeer, clientInfo,
                    connectRequest.isCompress());
            if (rpcServerChannelRegistry.registerRpcServerChannel(rpcServerChannel)) {
                connectResponse = new ConnectResponse();
                connectResponse.setCorrelationId(connectRequest.getCorrelationId());
                connectResponse.setServerPID(serverPeer.getPid());
                connectResponse.setCompress(connectRequest.isCompress());
                WirePayload payload = new WirePayload();
                payload.setConnectResponse(connectResponse);
                ctx.channel().writeAndFlush(payload);

                completePipeline(rpcServerChannel);
            } else {
                connectResponse = new ConnectResponse();
                connectResponse.setCorrelationId(connectRequest.getCorrelationId());
                connectResponse.setErrorCode(ConnectErrorCode.ALREADY_CONNECTED);
                WirePayload payload = new WirePayload();
                payload.setConnectResponse(connectResponse);

                logger.debug("Sending ConnectResponse({}). Already Connected.", connectResponse.getCorrelationId());
                ChannelFuture future = ctx.channel().writeAndFlush(payload);
                future.addListener(ChannelFutureListener.CLOSE); // close after write response.
            }
        } else {
            out.add(msg);
        }
    }

    private void completePipeline(RpcServerChannel rpcServerChannel) {
        ChannelPipeline p = rpcServerChannel.getChannel().pipeline();

        if (rpcServerChannel.isCompress()) {
            p.addBefore(Handlers.FRAME_DECODER, Handlers.COMPRESSOR,
                    ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP));
            p.addAfter(Handlers.COMPRESSOR, Handlers.DECOMPRESSOR,
                    ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP));
        }

        RpcServerHandler rpcServerHandler = new RpcServerHandler(rpcServerChannel, rpcServiceRegistry,
                rpcServiceExecutor, rpcServerChannelRegistry, extensionRegistry);
        p.addLast(Handlers.RPC_SERVER, rpcServerHandler);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        super.exceptionCaught(ctx, cause);
        logger.warn("Exception caught during RPC connection handshake.", cause);
        ctx.close();
    }

}