org.apache.drill.exec.rpc.RpcEncoder.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.drill.exec.rpc.RpcEncoder.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.drill.exec.rpc;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageEncoder;

import java.io.OutputStream;
import java.util.List;

import org.apache.drill.exec.proto.GeneralRPCProtos.CompleteRpcMessage;
import org.apache.drill.exec.proto.GeneralRPCProtos.RpcHeader;

import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.WireFormat;

/**
 * Converts an RPCMessage into wire format.
 */
class RpcEncoder extends MessageToMessageEncoder<OutboundRpcMessage> {
    final org.slf4j.Logger logger;

    static final int HEADER_TAG = makeTag(CompleteRpcMessage.HEADER_FIELD_NUMBER,
            WireFormat.WIRETYPE_LENGTH_DELIMITED);
    static final int PROTOBUF_BODY_TAG = makeTag(CompleteRpcMessage.PROTOBUF_BODY_FIELD_NUMBER,
            WireFormat.WIRETYPE_LENGTH_DELIMITED);
    static final int RAW_BODY_TAG = makeTag(CompleteRpcMessage.RAW_BODY_FIELD_NUMBER,
            WireFormat.WIRETYPE_LENGTH_DELIMITED);
    static final int HEADER_TAG_LENGTH = getRawVarintSize(HEADER_TAG);
    static final int PROTOBUF_BODY_TAG_LENGTH = getRawVarintSize(PROTOBUF_BODY_TAG);
    static final int RAW_BODY_TAG_LENGTH = getRawVarintSize(RAW_BODY_TAG);

    public RpcEncoder(String name) {
        this.logger = org.slf4j.LoggerFactory.getLogger(RpcEncoder.class.getCanonicalName() + "-" + name);
    }

    @Override
    protected void encode(ChannelHandlerContext ctx, OutboundRpcMessage msg, List<Object> out) throws Exception {
        if (RpcConstants.EXTRA_DEBUGGING) {
            logger.debug("Rpc Encoder called with msg {}", msg);
        }

        if (!ctx.channel().isOpen()) {
            //output.add(ctx.alloc().buffer(0));
            logger.debug("Channel closed, skipping encode.");
            msg.release();
            return;
        }

        try {
            if (RpcConstants.EXTRA_DEBUGGING) {
                logger.debug("Encoding outbound message {}", msg);
            }
            // first we build the RpcHeader
            RpcHeader header = RpcHeader.newBuilder() //
                    .setMode(msg.mode) //
                    .setCoordinationId(msg.coordinationId) //
                    .setRpcType(msg.rpcType).build();

            // figure out the full length
            int headerLength = header.getSerializedSize();
            int protoBodyLength = msg.pBody.getSerializedSize();
            int rawBodyLength = msg.getRawBodySize();
            int fullLength = //
                    HEADER_TAG_LENGTH + getRawVarintSize(headerLength) + headerLength + //
                            PROTOBUF_BODY_TAG_LENGTH + getRawVarintSize(protoBodyLength) + protoBodyLength; //

            if (rawBodyLength > 0) {
                fullLength += (RAW_BODY_TAG_LENGTH + getRawVarintSize(rawBodyLength) + rawBodyLength);
            }

            ByteBuf buf = ctx.alloc().buffer();
            OutputStream os = new ByteBufOutputStream(buf);
            CodedOutputStream cos = CodedOutputStream.newInstance(os);

            // write full length first (this is length delimited stream).
            cos.writeRawVarint32(fullLength);

            // write header
            cos.writeRawVarint32(HEADER_TAG);
            cos.writeRawVarint32(headerLength);
            header.writeTo(cos);

            // write protobuf body length and body
            cos.writeRawVarint32(PROTOBUF_BODY_TAG);
            cos.writeRawVarint32(protoBodyLength);
            msg.pBody.writeTo(cos);

            // if exists, write data body and tag.
            if (msg.getRawBodySize() > 0) {
                if (RpcConstants.EXTRA_DEBUGGING) {
                    logger.debug("Writing raw body of size {}", msg.getRawBodySize());
                }

                cos.writeRawVarint32(RAW_BODY_TAG);
                cos.writeRawVarint32(rawBodyLength);
                cos.flush(); // need to flush so that dbody goes after if cos is caching.

                CompositeByteBuf cbb = new CompositeByteBuf(buf.alloc(), true, msg.dBodies.length + 1);
                cbb.addComponent(buf);
                int bufLength = buf.readableBytes();
                for (ByteBuf b : msg.dBodies) {
                    cbb.addComponent(b);
                    bufLength += b.readableBytes();
                }
                cbb.writerIndex(bufLength);
                out.add(cbb);
            } else {
                cos.flush();
                out.add(buf);
            }

            if (RpcConstants.SOME_DEBUGGING) {
                logger.debug("Wrote message length {}:{} bytes (head:body).  Message: " + msg,
                        getRawVarintSize(fullLength), fullLength);
            }
            if (RpcConstants.EXTRA_DEBUGGING) {
                logger.debug("Sent message.  Ending writer index was {}.", buf.writerIndex());
            }
        } finally {
            // make sure to release Rpc Messages underlying byte buffers.
            //msg.release();
        }
    }

    /** Makes a tag value given a field number and wire type, copied from WireFormat since it isn't public.  */
    static int makeTag(final int fieldNumber, final int wireType) {
        return (fieldNumber << 3) | wireType;
    }

    public static int getRawVarintSize(int value) {
        int count = 0;
        while (true) {
            if ((value & ~0x7F) == 0) {
                count++;
                return count;
            } else {
                count++;
                value >>>= 7;
            }
        }
    }

}