io.grpc.netty.NettyClientStreamTest.java Source code

Java tutorial

Introduction

Here is the source code for io.grpc.netty.NettyClientStreamTest.java

Source

/*
 * Copyright 2014 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.netty;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.STATUS_OK;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableListMultimap;
import com.google.common.io.BaseEncoding;
import io.grpc.CallOptions;
import io.grpc.InternalStatus;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.StreamListener;
import io.grpc.internal.TransportTracer;
import io.grpc.netty.WriteQueue.QueuedCommand;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.util.AsciiString;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.LinkedList;
import java.util.Queue;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/**
 * Tests for {@link NettyClientStream}.
 */
@RunWith(JUnit4.class)
public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream> {
    @Mock
    protected ClientStreamListener listener;

    @Mock
    protected NettyClientHandler handler;

    @SuppressWarnings("unchecked")
    private MethodDescriptor.Marshaller<Void> marshaller = mock(MethodDescriptor.Marshaller.class);
    private final Queue<InputStream> listenerMessageQueue = new LinkedList<>();

    // Must be initialized before @Before, because it is used by createStream()
    private MethodDescriptor<?, ?> methodDescriptor = MethodDescriptor.<Void, Void>newBuilder()
            .setType(MethodDescriptor.MethodType.UNARY).setFullMethodName("testService/test")
            .setRequestMarshaller(marshaller).setResponseMarshaller(marshaller).build();

    private final TransportTracer transportTracer = new TransportTracer();

    /** Set up for test. */
    @Before
    @Override
    public void setUp() {
        super.setUp();

        doAnswer(new Answer<Void>() {
            @Override
            public Void answer(InvocationOnMock invocation) throws Throwable {
                StreamListener.MessageProducer producer = (StreamListener.MessageProducer) invocation
                        .getArguments()[0];
                InputStream message;
                while ((message = producer.next()) != null) {
                    listenerMessageQueue.add(message);
                }
                return null;
            }
        }).when(listener).messagesAvailable(ArgumentMatchers.<StreamListener.MessageProducer>any());
    }

    @Override
    protected ClientStreamListener listener() {
        return listener;
    }

    @Override
    protected Queue<InputStream> listenerMessageQueue() {
        return listenerMessageQueue;
    }

    @Test
    public void closeShouldSucceed() {
        // Force stream creation.
        stream().transportState().setId(STREAM_ID);
        stream().halfClose();
        verifyNoMoreInteractions(listener);
    }

    @Test
    public void cancelShouldSendCommand() {
        // Set stream id to indicate it has been created
        stream().transportState().setId(STREAM_ID);
        stream().cancel(Status.CANCELLED);
        ArgumentCaptor<CancelClientStreamCommand> commandCaptor = ArgumentCaptor
                .forClass(CancelClientStreamCommand.class);
        verify(writeQueue).enqueue(commandCaptor.capture(), eq(true));
        assertEquals(commandCaptor.getValue().reason(), Status.CANCELLED);
    }

    @Test
    public void deadlineExceededCancelShouldSendCommand() {
        // Set stream id to indicate it has been created
        stream().transportState().setId(STREAM_ID);
        stream().cancel(Status.DEADLINE_EXCEEDED);
        ArgumentCaptor<CancelClientStreamCommand> commandCaptor = ArgumentCaptor
                .forClass(CancelClientStreamCommand.class);
        verify(writeQueue).enqueue(commandCaptor.capture(), eq(true));
        assertEquals(commandCaptor.getValue().reason(), Status.DEADLINE_EXCEEDED);
    }

    @Test
    public void cancelShouldStillSendCommandIfStreamNotCreatedToCancelCreation() {
        stream().cancel(Status.CANCELLED);
        verify(writeQueue).enqueue(isA(CancelClientStreamCommand.class), eq(true));
    }

    @Test
    public void writeMessageShouldSendRequest() throws Exception {
        // Force stream creation.
        stream().transportState().setId(STREAM_ID);
        byte[] msg = smallMessage();
        stream.writeMessage(new ByteArrayInputStream(msg));
        stream.flush();
        verify(writeQueue).enqueue(
                eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE), false)), eq(true));
    }

    @Test
    public void writeMessageShouldSendRequestUnknownLength() throws Exception {
        // Force stream creation.
        stream().transportState().setId(STREAM_ID);
        byte[] msg = smallMessage();
        stream.writeMessage(new BufferedInputStream(new ByteArrayInputStream(msg)));
        stream.flush();
        // Two writes occur, one for the GRPC frame header and the second with the payload
        // The framer reports the message count when the payload is completely written
        verify(writeQueue).enqueue(
                eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE).slice(0, 5), false)),
                eq(false));
        verify(writeQueue).enqueue(
                eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE).slice(5, 11), false)),
                eq(true));
    }

    @Test
    public void setStatusWithOkShouldCloseStream() {
        stream().transportState().setId(STREAM_ID);
        stream().transportState().transportReportStatus(Status.OK, true, new Metadata());
        verify(listener).closed(same(Status.OK), same(PROCESSED), any(Metadata.class));
    }

    @Test
    public void setStatusWithErrorShouldCloseStream() {
        Status errorStatus = Status.INTERNAL;
        stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
        verify(listener).closed(eq(errorStatus), same(PROCESSED), any(Metadata.class));
    }

    @Test
    public void setStatusWithOkShouldNotOverrideError() {
        Status errorStatus = Status.INTERNAL;
        stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
        stream().transportState().transportReportStatus(Status.OK, true, new Metadata());
        verify(listener).closed(any(Status.class), same(PROCESSED), any(Metadata.class));
    }

    @Test
    public void setStatusWithErrorShouldNotOverridePreviousError() {
        Status errorStatus = Status.INTERNAL;
        stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
        stream().transportState().transportReportStatus(Status.fromThrowable(new RuntimeException("fake")), true,
                new Metadata());
        verify(listener).closed(any(Status.class), same(PROCESSED), any(Metadata.class));
    }

    @Override
    @Test
    public void inboundMessageShouldCallListener() throws Exception {
        // Receive headers first so that it's a valid GRPC response.
        stream().transportState().setId(STREAM_ID);
        stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
        super.inboundMessageShouldCallListener();
    }

    @Test
    public void inboundHeadersShouldCallListenerHeadersRead() throws Exception {
        stream().transportState().setId(STREAM_ID);
        Http2Headers headers = grpcResponseHeaders();
        stream().transportState().transportHeadersReceived(headers, false);
        verify(listener).headersRead(any(Metadata.class));
    }

    @Test
    public void inboundTrailersClosesCall() throws Exception {
        stream().transportState().setId(STREAM_ID);
        stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
        super.inboundMessageShouldCallListener();
        stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.OK), true);
    }

    @Test
    public void inboundTrailersBeforeHalfCloseSendsRstStream() {
        stream().transportState().setId(STREAM_ID);
        stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
        stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.OK), true);

        // Verify a cancel stream with reason=null is sent to the handler.
        ArgumentCaptor<CancelClientStreamCommand> captor = ArgumentCaptor.forClass(CancelClientStreamCommand.class);
        verify(writeQueue).enqueue(captor.capture(), eq(true));
        assertNull(captor.getValue().reason());
    }

    @Test
    public void inboundTrailersAfterHalfCloseDoesNotSendRstStream() {
        stream().transportState().setId(STREAM_ID);
        stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
        stream.halfClose();
        stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.OK), true);
        verify(writeQueue, never()).enqueue(isA(CancelClientStreamCommand.class), eq(true));
    }

    @Test
    public void inboundStatusShouldSetStatus() throws Exception {
        stream().transportState().setId(STREAM_ID);

        // Receive headers first so that it's a valid GRPC response.
        stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);

        stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);
        ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
        verify(listener).closed(captor.capture(), same(PROCESSED), any(Metadata.class));
        assertEquals(Status.INTERNAL.getCode(), captor.getValue().getCode());
    }

    @Test
    public void invalidInboundHeadersCancelStream() throws Exception {
        stream().transportState().setId(STREAM_ID);
        Http2Headers headers = grpcResponseHeaders();
        headers.set("random", "4");
        headers.remove(CONTENT_TYPE_HEADER);
        // Remove once b/16290036 is fixed.
        headers.status(new AsciiString("500"));
        stream().transportState().transportHeadersReceived(headers, false);
        verify(listener, never()).closed(any(Status.class), any(Metadata.class));

        // We are now waiting for 100 bytes of error context on the stream, cancel has not yet been
        // sent
        verify(channel, never()).writeAndFlush(any(CancelClientStreamCommand.class));
        stream().transportState().transportDataReceived(Unpooled.buffer(100).writeZero(100), false);
        verify(channel, never()).writeAndFlush(any(CancelClientStreamCommand.class));
        stream().transportState().transportDataReceived(Unpooled.buffer(1000).writeZero(1000), false);

        // Now verify that cancel is sent and an error is reported to the listener
        verify(writeQueue).enqueue(isA(CancelClientStreamCommand.class), eq(true));
        ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
        ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
        verify(listener).closed(captor.capture(), same(PROCESSED), metadataCaptor.capture());
        assertEquals(Status.UNKNOWN.getCode(), captor.getValue().getCode());
        assertEquals("4",
                metadataCaptor.getValue().get(Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER)));

    }

    @Test
    public void invalidInboundContentTypeShouldCancelStream() {
        // Set stream id to indicate it has been created
        stream().transportState().setId(STREAM_ID);
        Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK).set(CONTENT_TYPE_HEADER,
                new AsciiString("application/bad", UTF_8));
        stream().transportState().transportHeadersReceived(headers, false);
        Http2Headers trailers = new DefaultHttp2Headers().set(new AsciiString("grpc-status", UTF_8),
                new AsciiString("0", UTF_8));
        stream().transportState().transportHeadersReceived(trailers, true);
        ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
        ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
        verify(listener).closed(captor.capture(), same(PROCESSED), metadataCaptor.capture());
        Status status = captor.getValue();
        assertEquals(Status.Code.UNKNOWN, status.getCode());
        assertTrue(status.getDescription().contains("content-type"));
        assertEquals("application/bad",
                metadataCaptor.getValue().get(Metadata.Key.of("Content-Type", Metadata.ASCII_STRING_MARSHALLER)));
    }

    @Test
    public void nonGrpcResponseShouldSetStatus() throws Exception {
        stream().transportState().transportDataReceived(Unpooled.copiedBuffer(MESSAGE, UTF_8), true);
        ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
        verify(listener).closed(captor.capture(), same(PROCESSED), any(Metadata.class));
        assertEquals(Status.Code.INTERNAL, captor.getValue().getCode());
    }

    @Test
    public void deframedDataAfterCancelShouldBeIgnored() throws Exception {
        stream().transportState().setId(STREAM_ID);
        // Receive headers first so that it's a valid GRPC response.
        stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);

        // Receive 2 consecutive empty frames. Only one is delivered at a time to the listener.
        stream().transportState().transportDataReceived(simpleGrpcFrame(), false);
        stream().transportState().transportDataReceived(simpleGrpcFrame(), false);

        // Only allow the first to be delivered.
        stream().request(1);

        // Receive error trailers. The server status will not be processed until after all of the
        // data frames have been processed. Since cancellation will interrupt message delivery,
        // this status will never be processed and the listener will instead only see the
        // cancellation.
        stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);

        // Verify that the first was delivered.
        assertNotNull("message expected", listenerMessageQueue.poll());
        assertNull("no additional message expected", listenerMessageQueue.poll());

        // Now set the error status.
        Metadata trailers = Utils.convertTrailers(grpcResponseTrailers(Status.CANCELLED));
        stream().transportState().transportReportStatus(Status.CANCELLED, true, trailers);

        // Now allow the delivery of the second.
        stream().request(1);

        // Verify that the listener was only notified of the first message, not the second.
        assertNull("no additional message expected", listenerMessageQueue.poll());
        verify(listener).closed(eq(Status.CANCELLED), same(PROCESSED), eq(trailers));
    }

    @Test
    public void dataFrameWithEosShouldDeframeAndThenFail() {
        stream().transportState().setId(STREAM_ID);
        stream().request(1);

        // Receive headers first so that it's a valid GRPC response.
        stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);

        // Receive a DATA frame with EOS set.
        stream().transportState().transportDataReceived(simpleGrpcFrame(), true);

        // Verify that the message was delivered.
        assertNotNull("message expected", listenerMessageQueue.poll());
        assertNull("no additional message expected", listenerMessageQueue.poll());

        ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
        verify(listener).closed(captor.capture(), same(PROCESSED), any(Metadata.class));
        assertEquals(Status.Code.INTERNAL, captor.getValue().getCode());
    }

    @Test
    public void setHttp2StreamShouldNotifyReady() {
        listener = mock(ClientStreamListener.class);

        stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), methodDescriptor,
                new Metadata(), channel, AsciiString.of("localhost"), AsciiString.of("http"),
                AsciiString.of("agent"), StatsTraceContext.NOOP, transportTracer, CallOptions.DEFAULT);
        stream.start(listener);
        stream().transportState().setId(STREAM_ID);
        verify(listener, never()).onReady();
        assertFalse(stream.isReady());
        stream().transportState().setHttp2Stream(http2Stream);
        verify(listener).onReady();
        assertTrue(stream.isReady());
    }

    @Test
    public void removeUserAgentFromApplicationHeaders() {
        Metadata metadata = new Metadata();
        metadata.put(GrpcUtil.USER_AGENT_KEY, "bad agent");
        listener = mock(ClientStreamListener.class);
        Mockito.reset(writeQueue);
        ChannelPromise completedPromise = new DefaultChannelPromise(channel).setSuccess();
        when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(completedPromise);

        stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), methodDescriptor,
                new Metadata(), channel, AsciiString.of("localhost"), AsciiString.of("http"),
                AsciiString.of("good agent"), StatsTraceContext.NOOP, transportTracer, CallOptions.DEFAULT);
        stream.start(listener);

        ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
        verify(writeQueue).enqueue(cmdCap.capture(), eq(false));
        assertThat(ImmutableListMultimap.copyOf(cmdCap.getValue().headers())).containsEntry(Utils.USER_AGENT,
                AsciiString.of("good agent"));
    }

    @Test
    public void getRequestSentThroughHeader() {
        // Creating a GET method
        MethodDescriptor<?, ?> descriptor = MethodDescriptor.<Void, Void>newBuilder()
                .setType(MethodDescriptor.MethodType.UNARY).setFullMethodName("testService/test")
                .setRequestMarshaller(marshaller).setResponseMarshaller(marshaller).setIdempotent(true)
                .setSafe(true).build();
        NettyClientStream stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
                descriptor, new Metadata(), channel, AsciiString.of("localhost"), AsciiString.of("http"),
                AsciiString.of("agent"), StatsTraceContext.NOOP, transportTracer, CallOptions.DEFAULT);
        stream.start(listener);
        stream.transportState().setId(STREAM_ID);
        stream.transportState().setHttp2Stream(http2Stream);

        byte[] msg = smallMessage();
        stream.writeMessage(new ByteArrayInputStream(msg));
        stream.flush();
        stream.halfClose();
        ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
        verify(writeQueue).enqueue(cmdCap.capture(), eq(true));
        ImmutableListMultimap<CharSequence, CharSequence> headers = ImmutableListMultimap
                .copyOf(cmdCap.getValue().headers());
        assertThat(headers).containsEntry(AsciiString.of(":method"), Utils.HTTP_GET_METHOD);
        assertThat(headers).containsEntry(AsciiString.of(":path"),
                AsciiString.of("/testService/test?" + BaseEncoding.base64().encode(msg)));
    }

    @Override
    protected NettyClientStream createStream() {
        when(handler.getWriteQueue()).thenReturn(writeQueue);
        NettyClientStream stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
                methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"), AsciiString.of("http"),
                AsciiString.of("agent"), StatsTraceContext.NOOP, transportTracer, CallOptions.DEFAULT);
        stream.start(listener);
        stream.transportState().setHttp2Stream(http2Stream);
        reset(listener);
        return stream;
    }

    @Override
    protected void sendHeadersIfServer() {
    }

    @Override
    protected void closeStream() {
        stream().cancel(Status.CANCELLED);
    }

    private ByteBuf simpleGrpcFrame() {
        return Unpooled.wrappedBuffer(new byte[] { 0, 0, 0, 0, 2, 3, 14 });
    }

    private NettyClientStream stream() {
        return stream;
    }

    private Http2Headers grpcResponseHeaders() {
        return new DefaultHttp2Headers().status(STATUS_OK).set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC);
    }

    private Http2Headers grpcResponseTrailers(Status status) {
        Metadata trailers = new Metadata();
        trailers.put(InternalStatus.CODE_KEY, status);
        return Utils.convertTrailers(trailers, true);
    }

    private class TransportStateImpl extends NettyClientStream.TransportState {
        public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) {
            super(handler, channel.eventLoop(), maxMessageSize, StatsTraceContext.NOOP, transportTracer,
                    "methodName");
        }

        @Override
        protected Status statusFromFailedFuture(ChannelFuture f) {
            return Utils.statusFromThrowable(f.cause());
        }
    }
}