com.linecorp.armeria.internal.grpc.GrpcMessageMarshaller.java Source code

Java tutorial

Introduction

Here is the source code for com.linecorp.armeria.internal.grpc.GrpcMessageMarshaller.java

Source

/*
 * Copyright 2017 LINE Corporation
 *
 * LINE Corporation licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package com.linecorp.armeria.internal.grpc;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

import java.io.IOException;
import java.io.InputStream;

import javax.annotation.Nullable;

import org.curioswitch.common.protobuf.json.MessageMarshaller;

import com.google.common.io.ByteStreams;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.UnsafeByteOperations;

import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.internal.grpc.ArmeriaMessageDeframer.ByteBufOrStream;

import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.PrototypeMarshaller;
import io.grpc.Status;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.buffer.CompositeByteBuf;

/**
 * Marshaller for gRPC method request or response messages to and from {@link ByteBuf}. Will attempt to use
 * optimized code paths for known message types, and otherwise delegates to the gRPC stub.
 */
public class GrpcMessageMarshaller<I, O> {

    private enum MessageType {
        UNKNOWN, PROTOBUF
    }

    private final ByteBufAllocator alloc;
    private final SerializationFormat serializationFormat;
    private final MethodDescriptor<I, O> method;
    @Nullable
    private final MessageMarshaller jsonMarshaller;
    private final MessageType requestType;
    private final MessageType responseType;
    private final boolean unsafeWrapDeserializedBuffer;

    public GrpcMessageMarshaller(ByteBufAllocator alloc, SerializationFormat serializationFormat,
            MethodDescriptor<I, O> method, @Nullable MessageMarshaller jsonMarshaller,
            boolean unsafeWrapDeserializedBuffer) {
        this.alloc = requireNonNull(alloc, "alloc");
        this.serializationFormat = requireNonNull(serializationFormat, "serializationFormat");
        this.method = requireNonNull(method, "method");
        this.unsafeWrapDeserializedBuffer = unsafeWrapDeserializedBuffer;
        checkArgument(!GrpcSerializationFormats.isJson(serializationFormat) || jsonMarshaller != null,
                "jsonMarshaller must be non-null when serializationFormat is JSON.");
        this.jsonMarshaller = jsonMarshaller;
        requestType = marshallerType(method.getRequestMarshaller());
        responseType = marshallerType(method.getResponseMarshaller());
    }

    public ByteBuf serializeRequest(I message) throws IOException {
        switch (requestType) {
        case PROTOBUF:
            return serializeProto((Message) message);
        default:
            final CompositeByteBuf out = alloc.compositeBuffer();
            try (ByteBufOutputStream os = new ByteBufOutputStream(out)) {
                ByteStreams.copy(method.streamRequest(message), os);
            }
            return out;
        }
    }

    public I deserializeRequest(ByteBufOrStream message) throws IOException {
        InputStream messageStream = message.stream();
        if (message.buf() != null) {
            try {
                switch (requestType) {
                case PROTOBUF:
                    final PrototypeMarshaller<I> marshaller = (PrototypeMarshaller<I>) method
                            .getRequestMarshaller();
                    // PrototypeMarshaller<I>.getMessagePrototype will always parse to I
                    @SuppressWarnings("unchecked")
                    final I msg = (I) deserializeProto(message.buf(), (Message) marshaller.getMessagePrototype());
                    return msg;
                default:
                    // Fallback to using the method's stream marshaller.
                    messageStream = new ByteBufInputStream(message.buf().retain(), true);
                    break;
                }
            } finally {
                if (!unsafeWrapDeserializedBuffer) {
                    message.buf().release();
                }
            }
        }
        try (InputStream msg = messageStream) {
            return method.parseRequest(msg);
        }
    }

    public ByteBuf serializeResponse(O message) throws IOException {
        switch (responseType) {
        case PROTOBUF:
            return serializeProto((Message) message);
        default:
            final CompositeByteBuf out = alloc.compositeBuffer();
            try (ByteBufOutputStream os = new ByteBufOutputStream(out)) {
                ByteStreams.copy(method.streamResponse(message), os);
            }
            return out;
        }
    }

    public O deserializeResponse(ByteBufOrStream message) throws IOException {
        InputStream messageStream = message.stream();
        if (message.buf() != null) {
            try {
                switch (responseType) {
                case PROTOBUF:
                    final PrototypeMarshaller<O> marshaller = (PrototypeMarshaller<O>) method
                            .getResponseMarshaller();
                    // PrototypeMarshaller<I>.getMessagePrototype will always parse to I
                    @SuppressWarnings("unchecked")
                    final O msg = (O) deserializeProto(message.buf(), (Message) marshaller.getMessagePrototype());
                    return msg;
                default:
                    // Fallback to using the method's stream marshaller.
                    messageStream = new ByteBufInputStream(message.buf().retain(), true);
                    break;
                }
            } finally {
                if (!unsafeWrapDeserializedBuffer) {
                    message.buf().release();
                }
            }
        }
        try (InputStream msg = messageStream) {
            return method.parseResponse(msg);
        }
    }

    private ByteBuf serializeProto(Message message) throws IOException {
        if (GrpcSerializationFormats.isProto(serializationFormat)) {
            final ByteBuf buf = alloc.buffer(message.getSerializedSize());
            boolean success = false;
            try {
                message.writeTo(CodedOutputStream.newInstance(buf.nioBuffer(0, buf.writableBytes())));
                buf.writerIndex(buf.capacity());
                success = true;
            } finally {
                if (!success) {
                    buf.release();
                }
            }
            return buf;
        }

        if (GrpcSerializationFormats.isJson(serializationFormat)) {
            final ByteBuf buf = alloc.buffer();
            boolean success = false;
            try (ByteBufOutputStream os = new ByteBufOutputStream(buf)) {
                jsonMarshaller.writeValue(message, os);
                success = true;
            } finally {
                if (!success) {
                    buf.release();
                }
            }
            return buf;
        }
        throw new IllegalStateException("Unknown serialization format: " + serializationFormat);
    }

    private Message deserializeProto(ByteBuf buf, Message prototype) throws IOException {
        if (GrpcSerializationFormats.isProto(serializationFormat)) {
            final CodedInputStream stream;
            if (unsafeWrapDeserializedBuffer) {
                stream = UnsafeByteOperations.unsafeWrap(buf.nioBuffer()).newCodedInput();
                stream.enableAliasing(true);
            } else {
                stream = CodedInputStream.newInstance(buf.nioBuffer());
            }
            try {
                final Message msg = prototype.getParserForType().parseFrom(stream);
                try {
                    stream.checkLastTagWas(0);
                } catch (InvalidProtocolBufferException e) {
                    e.setUnfinishedMessage(msg);
                    throw e;
                }
                return msg;
            } catch (InvalidProtocolBufferException e) {
                throw Status.INTERNAL.withDescription("Invalid protobuf byte sequence").withCause(e)
                        .asRuntimeException();
            }
        }

        if (GrpcSerializationFormats.isJson(serializationFormat)) {
            final Message.Builder builder = prototype.newBuilderForType();
            try (ByteBufInputStream is = new ByteBufInputStream(buf, /* releaseOnClose */ false)) {
                jsonMarshaller.mergeValue(is, builder);
            }
            return builder.build();
        }
        throw new IllegalStateException("Unknown serialization format: " + serializationFormat);
    }

    private static MessageType marshallerType(Marshaller<?> marshaller) {
        return marshaller instanceof PrototypeMarshaller ? MessageType.PROTOBUF : MessageType.UNKNOWN;
    }
}