io.airlift.drift.transport.netty.HeaderMessageEncoding.java Source code

Java tutorial

Introduction

Here is the source code for io.airlift.drift.transport.netty.HeaderMessageEncoding.java

Source

/*
 * Copyright (C) 2013 Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License. You may obtain
 * a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package io.airlift.drift.transport.netty;

import com.google.common.collect.ImmutableMap;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.airlift.drift.TApplicationException;
import io.airlift.drift.protocol.TBinaryProtocol;
import io.airlift.drift.protocol.TCompactProtocol;
import io.airlift.drift.protocol.TFacebookCompactProtocol;
import io.airlift.drift.protocol.TProtocolFactory;
import io.airlift.drift.protocol.TTransportException;
import io.airlift.drift.transport.MethodMetadata;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;

import javax.annotation.concurrent.ThreadSafe;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.OptionalInt;

import static com.google.common.base.Verify.verify;
import static io.airlift.drift.TApplicationException.Type.BAD_SEQUENCE_ID;
import static java.nio.charset.StandardCharsets.UTF_8;

@ThreadSafe
class HeaderMessageEncoding implements MessageEncoding {
    private static final int HEADER_MAGIC = 0x0FFF;

    private static final int HEADER_SEQUENCE_ID_OFFSET = 4;

    private static final int PROTOCOL_BINARY = 0;
    private static final int PROTOCOL_COMPACT = 2;

    private static final int FLAG_SUPPORT_OUT_OF_ORDER = 1;

    private static final int NORMAL_HEADERS = 1;
    private static final int PERSISTENT_HEADERS = 1;

    private final TProtocolFactory protocolFactory;
    private final int protocolId;

    // GZip compression is currently not supported since this is rarely desired for an RPC system
    @SuppressFBWarnings("SS_SHOULD_BE_STATIC")
    private final boolean gzip = false;

    public HeaderMessageEncoding(TProtocolFactory protocolFactory) {
        if (protocolFactory instanceof TBinaryProtocol.Factory) {
            protocolId = PROTOCOL_BINARY;
        } else if (protocolFactory instanceof TCompactProtocol.Factory
                || protocolFactory instanceof TFacebookCompactProtocol.Factory) {
            protocolId = PROTOCOL_COMPACT;
        } else {
            throw new IllegalArgumentException("Unknown protocol: " + protocolFactory.getClass().getName());
        }
        this.protocolFactory = protocolFactory;
    }

    @Override
    public ByteBuf writeRequest(ByteBufAllocator allocator, int sequenceId, MethodMetadata method,
            List<Object> parameters, Map<String, String> headers) throws Exception {
        ByteBuf message = MessageEncoding.encodeRequest(protocolFactory, allocator, sequenceId, method, parameters);

        //
        // describe the encoding (Thrift protocol, compression info)
        ByteBuf encodingInfo = Unpooled.buffer(3);
        encodingInfo.writeByte(protocolId);
        // number of "transforms"
        encodingInfo.writeByte(gzip ? 1 : 0);
        if (gzip) {
            // Note this is actually a vint, but there are only 3 headers for now
            encodingInfo.writeByte(0x01);
        }

        // headers
        ByteBuf encodedHeaders = encodeHeaders(headers);

        // Padding - header size must be a multiple of 4
        int headerSize = encodingInfo.readableBytes() + encodedHeaders.readableBytes();
        ByteBuf padding = getPadding(headerSize);
        headerSize += padding.readableBytes();

        // frame header (magic, flags, sequenceId, headerSize
        ByteBuf frameHeader = Unpooled.buffer(12);
        frameHeader.writeShort(HEADER_MAGIC);
        frameHeader.writeShort(FLAG_SUPPORT_OUT_OF_ORDER);
        frameHeader.writeInt(sequenceId);
        frameHeader.writeShort(headerSize >> 2);

        return Unpooled.wrappedBuffer(frameHeader, encodingInfo, encodedHeaders, padding, message);
    }

    private static ByteBuf getPadding(int headerSize) {
        int paddingSize = 4 - headerSize % 4;
        ByteBuf padding = Unpooled.buffer(paddingSize);
        padding.writeZero(paddingSize);
        return padding;
    }

    private static ByteBuf encodeHeaders(Map<String, String> headers) {
        if (headers.isEmpty()) {
            return Unpooled.EMPTY_BUFFER;
        }

        // 1 bytes for header type, 5 for header count vint, and 5 for each header key and value length vint
        int estimatedSize = 1 + 5 + (headers.size() * 10);
        for (Entry<String, String> entry : headers.entrySet()) {
            // assume the key and value are ASCII
            estimatedSize += entry.getKey().length() + entry.getValue().length();
        }

        ByteBuf headersBuffer = Unpooled.buffer(estimatedSize);
        // non persistent header
        headersBuffer.writeByte(0x01);
        writeVint(headersBuffer, headers.size());
        for (Entry<String, String> entry : headers.entrySet()) {
            writeString(headersBuffer, entry.getKey());
            writeString(headersBuffer, entry.getValue());
        }
        return headersBuffer;
    }

    private static void writeString(ByteBuf out, String value) {
        byte[] bytes = value.getBytes(UTF_8);
        writeVint(out, bytes.length);
        out.writeBytes(bytes);
    }

    private static void writeVint(ByteBuf out, int n) {
        while (true) {
            if ((n & ~0x7F) == 0) {
                out.writeByte(n);
                return;
            }

            out.writeByte(n | 0x80);
            n >>>= 7;
        }
    }

    @Override
    public OptionalInt extractResponseSequenceId(ByteBuf buffer) {
        if (buffer.readableBytes() < HEADER_SEQUENCE_ID_OFFSET + Integer.BYTES) {
            return OptionalInt.empty();
        }
        return OptionalInt.of(buffer.getInt(buffer.readerIndex() + HEADER_SEQUENCE_ID_OFFSET));
    }

    @Override
    @SuppressFBWarnings("DLS_DEAD_LOCAL_STORE")
    public Object readResponse(ByteBuf buffer, int sequenceId, MethodMetadata method) throws Exception {
        short magic = buffer.readShort();
        verify(magic == HEADER_MAGIC, "Unexpected response header magic");
        short flags = buffer.readShort();
        verify(flags == 1, "Unexpected response header flags");

        int frameSequenceId = buffer.readInt();
        if (frameSequenceId != sequenceId) {
            throw new TApplicationException(BAD_SEQUENCE_ID,
                    method.getName() + " failed: out of sequence response");
        }
        int headerSize = buffer.readShort() << 2;

        ByteBuf messageHeader = buffer.readBytes(headerSize);
        int protocolId = messageHeader.readUnsignedByte();
        verify(protocolId == this.protocolId, "response protocol is different than request protocol");

        int numberOfTransforms = messageHeader.readUnsignedByte();
        verify(numberOfTransforms < 128, "Too many transforms for response");

        boolean gzipCompressed = false;
        for (int i = 0; i < numberOfTransforms; i++) {
            int transform = messageHeader.readUnsignedByte();
            verify(transform == 0x01, "Unsupported response transform");
            gzipCompressed = true;
        }

        // Currently we ignore response headers from the server because there is no API to fetch the headers
        Map<String, String> normalHeaders = decodeHeaders(NORMAL_HEADERS, messageHeader);
        Map<String, String> persistentHeaders = decodeHeaders(PERSISTENT_HEADERS, messageHeader);

        ByteBuf message = buffer.readBytes(buffer.readableBytes());
        if (gzipCompressed) {
            // todo decompress
            throw new TTransportException("gzip compression not implemented");
        }

        return MessageEncoding.decodeResponse(protocolFactory, message, sequenceId, method);
    }

    private static Map<String, String> decodeHeaders(int expectedHeadersType, ByteBuf messageHeader) {
        if (messageHeader.readableBytes() == 0) {
            return ImmutableMap.of();
        }

        byte headersType = messageHeader.readByte();
        if (headersType != expectedHeadersType) {
            return ImmutableMap.of();
        }

        ImmutableMap.Builder<String, String> headers = ImmutableMap.builder();
        int headerCount = readVarint(messageHeader);
        for (int i = 0; i < headerCount; i++) {
            String key = readString(messageHeader);
            String value = readString(messageHeader);
            headers.put(key, value);
        }
        return headers.build();
    }

    private static String readString(ByteBuf messageHeader) {
        int length = readVarint(messageHeader);
        return messageHeader.readBytes(length).toString(UTF_8);
    }

    private static int readVarint(ByteBuf messageHeader) {
        int result = 0;
        int shift = 0;

        while (true) {
            byte b = messageHeader.readByte();
            result |= (b & 0x7f) << shift;
            if ((b & 0x80) != 0x80) {
                break;
            }
            shift += 7;
        }

        return result;
    }
}