io.grpc.netty.NettyServerStreamTest.java Source code

Java tutorial

Introduction

Here is the source code for io.grpc.netty.NettyServerStreamTest.java

Source

/*
 * Copyright 2014 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.netty;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static org.junit.Assert.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import io.grpc.Attributes;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.StreamListener;
import io.grpc.internal.TransportTracer;
import io.netty.buffer.EmptyByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.util.AsciiString;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.LinkedList;
import java.util.Queue;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/** Unit tests for {@link NettyServerStream}. */
@RunWith(JUnit4.class)
public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream> {
    @Mock
    protected ServerStreamListener serverListener;

    @Mock
    private NettyServerHandler handler;

    private Metadata trailers = new Metadata();
    private final Queue<InputStream> listenerMessageQueue = new LinkedList<>();

    @Before
    @Override
    public void setUp() {
        super.setUp();

        // Verify onReady notification and then reset it.
        verify(listener()).onReady();
        reset(listener());

        doAnswer(new Answer<Void>() {
            @Override
            public Void answer(InvocationOnMock invocation) throws Throwable {
                StreamListener.MessageProducer producer = (StreamListener.MessageProducer) invocation
                        .getArguments()[0];
                InputStream message;
                while ((message = producer.next()) != null) {
                    listenerMessageQueue.add(message);
                }
                return null;
            }
        }).when(serverListener).messagesAvailable(ArgumentMatchers.<StreamListener.MessageProducer>any());
    }

    @Test
    public void writeMessageShouldSendResponse() throws Exception {
        ListMultimap<CharSequence, CharSequence> expectedHeaders = ImmutableListMultimap
                .copyOf(new DefaultHttp2Headers().status(Utils.STATUS_OK).set(Utils.CONTENT_TYPE_HEADER,
                        Utils.CONTENT_TYPE_GRPC));

        stream.writeHeaders(new Metadata());

        ArgumentCaptor<SendResponseHeadersCommand> sendHeadersCap = ArgumentCaptor
                .forClass(SendResponseHeadersCommand.class);
        verify(writeQueue).enqueue(sendHeadersCap.capture(), eq(true));
        SendResponseHeadersCommand sendHeaders = sendHeadersCap.getValue();
        assertThat(sendHeaders.stream()).isSameInstanceAs(stream.transportState());
        assertThat(ImmutableListMultimap.copyOf(sendHeaders.headers())).containsExactlyEntriesIn(expectedHeaders);
        assertThat(sendHeaders.endOfStream()).isFalse();

        byte[] msg = smallMessage();
        stream.writeMessage(new ByteArrayInputStream(msg));
        stream.flush();

        verify(writeQueue).enqueue(
                eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE), false)), eq(true));
    }

    @Test
    public void writeHeadersShouldSendHeaders() throws Exception {
        Metadata headers = new Metadata();
        ListMultimap<CharSequence, CharSequence> expectedHeaders = ImmutableListMultimap
                .copyOf(Utils.convertServerHeaders(headers));

        stream().writeHeaders(headers);

        ArgumentCaptor<SendResponseHeadersCommand> sendHeadersCap = ArgumentCaptor
                .forClass(SendResponseHeadersCommand.class);
        verify(writeQueue).enqueue(sendHeadersCap.capture(), eq(true));
        SendResponseHeadersCommand sendHeaders = sendHeadersCap.getValue();
        assertThat(sendHeaders.stream()).isSameInstanceAs(stream.transportState());
        assertThat(ImmutableListMultimap.copyOf(sendHeaders.headers())).containsExactlyEntriesIn(expectedHeaders);
        assertThat(sendHeaders.endOfStream()).isFalse();
    }

    @Test
    public void closeBeforeClientHalfCloseShouldSucceed() throws Exception {
        ListMultimap<CharSequence, CharSequence> expectedHeaders = ImmutableListMultimap
                .copyOf(new DefaultHttp2Headers().status(new AsciiString("200"))
                        .set(new AsciiString("content-type"), new AsciiString("application/grpc"))
                        .set(new AsciiString("grpc-status"), new AsciiString("0")));

        stream().close(Status.OK, new Metadata());

        ArgumentCaptor<SendResponseHeadersCommand> sendHeadersCap = ArgumentCaptor
                .forClass(SendResponseHeadersCommand.class);
        verify(writeQueue).enqueue(sendHeadersCap.capture(), eq(true));
        SendResponseHeadersCommand sendHeaders = sendHeadersCap.getValue();
        assertThat(sendHeaders.stream()).isSameInstanceAs(stream.transportState());
        assertThat(ImmutableListMultimap.copyOf(sendHeaders.headers())).containsExactlyEntriesIn(expectedHeaders);
        assertThat(sendHeaders.endOfStream()).isTrue();
        verifyZeroInteractions(serverListener);

        // Sending complete. Listener gets closed()
        stream().transportState().complete();

        verify(serverListener).closed(Status.OK);
        assertNull("no message expected", listenerMessageQueue.poll());
    }

    @Test
    public void closeWithErrorBeforeClientHalfCloseShouldSucceed() throws Exception {
        ListMultimap<CharSequence, CharSequence> expectedHeaders = ImmutableListMultimap
                .copyOf(new DefaultHttp2Headers().status(new AsciiString("200"))
                        .set(new AsciiString("content-type"), new AsciiString("application/grpc"))
                        .set(new AsciiString("grpc-status"), new AsciiString("1")));

        // Error is sent on wire and ends the stream
        stream().close(Status.CANCELLED, trailers);

        ArgumentCaptor<SendResponseHeadersCommand> sendHeadersCap = ArgumentCaptor
                .forClass(SendResponseHeadersCommand.class);
        verify(writeQueue).enqueue(sendHeadersCap.capture(), eq(true));
        SendResponseHeadersCommand sendHeaders = sendHeadersCap.getValue();
        assertThat(sendHeaders.stream()).isSameInstanceAs(stream.transportState());
        assertThat(ImmutableListMultimap.copyOf(sendHeaders.headers())).containsExactlyEntriesIn(expectedHeaders);
        assertThat(sendHeaders.endOfStream()).isTrue();
        verifyZeroInteractions(serverListener);

        // Sending complete. Listener gets closed()
        stream().transportState().complete();
        verify(serverListener).closed(Status.OK);
        assertNull("no message expected", listenerMessageQueue.poll());
    }

    @Test
    public void closeAfterClientHalfCloseShouldSucceed() throws Exception {
        ListMultimap<CharSequence, CharSequence> expectedHeaders = ImmutableListMultimap
                .copyOf(new DefaultHttp2Headers().status(new AsciiString("200"))
                        .set(new AsciiString("content-type"), new AsciiString("application/grpc"))
                        .set(new AsciiString("grpc-status"), new AsciiString("0")));

        // Client half-closes. Listener gets halfClosed()
        stream().transportState().inboundDataReceived(new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT), true);

        verify(serverListener).halfClosed();

        // Server closes. Status sent
        stream().close(Status.OK, trailers);
        assertNull("no message expected", listenerMessageQueue.poll());

        ArgumentCaptor<SendResponseHeadersCommand> cmdCap = ArgumentCaptor
                .forClass(SendResponseHeadersCommand.class);
        verify(writeQueue).enqueue(cmdCap.capture(), eq(true));
        SendResponseHeadersCommand cmd = cmdCap.getValue();
        assertThat(cmd.stream()).isSameInstanceAs(stream.transportState());
        assertThat(ImmutableListMultimap.copyOf(cmd.headers())).containsExactlyEntriesIn(expectedHeaders);
        assertThat(cmd.endOfStream()).isTrue();

        // Sending and receiving complete. Listener gets closed()
        stream().transportState().complete();
        verify(serverListener).closed(Status.OK);
        assertNull("no message expected", listenerMessageQueue.poll());
    }

    @Test
    public void abortStreamAndNotSendStatus() throws Exception {
        Status status = Status.INTERNAL.withCause(new Throwable());
        stream().transportState().transportReportStatus(status);
        verify(serverListener).closed(same(status));
        verify(channel, never()).writeAndFlush(any(SendResponseHeadersCommand.class));
        verify(channel, never()).writeAndFlush(any(SendGrpcFrameCommand.class));
        assertNull("no message expected", listenerMessageQueue.poll());
    }

    @Test
    public void abortStreamAfterClientHalfCloseShouldCallClose() {
        Status status = Status.INTERNAL.withCause(new Throwable());
        // Client half-closes. Listener gets halfClosed()
        stream().transportState().inboundDataReceived(new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT), true);
        verify(serverListener).halfClosed();
        // Abort from the transport layer
        stream().transportState().transportReportStatus(status);
        verify(serverListener).closed(same(status));
        assertNull("no message expected", listenerMessageQueue.poll());
    }

    @Test
    public void emptyFramerShouldSendNoPayload() {
        ListMultimap<CharSequence, CharSequence> expectedHeaders = ImmutableListMultimap
                .copyOf(new DefaultHttp2Headers().status(new AsciiString("200"))
                        .set(new AsciiString("content-type"), new AsciiString("application/grpc"))
                        .set(new AsciiString("grpc-status"), new AsciiString("0")));
        ArgumentCaptor<SendResponseHeadersCommand> cmdCap = ArgumentCaptor
                .forClass(SendResponseHeadersCommand.class);

        stream().close(Status.OK, new Metadata());

        verify(writeQueue).enqueue(cmdCap.capture(), eq(true));
        SendResponseHeadersCommand cmd = cmdCap.getValue();
        assertThat(cmd.stream()).isSameInstanceAs(stream.transportState());
        assertThat(ImmutableListMultimap.copyOf(cmd.headers())).containsExactlyEntriesIn(expectedHeaders);
        assertThat(cmd.endOfStream()).isTrue();
    }

    @Test
    public void cancelStreamShouldSucceed() {
        stream().cancel(Status.DEADLINE_EXCEEDED);
        verify(writeQueue)
                .enqueue(new CancelServerStreamCommand(stream().transportState(), Status.DEADLINE_EXCEEDED), true);
    }

    @Override
    protected NettyServerStream createStream() {
        when(handler.getWriteQueue()).thenReturn(writeQueue);
        StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
        TransportTracer transportTracer = new TransportTracer();
        NettyServerStream.TransportState state = new NettyServerStream.TransportState(handler, channel.eventLoop(),
                http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx, transportTracer, "method");
        NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY, "test-authority",
                statsTraceCtx, transportTracer);
        stream.transportState().setListener(serverListener);
        state.onStreamAllocated();
        verify(serverListener, atLeastOnce()).onReady();
        verifyNoMoreInteractions(serverListener);
        return stream;
    }

    @Override
    protected void sendHeadersIfServer() {
        stream.writeHeaders(new Metadata());
    }

    @Override
    protected void closeStream() {
        stream().close(Status.ABORTED, new Metadata());
    }

    @Override
    protected ServerStreamListener listener() {
        return serverListener;
    }

    @Override
    protected Queue<InputStream> listenerMessageQueue() {
        return listenerMessageQueue;
    }

    private NettyServerStream stream() {
        return stream;
    }
}