com.chicm.cmraft.rpc.PacketUtils.java Source code

Java tutorial

Introduction

Here is the source code for com.chicm.cmraft.rpc.PacketUtils.java

Source

/**
* Copyright 2014 The CmRaft Project
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.chicm.cmraft.rpc;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.Channels;
import java.nio.channels.SocketChannel;
import java.util.concurrent.ExecutionException;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.chicm.cmraft.protobuf.generated.RaftProtos.RequestHeader;
import com.chicm.cmraft.protobuf.generated.RaftProtos.ResponseHeader;
import com.google.protobuf.BlockingService;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Message;
import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.Message.Builder;

/**
 * Static utility methods dealing with protobuf RPC packets. 
 * @author chicm
 *
 */
public class PacketUtils {
    static final Log LOG = LogFactory.getLog(PacketUtils.class);
    static final int DEFAULT_BYTEBUFFER_SIZE = 1000;
    static final int MESSAGE_LENGHT_FIELD_SIZE = 4;
    static final int DEFAULT_CHANNEL_READ_RETRIES = 5;

    public static byte[] int2Bytes(int val) {
        byte[] b = new byte[4];
        for (int i = 3; i > 0; i--) {
            b[i] = (byte) val;
            val >>>= 8;
        }
        b[0] = (byte) val;
        return b;
    }

    public static int bytes2Int(byte[] bytes) {
        int n = 0;
        for (int i = 0; i < 4; i++) {
            n <<= 8;
            n ^= bytes[i] & 0xFF;
        }
        return n;
    }

    public static void writeIntToStream(int n, OutputStream os) throws IOException {
        byte[] b = int2Bytes(n);
        os.write(b);
    }

    public static int getTotalSizeofMessages(Message... messages) {
        int totalSize = 0;
        for (Message m : messages) {
            if (m == null)
                continue;
            totalSize += m.getSerializedSize();
            totalSize += CodedOutputStream.computeRawVarint32Size(m.getSerializedSize());
        }
        return totalSize;
    }

    public static int writeRpc(AsynchronousSocketChannel channel, Message header, Message body)
            throws IOException, InterruptedException, ExecutionException {
        int totalSize = getTotalSizeofMessages(header, body);
        return writeRpc(channel, header, body, totalSize);
    }

    private static int writeRpc(AsynchronousSocketChannel channel, Message header, Message body, int totalSize)
            throws IOException, InterruptedException, ExecutionException {
        // writing total size so that server can read all request data in one read
        //LOG.debug("total size:" + totalSize);
        long t = System.currentTimeMillis();

        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        writeIntToStream(totalSize, bos);

        header.writeDelimitedTo(bos);
        if (body != null)
            body.writeDelimitedTo(bos);

        bos.flush();
        byte[] b = bos.toByteArray();
        ByteBuffer buf = ByteBuffer.allocateDirect(totalSize + 4);
        buf.put(b);

        buf.flip();
        channel.write(buf).get();

        if (LOG.isTraceEnabled()) {
            LOG.trace("Write Rpc message to socket, takes " + (System.currentTimeMillis() - t) + " ms, size "
                    + totalSize);
            LOG.trace("message:" + body);
        }
        return totalSize;
    }

    private static int writeRpc_backup(SocketChannel channel, Message header, Message body, int totalSize)
            throws IOException {
        // writing total size so that server can read all request data in one read
        LOG.debug("total size:" + totalSize);
        long t = System.currentTimeMillis();
        OutputStream os = Channels.newOutputStream(channel);
        writeIntToStream(totalSize, os);
        header.writeDelimitedTo(os);
        if (body != null)
            body.writeDelimitedTo(os);
        os.flush();
        LOG.debug("" + (System.currentTimeMillis() - t) + " ms");
        LOG.debug("flushed:" + totalSize);
        return totalSize;
    }

    public static RpcCall parseRpcRequestFromChannel(AsynchronousSocketChannel channel, BlockingService service)
            throws InterruptedException, ExecutionException, IOException {
        RpcCall call = null;
        long t = System.currentTimeMillis();
        InputStream in = Channels.newInputStream(channel);
        byte[] datasize = new byte[MESSAGE_LENGHT_FIELD_SIZE];
        in.read(datasize);
        int nDataSize = bytes2Int(datasize);

        int len = 0;
        ByteBuffer buf = ByteBuffer.allocateDirect(nDataSize);
        for (; len < nDataSize;) {
            len += channel.read(buf).get();
        }
        if (len < nDataSize) {
            LOG.error("SOCKET READ FAILED, len:" + len);
            return call;
        }
        byte[] data = new byte[nDataSize];
        buf.flip();
        buf.get(data);
        int offset = 0;
        CodedInputStream cis = CodedInputStream.newInstance(data, offset, nDataSize - offset);
        int headerSize = cis.readRawVarint32();
        offset += cis.getTotalBytesRead();
        RequestHeader header = RequestHeader.newBuilder().mergeFrom(data, offset, headerSize).build();

        offset += headerSize;
        cis.skipRawBytes(headerSize);
        cis.resetSizeCounter();
        int bodySize = cis.readRawVarint32();
        offset += cis.getTotalBytesRead();
        //LOG.debug("header parsed:" + header.toString());

        MethodDescriptor md = service.getDescriptorForType().findMethodByName(header.getRequestName());
        Builder builder = service.getRequestPrototype(md).newBuilderForType();
        Message body = null;
        if (builder != null) {
            body = builder.mergeFrom(data, offset, bodySize).build();
            //LOG.debug("server : request parsed:" + body.toString());
        }
        call = new RpcCall(header.getId(), header, body, md);
        if (LOG.isTraceEnabled()) {
            LOG.trace("Parse Rpc request from socket: " + call.getCallId() + ", takes"
                    + (System.currentTimeMillis() - t) + " ms");
        }

        return call;
    }

    public static RpcCall parseRpcResponseFromChannel(AsynchronousSocketChannel channel, BlockingService service)
            throws InterruptedException, ExecutionException, IOException {
        RpcCall call = null;
        long t = System.currentTimeMillis();
        InputStream in = Channels.newInputStream(channel);
        byte[] datasize = new byte[MESSAGE_LENGHT_FIELD_SIZE];
        in.read(datasize);
        int nDataSize = bytes2Int(datasize);

        LOG.debug("message size: " + nDataSize);

        int len = 0;
        ByteBuffer buf = ByteBuffer.allocateDirect(nDataSize);
        for (; len < nDataSize;) {
            len += channel.read(buf).get();
        }
        if (len < nDataSize) {
            LOG.error("SOCKET READ FAILED, len:" + len);
            return call;
        }
        byte[] data = new byte[nDataSize];
        buf.flip();
        buf.get(data);
        int offset = 0;
        CodedInputStream cis = CodedInputStream.newInstance(data, offset, nDataSize - offset);
        int headerSize = cis.readRawVarint32();
        offset += cis.getTotalBytesRead();
        ResponseHeader header = ResponseHeader.newBuilder().mergeFrom(data, offset, headerSize).build();

        offset += headerSize;
        cis.skipRawBytes(headerSize);
        cis.resetSizeCounter();
        int bodySize = cis.readRawVarint32();
        offset += cis.getTotalBytesRead();

        MethodDescriptor md = service.getDescriptorForType().findMethodByName(header.getResponseName());
        Builder builder = service.getResponsePrototype(md).newBuilderForType();
        Message body = null;
        if (builder != null) {
            body = builder.mergeFrom(data, offset, bodySize).build();
        }
        call = new RpcCall(header.getId(), header, body, md);
        if (LOG.isTraceEnabled()) {
            LOG.trace("Parse Rpc response from socket: " + call.getCallId() + ", takes"
                    + (System.currentTimeMillis() - t) + " ms");
        }

        return call;
    }

}