com.chicm.cmraft.rpc.RpcClient.java Source code

Java tutorial

Introduction

Here is the source code for com.chicm.cmraft.rpc.RpcClient.java

Source

/**
* Copyright 2014 The CmRaft Project
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.chicm.cmraft.rpc;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;

import java.io.IOException;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.chicm.cmraft.common.CmRaftConfiguration;
import com.chicm.cmraft.common.Configuration;
import com.chicm.cmraft.common.ServerInfo;
import com.chicm.cmraft.protobuf.generated.RaftProtos.RaftService;
import com.chicm.cmraft.protobuf.generated.RaftProtos.RequestHeader;
import com.chicm.cmraft.protobuf.generated.RaftProtos.RaftService.BlockingInterface;
import com.chicm.cmraft.protobuf.generated.RaftProtos.TestRpcRequest;
import com.chicm.cmraft.util.BlockingHashMap;
import com.google.common.base.Preconditions;
import com.google.protobuf.BlockingRpcChannel;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import com.google.protobuf.RpcController;
import com.google.protobuf.ServiceException;
import com.google.protobuf.Descriptors.MethodDescriptor;

/**
 * RpcClient implements the BlockingRpcChannel interface with inner class. It translate RPC method calls to 
 * RPC request packets and send them to RPC server. Then translate RPC response packets from
 * RPC server to returned objects for RPC method calls.
 * 
 * @author chicm
 *
 */
public class RpcClient {
    static final Log LOG = LogFactory.getLog(RpcClient.class);
    private final static String RPC_TIMEOUT_KEY = "raft.rpc.timeout";
    private final static int DEFAULT_RPC_TIMEOUT = 3000;
    private static volatile AtomicInteger client_call_id = new AtomicInteger(0);
    private BlockingInterface stub = null;
    private ChannelHandlerContext ctx = null;
    private BlockingHashMap<Integer, RpcCall> responsesMap = new BlockingHashMap<>();
    private RpcClientEventListener listener = new RpcClientEventListenerImpl();
    private volatile boolean connected = false;
    private int rpcTimeout;
    private ServerInfo remoteServer = null;

    public RpcClient(Configuration conf, ServerInfo remoteServer) {
        rpcTimeout = conf.getInt(RPC_TIMEOUT_KEY, DEFAULT_RPC_TIMEOUT);
        this.remoteServer = remoteServer;
        //todo: to change call id init value
        Random r = new Random();
        client_call_id.set(r.nextInt(1000) * 100);
    }

    public boolean isConnected() {
        if (ctx == null)
            return false;
        if (!ctx.channel().isActive())
            return false;
        return connected;
    }

    public synchronized boolean connect() throws IOException, InterruptedException, ExecutionException {

        if (isConnected())
            return true;
        try {
            ctx = connectRemoteServer();
        } catch (Exception e) {
            LOG.error("Failed connecting to:" + getRemoteServer() + " : " + e.getMessage());
            try {
                if (ctx != null && ctx.channel().isOpen()) {
                    ctx.close().sync();
                }
            } catch (Exception e2) {
                LOG.error("Failed closing ctx, " + e2.getMessage());
            }
            throw e;
        }

        BlockingRpcChannel c = createBlockingRpcChannel();
        stub = RaftService.newBlockingStub(c);

        connected = true;
        return connected;
    }

    public ServerInfo getRemoteServer() {
        return remoteServer;
    }

    public synchronized void close() {
        try {
            LOG.info("Closing connection");
            ctx.close().sync();
            connected = false;
        } catch (Exception e) {
            LOG.error("Closing failed", e);
        }
    }

    public BlockingInterface getStub() throws Exception {
        if (!isConnected()) {
            if (!connect()) {
                return null;
            }
        }
        return stub;
    }

    private ChannelHandlerContext connectRemoteServer() throws InterruptedException {
        EventLoopGroup workerGroup = new NioEventLoopGroup();

        try {
            ClientChannelHandler channelHandler = new ClientChannelHandler(listener);
            Bootstrap b = new Bootstrap();
            b.group(workerGroup);
            b.channel(NioSocketChannel.class);
            b.option(ChannelOption.SO_KEEPALIVE, true);
            b.handler(channelHandler);

            ChannelFuture f = b.connect(getRemoteServer().getHost(), getRemoteServer().getPort()).sync();
            LOG.debug("connected to: " + this.getRemoteServer());
            return channelHandler.getCtx();
            // Wait until the connection is closed.
            //f.channel().closeFuture().sync();
        } finally {
            //workerGroup.shutdownGracefully();
        }
    }

    public static int generateCallId() {
        return client_call_id.incrementAndGet();
    }

    public static int getCallId() {
        return client_call_id.get();
    }

    private BlockingRpcChannel createBlockingRpcChannel() {
        return new BlockingRpcChannelImplementation();
    }

    class BlockingRpcChannelImplementation implements BlockingRpcChannel {
        @Override
        public Message callBlockingMethod(MethodDescriptor md, RpcController controller, Message request,
                Message returnType) throws ServiceException {

            Message response = null;
            int callId = generateCallId();
            try {
                RequestHeader.Builder builder = RequestHeader.newBuilder();
                builder.setId(callId);
                builder.setRequestName(md.getName());
                RequestHeader header = builder.build();

                LOG.debug("SENDING RPC, CALLID:" + header.getId());
                RpcCall call = new RpcCall(callId, header, request, md);
                long tm = System.currentTimeMillis();

                ctx.writeAndFlush(call);

                RpcCall result = responsesMap.take(callId, rpcTimeout);

                response = result != null ? result.getMessage() : null;
                if (response != null) {
                    LOG.debug("response taken: " + callId);
                    LOG.debug(String.format("RPC[%d] round trip takes %d ms", header.getId(),
                            (System.currentTimeMillis() - tm)));
                }
            } catch (RpcTimeoutException e) {
                LOG.error("Rpc Timeout, call ID:" + callId + ", remote server:" + getRemoteServer());
                LOG.error("Rpc Timeout, call:" + request);
                LOG.error("Rpc Timeout", e);
                ServiceException se = new ServiceException(e.getMessage(), e);
                throw se;
            } catch (Exception e) {
                LOG.error("ctx:" + ctx);
                LOG.error("callBlockingMethod exception", e);
                throw e;
            }
            return response;
        }
    }

    class RpcClientEventListenerImpl implements RpcClientEventListener {
        @Override
        public void channelClosed() {
            ctx.close();
            connected = false;
        }

        @Override
        public void onRpcResponse(RpcCall call) {
            Preconditions.checkNotNull(call);
            responsesMap.put(call.getCallId(), call);
        }
    }

    /*
     * For testing purpose
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        if (args.length < 3) {
            System.out.println(
                    "usage: RpcClient <server host> <server port> <clients number> <threads number> <packetsize>");
            return;
        }
        String host = args[0];
        int port = Integer.parseInt(args[1]);
        int nclients = Integer.parseInt(args[2]);
        int nThreads = Integer.parseInt(args[3]);
        int nPacketSize = 1024;

        if (args.length >= 5) {
            nPacketSize = Integer.parseInt(args[4]);
        }

        for (int j = 0; j < nclients; j++) {
            RpcClient client = new RpcClient(CmRaftConfiguration.create(), new ServerInfo(host, port));

            for (int i = 0; i < nThreads; i++) {
                new Thread(new TestRpcWorker(client, nPacketSize)).start();
            }
        }
    }

    static class TestRpcWorker implements Runnable {
        private RpcClient client;
        private int packetSize;

        public TestRpcWorker(RpcClient client, int size) {
            this.client = client;
            this.packetSize = size;
        }

        @Override
        public void run() {
            client.sendRequest(packetSize);
        }
    }

    public void testRpc(int packetSize) throws Exception {
        TestRpcRequest.Builder builder = TestRpcRequest.newBuilder();
        byte[] bytes = new byte[packetSize];
        builder.setData(ByteString.copyFrom(bytes));

        stub.testRpc(null, builder.build());
    }

    private ThreadLocal<Long> startTime = new ThreadLocal<>();

    /*
     * For testing purpose
     */
    public void sendRequest(int packetSize) {

        if (!this.isConnected()) {
            try {
                if (!connect()) {
                    LOG.error("INIT error");
                    return;
                }
            } catch (Exception e) {
                LOG.error("RpcClient init exception", e);
                return;
            }
        }

        LOG.info("client thread started");
        long starttime = System.currentTimeMillis();
        try {
            for (int i = 0; i < 5000000; i++) {
                startTime.set(System.currentTimeMillis());
                testRpc(packetSize);
                if (i != 0 && i % 1000 == 0) {
                    long ms = System.currentTimeMillis() - starttime;
                    LOG.debug("RPC CALL[ " + i + "] round trip time: " + ms);

                    long curtm = System.currentTimeMillis();
                    long elipsetm = (curtm - starttime) / 1000;
                    if (elipsetm == 0)
                        elipsetm = 1;
                    long tps = i / elipsetm;

                    LOG.info("response id: " + i + " time: " + elipsetm + " TPS: " + tps);
                }
            }
        } catch (Exception e) {
            e.printStackTrace(System.out);
        }
    }
}