Java tutorial
/* * Copyright 2017 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.server.grpc; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.linecorp.armeria.common.util.Functions.voidFunction; import static io.netty.util.AsciiString.c2b; import static java.util.Objects.requireNonNull; import java.io.IOException; import java.io.UncheckedIOException; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import javax.annotation.Nullable; import org.curioswitch.common.protobuf.json.MessageMarshaller; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; import com.google.common.base.Strings; import com.linecorp.armeria.common.DefaultHttpHeaders; import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpObject; import com.linecorp.armeria.common.HttpResponseWriter; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.grpc.GrpcSerializationFormats; import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.internal.grpc.ArmeriaMessageDeframer; import com.linecorp.armeria.internal.grpc.ArmeriaMessageDeframer.ByteBufOrStream; import com.linecorp.armeria.internal.grpc.ArmeriaMessageFramer; import com.linecorp.armeria.internal.grpc.GrpcHeaderNames; import com.linecorp.armeria.internal.grpc.GrpcLogUtil; import com.linecorp.armeria.internal.grpc.GrpcMessageMarshaller; import com.linecorp.armeria.internal.grpc.HttpStreamReader; import com.linecorp.armeria.internal.grpc.StatusMessageEscaper; import com.linecorp.armeria.internal.grpc.TransportStatusListener; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.unsafe.ByteBufHttpData; import com.linecorp.armeria.unsafe.grpc.GrpcUnsafeBufferUtil; import io.grpc.Codec; import io.grpc.Codec.Identity; import io.grpc.Compressor; import io.grpc.CompressorRegistry; import io.grpc.Decompressor; import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; import io.grpc.ServerCall; import io.grpc.Status; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.util.AsciiString; /** * Encapsulates the state of a single server call, reading messages from the client, passing to business logic * via {@link ServerCall.Listener}, and writing messages passed back to the response. */ class ArmeriaServerCall<I, O> extends ServerCall<I, O> implements ArmeriaMessageDeframer.Listener, TransportStatusListener { private static final Logger logger = LoggerFactory.getLogger(ArmeriaServerCall.class); @SuppressWarnings("rawtypes") private static final AtomicIntegerFieldUpdater<ArmeriaServerCall> pendingMessagesUpdater = AtomicIntegerFieldUpdater .newUpdater(ArmeriaServerCall.class, "pendingMessages"); // Only most significant bit of a byte is set. @VisibleForTesting static final byte TRAILERS_FRAME_HEADER = (byte) (1 << 7); private static final Metadata EMPTY_METADATA = new Metadata(); private static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); private final MethodDescriptor<I, O> method; private final HttpStreamReader messageReader; private final ArmeriaMessageFramer messageFramer; private final HttpResponseWriter res; private final CompressorRegistry compressorRegistry; private final DecompressorRegistry decompressorRegistry; private final ServiceRequestContext ctx; private final SerializationFormat serializationFormat; private final GrpcMessageMarshaller<I, O> marshaller; private final boolean unsafeWrapRequestBuffers; // Only set once. @Nullable private ServerCall.Listener<I> listener; @Nullable private final String clientAcceptEncoding; @Nullable private Compressor compressor; private boolean messageCompression; private boolean messageReceived; // state private volatile boolean cancelled; private volatile boolean clientStreamClosed; private volatile boolean listenerClosed; private boolean sendHeadersCalled; private boolean closeCalled; private volatile int pendingMessages; ArmeriaServerCall(HttpHeaders clientHeaders, MethodDescriptor<I, O> method, CompressorRegistry compressorRegistry, DecompressorRegistry decompressorRegistry, HttpResponseWriter res, int maxInboundMessageSizeBytes, int maxOutboundMessageSizeBytes, ServiceRequestContext ctx, SerializationFormat serializationFormat, MessageMarshaller jsonMarshaller, boolean unsafeWrapRequestBuffers) { requireNonNull(clientHeaders, "clientHeaders"); this.method = requireNonNull(method, "method"); this.ctx = requireNonNull(ctx, "ctx"); this.serializationFormat = requireNonNull(serializationFormat, "serializationFormat"); messageReader = new HttpStreamReader(requireNonNull(decompressorRegistry, "decompressorRegistry"), new ArmeriaMessageDeframer(this, maxInboundMessageSizeBytes, ctx.alloc()) .decompressor(clientDecompressor(clientHeaders, decompressorRegistry)), this); messageFramer = new ArmeriaMessageFramer(ctx.alloc(), maxOutboundMessageSizeBytes); this.res = requireNonNull(res, "res"); this.compressorRegistry = requireNonNull(compressorRegistry, "compressorRegistry"); clientAcceptEncoding = Strings.emptyToNull(clientHeaders.get(GrpcHeaderNames.GRPC_ACCEPT_ENCODING)); this.decompressorRegistry = requireNonNull(decompressorRegistry, "decompressorRegistry"); marshaller = new GrpcMessageMarshaller<>(ctx.alloc(), serializationFormat, method, jsonMarshaller, unsafeWrapRequestBuffers); this.unsafeWrapRequestBuffers = unsafeWrapRequestBuffers; res.completionFuture().handleAsync(voidFunction((unused, t) -> { if (!closeCalled) { // Closed by client, not by server. cancelled = true; close(Status.CANCELLED, EMPTY_METADATA); } }), ctx.contextAwareEventLoop()); } @Override public void request(int numMessages) { if (ctx.eventLoop().inEventLoop()) { messageReader.request(numMessages); } else { ctx.eventLoop().submit(() -> messageReader.request(numMessages)); } } @Override public void sendHeaders(Metadata unusedGrpcMetadata) { if (ctx.eventLoop().inEventLoop()) { doSendHeaders(unusedGrpcMetadata); } else { ctx.eventLoop().submit(() -> doSendHeaders(unusedGrpcMetadata)); } } private void doSendHeaders(Metadata unusedGrpcMetadata) { checkState(!sendHeadersCalled, "sendHeaders already called"); checkState(!closeCalled, "call is closed"); final HttpHeaders headers = HttpHeaders.of(HttpStatus.OK); headers.contentType(serializationFormat.mediaType()); if (compressor == null || !messageCompression || clientAcceptEncoding == null) { compressor = Codec.Identity.NONE; } else { final List<String> acceptedEncodingsList = ACCEPT_ENCODING_SPLITTER.splitToList(clientAcceptEncoding); if (!acceptedEncodingsList.contains(compressor.getMessageEncoding())) { // resort to using no compression. compressor = Codec.Identity.NONE; } } messageFramer.setCompressor(compressor); // Always put compressor, even if it's identity. headers.add(GrpcHeaderNames.GRPC_ENCODING, compressor.getMessageEncoding()); final String advertisedEncodings = String.join(",", decompressorRegistry.getAdvertisedMessageEncodings()); if (!advertisedEncodings.isEmpty()) { headers.add(GrpcHeaderNames.GRPC_ACCEPT_ENCODING, advertisedEncodings); } sendHeadersCalled = true; res.write(headers); } @Override public void sendMessage(O message) { pendingMessagesUpdater.incrementAndGet(this); if (ctx.eventLoop().inEventLoop()) { doSendMessage(message); } else { ctx.eventLoop().submit(() -> doSendMessage(message)); } } private void doSendMessage(O message) { checkState(sendHeadersCalled, "sendHeaders has not been called"); checkState(!closeCalled, "call is closed"); try { res.write(messageFramer.writePayload(marshaller.serializeResponse(message))); res.onDemand(() -> { if (pendingMessagesUpdater.decrementAndGet(this) == 0) { try { listener.onReady(); } catch (Throwable t) { close(Status.fromThrowable(t), EMPTY_METADATA); } } }); } catch (RuntimeException e) { close(Status.fromThrowable(e), EMPTY_METADATA); throw e; } catch (Throwable t) { close(Status.fromThrowable(t), EMPTY_METADATA); throw new RuntimeException(t); } } @Override public boolean isReady() { return !closeCalled && pendingMessages == 0; } @Override public void close(Status status, Metadata unusedGrpcMetadata) { if (ctx.eventLoop().inEventLoop()) { doClose(status, unusedGrpcMetadata); } else { ctx.eventLoop().submit(() -> doClose(status, unusedGrpcMetadata)); } } private void doClose(Status status, Metadata unusedGrpcMetadata) { checkState(!closeCalled, "call already closed"); closeCalled = true; if (cancelled) { // No need to write anything to client if cancelled already. closeListener(status); return; } final HttpHeaders trailers = statusToTrailers(status, sendHeadersCalled); final HttpObject trailersObj; if (sendHeadersCalled && GrpcSerializationFormats.isGrpcWeb(serializationFormat)) { // Normal trailers are not supported in grpc-web and must be encoded as a message. // Message compression is not supported in grpc-web, so we don't bother using the normal // ArmeriaMessageFramer. trailersObj = serializeTrailersAsMessage(trailers); } else { trailersObj = trailers; } try { res.write(trailersObj); res.close(); } finally { closeListener(status); } } @Override public boolean isCancelled() { return cancelled; } @Override public synchronized void setMessageCompression(boolean messageCompression) { messageFramer.setMessageCompression(messageCompression); this.messageCompression = messageCompression; } @Override public synchronized void setCompression(String compressorName) { checkState(!sendHeadersCalled, "sendHeaders has been called"); compressor = compressorRegistry.lookupCompressor(compressorName); checkArgument(compressor != null, "Unable to find compressor by name %s", compressorName); messageFramer.setCompressor(compressor); } @Override public MethodDescriptor<I, O> getMethodDescriptor() { return method; } @Override public void messageRead(ByteBufOrStream message) { final I request; boolean success = false; try { // Special case for unary calls. if (messageReceived && method.getType() == MethodType.UNARY) { closeListener(Status.INTERNAL .withDescription("More than one request messages for unary call or server streaming call")); return; } messageReceived = true; if (isCancelled()) { return; } success = true; } finally { if (message.buf() != null && !success) { message.buf().release(); } } try { request = marshaller.deserializeRequest(message); } catch (IOException e) { throw new UncheckedIOException(e); } if (unsafeWrapRequestBuffers && message.buf() != null) { GrpcUnsafeBufferUtil.storeBuffer(message.buf(), request, ctx); } try (SafeCloseable ignored = RequestContext.push(ctx)) { listener.onMessage(request); } catch (Throwable t) { close(Status.fromThrowable(t), EMPTY_METADATA); } } @Override public void endOfStream() { clientStreamClosed = true; if (!closeCalled) { try (SafeCloseable ignored = RequestContext.push(ctx)) { listener.onHalfClose(); } catch (Throwable t) { close(Status.fromThrowable(t), EMPTY_METADATA); } } } @Override public void transportReportStatus(Status status) { if (closeCalled) { // We've already called close on the server-side and will close the listener with the server-side // status, so we ignore client transport status's at this point (it's usually the RST_STREAM // corresponding to a successful stream ending in practice, but even if it was an actual transport // failure there's no need to notify the server listener of it). return; } closeListener(status); } private void closeListener(Status newStatus) { if (!listenerClosed) { listenerClosed = true; if (!clientStreamClosed) { messageReader().cancel(); clientStreamClosed = true; } messageFramer.close(); ctx.logBuilder().responseContent(GrpcLogUtil.rpcResponse(newStatus), null); if (newStatus.isOk()) { try (SafeCloseable ignored = RequestContext.push(ctx)) { listener.onComplete(); } catch (Throwable t) { // This should not be possible with normal generated stubs which do not implement // onComplete, but is conceivable for a completely manually constructed stub. logger.warn("Error in gRPC onComplete handler.", t); } } else { cancelled = true; try (SafeCloseable ignored = RequestContext.push(ctx)) { listener.onCancel(); } catch (Throwable t) { if (!closeCalled) { // A custom error when dealing with client cancel or transport issues should be // returned. We have already closed the listener, so it will not receive any more // callbacks as designed. close(Status.fromThrowable(t), EMPTY_METADATA); } } // Transport error, not business logic error, so reset the stream. if (!closeCalled) { res.close(newStatus.asException()); } } } } static HttpHeaders statusToTrailers(Status status, boolean headersSent) { final HttpHeaders trailers; if (headersSent) { // Normal trailers. trailers = new DefaultHttpHeaders(); } else { // Trailers only response trailers = new DefaultHttpHeaders(true, 3, true).status(HttpStatus.OK).set(HttpHeaderNames.CONTENT_TYPE, "application/grpc+proto"); } trailers.add(GrpcHeaderNames.GRPC_STATUS, Integer.toString(status.getCode().value())); if (status.getDescription() != null) { trailers.add(GrpcHeaderNames.GRPC_MESSAGE, StatusMessageEscaper.escape(status.getDescription())); } return trailers; } HttpStreamReader messageReader() { return messageReader; } void setListener(Listener<I> listener) { checkState(this.listener == null, "listener already set"); this.listener = requireNonNull(listener, "listener"); } private HttpData serializeTrailersAsMessage(HttpHeaders trailers) { final ByteBuf serialized = ctx.alloc().buffer(); boolean success = false; try { serialized.writeByte(TRAILERS_FRAME_HEADER); // Skip, we'll set this after serializing the headers. serialized.writeInt(0); for (Map.Entry<AsciiString, String> trailer : trailers) { encodeHeader(trailer.getKey(), trailer.getValue(), serialized); } final int messageSize = serialized.readableBytes() - 5; serialized.setInt(1, messageSize); success = true; } finally { if (!success) { serialized.release(); } } return new ByteBufHttpData(serialized, true); } private static Decompressor clientDecompressor(HttpHeaders headers, DecompressorRegistry registry) { final String encoding = headers.get(GrpcHeaderNames.GRPC_ENCODING); if (encoding == null) { return Identity.NONE; } final Decompressor decompressor = registry.lookupDecompressor(encoding); return firstNonNull(decompressor, Identity.NONE); } // Copied from io.netty.handler.codec.http.HttpHeadersEncoder private static void encodeHeader(CharSequence name, CharSequence value, ByteBuf buf) { final int nameLen = name.length(); final int valueLen = value.length(); final int entryLen = nameLen + valueLen + 4; buf.ensureWritable(entryLen); int offset = buf.writerIndex(); writeAscii(buf, offset, name, nameLen); offset += nameLen; buf.setByte(offset++, ':'); buf.setByte(offset++, ' '); writeAscii(buf, offset, value, valueLen); offset += valueLen; buf.setByte(offset++, '\r'); buf.setByte(offset++, '\n'); buf.writerIndex(offset); } private static void writeAscii(ByteBuf buf, int offset, CharSequence value, int valueLen) { if (value instanceof AsciiString) { ByteBufUtil.copy((AsciiString) value, 0, buf, offset, valueLen); } else { writeCharSequence(buf, offset, value, valueLen); } } private static void writeCharSequence(ByteBuf buf, int offset, CharSequence value, int valueLen) { for (int i = 0; i < valueLen; ++i) { buf.setByte(offset++, c2b(value.charAt(i))); } } }