alluxio.grpc.GrpcSerializationUtils.java Source code

Java tutorial

Introduction

Here is the source code for alluxio.grpc.GrpcSerializationUtils.java

Source

/*
 * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0
 * (the "License"). You may not use this work except in compliance with the License, which is
 * available at www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied, as more fully set forth in the License.
 *
 * See the NOTICE file distributed with this work for information regarding copyright ownership.
 */

package alluxio.grpc;

import io.grpc.CallOptions;
import io.grpc.MethodDescriptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.internal.CompositeReadableBuffer;
import io.grpc.internal.ReadableBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Queue;

/**
 * Utilities for gRPC message serialization.
 */
public class GrpcSerializationUtils {
    public static final CallOptions.Key<MethodDescriptor> OVERRIDDEN_METHOD_DESCRIPTOR = CallOptions.Key
            .create("overridden method descriptor");

    private static final Logger LOG = LoggerFactory.getLogger(GrpcSerializationUtils.class);

    private static final int TAG_TYPE_BITS = 3;

    private static final String BUFFER_INPUT_STREAM_CLASS_NAME = "io.grpc.internal.ReadableBuffers$BufferInputStream";
    private static final String BUFFER_FIELD_NAME = "buffer";
    private static final String BUFFERS_FIELD_NAME = "buffers";
    private static final String NETTY_WRITABLE_BUFFER_CLASS_NAME = "io.grpc.netty.NettyWritableBuffer";
    private static final String NETTY_READABLE_BUFFER_CLASS_NAME = "io.grpc.netty.NettyReadableBuffer";
    private static final String BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME = "io.grpc.internal.MessageFramer$BufferChainOutputStream";
    private static final String BUFFER_LIST_FIELD_NAME = "bufferList";
    private static final String CURRENT_FIELD_NAME = "current";

    private static Constructor<?> sNettyWritableBufferConstructor;
    private static Field sBufferList;
    private static Field sCompositeBuffers = null;
    private static Field sCurrent;
    private static Field sReadableBufferField = null;
    private static Field sReadableByteBuf = null;
    private static boolean sZeroCopySendSupported = true;
    private static boolean sZeroCopyReceiveSupported = true;

    static {
        try {
            sReadableBufferField = getPrivateField(BUFFER_INPUT_STREAM_CLASS_NAME, BUFFER_FIELD_NAME);
        } catch (Exception e) {
            LOG.warn("Cannot get gRPC input stream buffer, zero copy send will be disabled.", e);
            sZeroCopySendSupported = false;
        }
        try {
            sNettyWritableBufferConstructor = getPrivateConstructor(NETTY_WRITABLE_BUFFER_CLASS_NAME,
                    ByteBuf.class);
            sBufferList = getPrivateField(BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME, BUFFER_LIST_FIELD_NAME);
            sCurrent = getPrivateField(BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME, CURRENT_FIELD_NAME);
            sCompositeBuffers = getPrivateField(CompositeReadableBuffer.class.getName(), BUFFERS_FIELD_NAME);
            sReadableByteBuf = getPrivateField(NETTY_READABLE_BUFFER_CLASS_NAME, BUFFER_FIELD_NAME);
        } catch (Exception e) {
            LOG.warn("Cannot get gRPC output stream buffer, zero copy receive will be disabled.", e);
            sZeroCopyReceiveSupported = false;
        }
    }

    private static Field getPrivateField(String className, String fieldName)
            throws NoSuchFieldException, ClassNotFoundException {
        Class<?> declaringClass = Class.forName(className);
        Field field = declaringClass.getDeclaredField(fieldName);
        field.setAccessible(true);
        return field;
    }

    private static Constructor<?> getPrivateConstructor(String className, Class<?>... parameterTypes)
            throws ClassNotFoundException, NoSuchMethodException {
        Class<?> declaringClass = Class.forName(className);
        Constructor<?> constructor = declaringClass.getDeclaredConstructor(parameterTypes);
        constructor.setAccessible(true);
        return constructor;
    }

    /**
     * Makes a gRPC tag for a field.
     *
     * @param fieldNumber field number
     * @param wireType wire type of the field
     * @return the gRPC tag
     */
    public static int makeTag(final int fieldNumber, final int wireType) {
        // This is a public version of WireFormat.makeTag.
        return (fieldNumber << TAG_TYPE_BITS) | wireType;
    }

    /**
     * Gets a buffer directly from a gRPC input stream.
     *
     * @param stream the input stream
     * @return the raw data buffer
     */
    public static ReadableBuffer getBufferFromStream(InputStream stream) {
        if (!sZeroCopyReceiveSupported || !stream.getClass().equals(sReadableBufferField.getDeclaringClass())) {
            return null;
        }
        try {
            return (ReadableBuffer) sReadableBufferField.get(stream);
        } catch (Exception e) {
            LOG.warn("Failed to get data buffer from stream.", e);
            return null;
        }
    }

    /**
     * Gets a Netty buffer directly from a gRPC ReadableBuffer.
     *
     * @param buffer the input buffer
     * @return the raw ByteBuf, or null if the ByteBuf cannot be extracted
     */
    public static ByteBuf getByteBufFromReadableBuffer(ReadableBuffer buffer) {
        if (!sZeroCopyReceiveSupported) {
            return null;
        }
        try {
            if (buffer instanceof CompositeReadableBuffer) {
                Queue<ReadableBuffer> buffers = (Queue<ReadableBuffer>) sCompositeBuffers.get(buffer);
                if (buffers.size() == 1) {
                    return getByteBufFromReadableBuffer(buffers.peek());
                } else {
                    CompositeByteBuf buf = PooledByteBufAllocator.DEFAULT.compositeBuffer();
                    for (ReadableBuffer readableBuffer : buffers) {
                        ByteBuf subBuffer = getByteBufFromReadableBuffer(readableBuffer);
                        if (subBuffer == null) {
                            return null;
                        }
                        buf.addComponent(true, subBuffer);
                    }
                    return buf;
                }
            } else if (buffer.getClass().equals(sReadableByteBuf.getDeclaringClass())) {
                return (ByteBuf) sReadableByteBuf.get(buffer);
            }
        } catch (Exception e) {
            LOG.warn("Failed to get data buffer from stream: {}.", e.getMessage());
            return null;
        }
        return null;
    }

    /**
     * Add the given buffers directly to the gRPC output stream.
     *
     * @param buffers the buffers to be added
     * @param stream the output stream
     * @return whether the buffers are added successfully
     */
    public static boolean addBuffersToStream(ByteBuf[] buffers, OutputStream stream) {
        if (!sZeroCopySendSupported || !stream.getClass().equals(sBufferList.getDeclaringClass())) {
            return false;
        }
        try {
            if (sCurrent.get(stream) != null) {
                return false;
            }
            for (ByteBuf buffer : buffers) {
                Object nettyBuffer = sNettyWritableBufferConstructor.newInstance(buffer);
                List list = (List) sBufferList.get(stream);
                list.add(nettyBuffer);
                buffer.retain();
                sCurrent.set(stream, nettyBuffer);
            }
            return true;
        } catch (Exception e) {
            LOG.warn("Failed to add data buffer to stream: {}.", e.getMessage());
            return false;
        }
    }

    /**
     * Creates a service definition that uses custom marshallers.
     *
     * @param service the service to intercept
     * @param marshallers a map that specifies which marshaller to use for each method
     * @return the new service definition
     */
    public static ServerServiceDefinition overrideMethods(final ServerServiceDefinition service,
            final Map<MethodDescriptor, MethodDescriptor> marshallers) {
        List<ServerMethodDefinition<?, ?>> newMethods = new ArrayList<ServerMethodDefinition<?, ?>>();
        List<MethodDescriptor<?, ?>> newDescriptors = new ArrayList<MethodDescriptor<?, ?>>();
        // intercepts the descriptors
        for (final ServerMethodDefinition<?, ?> definition : service.getMethods()) {
            ServerMethodDefinition<?, ?> newMethod = interceptMethod(definition, marshallers);
            newDescriptors.add(newMethod.getMethodDescriptor());
            newMethods.add(newMethod);
        }
        // builds the new service descriptor
        final ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition
                .builder(new ServiceDescriptor(service.getServiceDescriptor().getName(), newDescriptors));
        // creates the new service definition
        for (ServerMethodDefinition<?, ?> definition : newMethods) {
            serviceBuilder.addMethod(definition);
        }
        return serviceBuilder.build();
    }

    private static <ReqT, RespT> ServerMethodDefinition<ReqT, RespT> interceptMethod(
            final ServerMethodDefinition<ReqT, RespT> definition,
            final Map<MethodDescriptor, MethodDescriptor> newMethods) {
        MethodDescriptor<ReqT, RespT> descriptor = definition.getMethodDescriptor();
        MethodDescriptor newMethod = newMethods.get(descriptor);
        if (newMethod != null) {
            return ServerMethodDefinition.create(newMethod, definition.getServerCallHandler());
        }
        return definition;
    }
}