io.grpc.netty.NettyServerHandlerTest.java Source code

Java tutorial

Introduction

Here is the source code for io.grpc.netty.NettyServerHandlerTest.java

Source

/*
 * Copyright 2014 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.netty;

import static com.google.common.base.Charsets.UTF_8;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS;
import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS;
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED;
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.HTTP_METHOD;
import static io.grpc.netty.Utils.TE_HEADER;
import static io.grpc.netty.Utils.TE_TRAILERS;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.common.io.ByteStreams;
import com.google.common.truth.Truth;
import io.grpc.Attributes;
import io.grpc.InternalStatus;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.StreamListener;
import io.grpc.internal.testing.TestServerStreamTracer;
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2CodecUtil;
import io.netty.handler.codec.http2.Http2Error;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2LocalFlowController;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.Http2Stream;
import io.netty.util.AsciiString;
import java.io.InputStream;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.DisableOnDebug;
import org.junit.rules.TestRule;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/**
 * Unit tests for {@link NettyServerHandler}.
 */
@RunWith(JUnit4.class)
public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHandler> {

    @Rule
    public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10));

    private static final int STREAM_ID = 3;

    private static final AsciiString HTTP_FAKE_METHOD = AsciiString.of("FAKE");

    @Mock
    private ServerStreamListener streamListener;

    @Mock
    private ServerStreamTracer.Factory streamTracerFactory;

    private final ServerTransportListener transportListener = mock(ServerTransportListener.class,
            delegatesTo(new ServerTransportListenerImpl()));
    private final TestServerStreamTracer streamTracer = new TestServerStreamTracer();

    private NettyServerStream stream;
    private KeepAliveManager spyKeepAliveManager;

    final Queue<InputStream> streamListenerMessageQueue = new LinkedList<>();

    private int maxConcurrentStreams = Integer.MAX_VALUE;
    private int maxHeaderListSize = Integer.MAX_VALUE;
    private boolean permitKeepAliveWithoutCalls = true;
    private long permitKeepAliveTimeInNanos = 0;
    private long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED;
    private long maxConnectionAgeInNanos = MAX_CONNECTION_AGE_NANOS_DISABLED;
    private long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
    private long keepAliveTimeInNanos = DEFAULT_SERVER_KEEPALIVE_TIME_NANOS;
    private long keepAliveTimeoutInNanos = DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS;

    private class ServerTransportListenerImpl implements ServerTransportListener {

        @Override
        public void streamCreated(ServerStream stream, String method, Metadata headers) {
            stream.setListener(streamListener);
        }

        @Override
        public Attributes transportReady(Attributes attributes) {
            return Attributes.EMPTY;
        }

        @Override
        public void transportTerminated() {
        }
    }

    @Before
    public void setUp() {
        MockitoAnnotations.initMocks(this);

        when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))).thenReturn(streamTracer);

        doAnswer(new Answer<Void>() {
            @Override
            public Void answer(InvocationOnMock invocation) throws Throwable {
                StreamListener.MessageProducer producer = (StreamListener.MessageProducer) invocation
                        .getArguments()[0];
                InputStream message;
                while ((message = producer.next()) != null) {
                    streamListenerMessageQueue.add(message);
                }
                return null;
            }
        }).when(streamListener).messagesAvailable(any(StreamListener.MessageProducer.class));
    }

    @Override
    protected void manualSetUp() throws Exception {
        assertNull("manualSetUp should not run more than once", handler());

        initChannel(new GrpcHttp2ServerHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE));

        // replace the keepAliveManager with spyKeepAliveManager
        spyKeepAliveManager = mock(KeepAliveManager.class, delegatesTo(handler().getKeepAliveManagerForTest()));
        handler().setKeepAliveManagerForTest(spyKeepAliveManager);

        // Simulate receipt of the connection preface
        handler().handleProtocolNegotiationCompleted(Attributes.EMPTY, /*securityInfo=*/ null);
        channelRead(Http2CodecUtil.connectionPrefaceBuf());
        // Simulate receipt of initial remote settings.
        ByteBuf serializedSettings = serializeSettings(new Http2Settings());
        channelRead(serializedSettings);
    }

    @Test
    public void transportReadyDelayedUntilConnectionPreface() throws Exception {
        initChannel(new GrpcHttp2ServerHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE));

        handler().handleProtocolNegotiationCompleted(Attributes.EMPTY, /*securityInfo=*/ null);
        verify(transportListener, never()).transportReady(any(Attributes.class));

        // Simulate receipt of the connection preface
        channelRead(Http2CodecUtil.connectionPrefaceBuf());
        channelRead(serializeSettings(new Http2Settings()));
        verify(transportListener).transportReady(any(Attributes.class));
    }

    @Test
    public void sendFrameShouldSucceed() throws Exception {
        manualSetUp();
        createStream();

        // Send a frame and verify that it was written.
        ChannelFuture future = enqueue(new SendGrpcFrameCommand(stream.transportState(), content(), false));
        assertTrue(future.isSuccess());
        verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false),
                any(ChannelPromise.class));
    }

    @Test
    public void streamTracerCreated() throws Exception {
        manualSetUp();
        createStream();

        verify(streamTracerFactory).newServerStreamTracer(eq("foo/bar"), any(Metadata.class));
        StatsTraceContext statsTraceCtx = stream.statsTraceContext();
        List<StreamTracer> tracers = statsTraceCtx.getTracersForTest();
        assertEquals(1, tracers.size());
        assertSame(streamTracer, tracers.get(0));
    }

    @Test
    public void inboundDataWithEndStreamShouldForwardToStreamListener() throws Exception {
        manualSetUp();
        inboundDataShouldForwardToStreamListener(true);
    }

    @Test
    public void inboundDataShouldForwardToStreamListener() throws Exception {
        manualSetUp();
        inboundDataShouldForwardToStreamListener(false);
    }

    private void inboundDataShouldForwardToStreamListener(boolean endStream) throws Exception {
        createStream();
        stream.request(1);

        // Create a data frame and then trigger the handler to read it.
        ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray());
        channelRead(frame);
        verify(streamListener, atLeastOnce()).messagesAvailable(any(StreamListener.MessageProducer.class));
        InputStream message = streamListenerMessageQueue.poll();
        assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message));
        message.close();
        assertNull("no additional message expected", streamListenerMessageQueue.poll());

        if (endStream) {
            verify(streamListener).halfClosed();
        }
        verify(streamListener, atLeastOnce()).onReady();
        verifyNoMoreInteractions(streamListener);
    }

    @Test
    public void clientHalfCloseShouldForwardToStreamListener() throws Exception {
        manualSetUp();
        createStream();
        stream.request(1);

        channelRead(emptyGrpcFrame(STREAM_ID, true));

        verify(streamListener, atLeastOnce()).messagesAvailable(any(StreamListener.MessageProducer.class));
        InputStream message = streamListenerMessageQueue.poll();
        assertArrayEquals(new byte[0], ByteStreams.toByteArray(message));
        assertNull("no additional message expected", streamListenerMessageQueue.poll());
        verify(streamListener).halfClosed();
        verify(streamListener, atLeastOnce()).onReady();
        verifyNoMoreInteractions(streamListener);
    }

    @Test
    public void clientCancelShouldForwardToStreamListener() throws Exception {
        manualSetUp();
        createStream();

        channelRead(rstStreamFrame(STREAM_ID, (int) Http2Error.CANCEL.code()));

        ArgumentCaptor<Status> statusCap = ArgumentCaptor.forClass(Status.class);
        verify(streamListener).closed(statusCap.capture());
        assertEquals(Code.CANCELLED, statusCap.getValue().getCode());
        Truth.assertThat(statusCap.getValue().getDescription()).contains("RST_STREAM");
        verify(streamListener, atLeastOnce()).onReady();
        assertNull("no messages expected", streamListenerMessageQueue.poll());
    }

    @Test
    public void streamErrorShouldNotCloseChannel() throws Exception {
        manualSetUp();
        createStream();
        stream.request(1);

        // When a DATA frame is read, throw an exception. It will be converted into an
        // Http2StreamException.
        RuntimeException e = new RuntimeException("Fake Exception");
        doThrow(e).when(streamListener).messagesAvailable(any(StreamListener.MessageProducer.class));

        // Read a DATA frame to trigger the exception.
        channelRead(emptyGrpcFrame(STREAM_ID, true));

        // Verify that the channel was NOT closed.
        assertTrue(channel().isOpen());

        // Verify the stream was closed.
        ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
        verify(streamListener).closed(captor.capture());
        assertEquals(e, captor.getValue().asException().getCause());
        assertEquals(Code.UNKNOWN, captor.getValue().getCode());
    }

    @Test
    public void closeShouldGracefullyCloseChannel() throws Exception {
        manualSetUp();
        handler().close(ctx(), newPromise());

        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                isA(ByteBuf.class), any(ChannelPromise.class));
        verifyWrite().writePing(eq(ctx()), eq(false), eq(NettyServerHandler.GRACEFUL_SHUTDOWN_PING),
                isA(ChannelPromise.class));
        channelRead(pingFrame(/*ack=*/ true, NettyServerHandler.GRACEFUL_SHUTDOWN_PING));

        verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), isA(ByteBuf.class),
                any(ChannelPromise.class));

        // Verify that the channel was closed.
        assertFalse(channel().isOpen());
    }

    @Test
    public void exceptionCaughtShouldCloseConnection() throws Exception {
        manualSetUp();
        handler().exceptionCaught(ctx(), new RuntimeException("fake exception"));

        // TODO(nmittler): EmbeddedChannel does not currently invoke the channelInactive processing,
        // so exceptionCaught() will not close streams properly in this test.
        // Once https://github.com/netty/netty/issues/4316 is resolved, we should also verify that
        // any open streams are closed properly.
        assertFalse(channel().isOpen());
    }

    @Test
    public void channelInactiveShouldCloseStreams() throws Exception {
        manualSetUp();
        createStream();
        handler().channelInactive(ctx());
        ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
        verify(streamListener).closed(captor.capture());
        assertFalse(captor.getValue().isOk());
    }

    @Test
    public void shouldAdvertiseMaxConcurrentStreams() throws Exception {
        maxConcurrentStreams = 314;
        manualSetUp();

        ArgumentCaptor<Http2Settings> captor = ArgumentCaptor.forClass(Http2Settings.class);
        verifyWrite().writeSettings(any(ChannelHandlerContext.class), captor.capture(), any(ChannelPromise.class));

        assertEquals(maxConcurrentStreams, captor.getValue().maxConcurrentStreams().intValue());
    }

    @Test
    public void shouldAdvertiseMaxHeaderListSize() throws Exception {
        maxHeaderListSize = 123;
        manualSetUp();

        ArgumentCaptor<Http2Settings> captor = ArgumentCaptor.forClass(Http2Settings.class);
        verifyWrite().writeSettings(any(ChannelHandlerContext.class), captor.capture(), any(ChannelPromise.class));

        assertEquals(maxHeaderListSize, captor.getValue().maxHeaderListSize().intValue());
    }

    @Test
    public void connectionWindowShouldBeOverridden() throws Exception {
        flowControlWindow = 1048576; // 1MiB
        manualSetUp();

        Http2Stream connectionStream = connection().connectionStream();
        Http2LocalFlowController localFlowController = connection().local().flowController();
        int actualInitialWindowSize = localFlowController.initialWindowSize(connectionStream);
        int actualWindowSize = localFlowController.windowSize(connectionStream);
        assertEquals(flowControlWindow, actualWindowSize);
        assertEquals(flowControlWindow, actualInitialWindowSize);
    }

    @Test
    public void cancelShouldSendRstStream() throws Exception {
        manualSetUp();
        createStream();
        enqueue(new CancelServerStreamCommand(stream.transportState(), Status.DEADLINE_EXCEEDED));
        verifyWrite().writeRstStream(eq(ctx()), eq(stream.transportState().id()), eq(Http2Error.CANCEL.code()),
                any(ChannelPromise.class));
    }

    @Test
    public void headersWithInvalidContentTypeShouldFail() throws Exception {
        manualSetUp();
        Http2Headers headers = new DefaultHttp2Headers().method(HTTP_METHOD)
                .set(CONTENT_TYPE_HEADER, new AsciiString("application/bad", UTF_8)).set(TE_HEADER, TE_TRAILERS)
                .path(new AsciiString("/foo/bar"));
        ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
        channelRead(headersFrame);
        Http2Headers responseHeaders = new DefaultHttp2Headers()
                .set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.INTERNAL.value()))
                .set(InternalStatus.MESSAGE_KEY.name(), "Content-Type 'application/bad' is not supported")
                .status("" + 415).set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");

        verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
                eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
    }

    @Test
    public void headersWithInvalidMethodShouldFail() throws Exception {
        manualSetUp();
        Http2Headers headers = new DefaultHttp2Headers().method(HTTP_FAKE_METHOD)
                .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC).path(new AsciiString("/foo/bar"));
        ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
        channelRead(headersFrame);
        Http2Headers responseHeaders = new DefaultHttp2Headers()
                .set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.INTERNAL.value()))
                .set(InternalStatus.MESSAGE_KEY.name(), "Method 'FAKE' is not supported").status("" + 405)
                .set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");

        verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
                eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
    }

    @Test
    public void headersWithMissingPathShouldFail() throws Exception {
        manualSetUp();
        Http2Headers headers = new DefaultHttp2Headers().method(HTTP_METHOD).set(CONTENT_TYPE_HEADER,
                CONTENT_TYPE_GRPC);
        ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
        channelRead(headersFrame);
        Http2Headers responseHeaders = new DefaultHttp2Headers()
                .set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.UNIMPLEMENTED.value()))
                .set(InternalStatus.MESSAGE_KEY.name(), "Expected path but is missing").status("" + 404)
                .set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");

        verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
                eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
    }

    @Test
    public void headersWithInvalidPathShouldFail() throws Exception {
        manualSetUp();
        Http2Headers headers = new DefaultHttp2Headers().method(HTTP_METHOD)
                .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC).path(new AsciiString("foo/bar"));
        ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
        channelRead(headersFrame);
        Http2Headers responseHeaders = new DefaultHttp2Headers()
                .set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.UNIMPLEMENTED.value()))
                .set(InternalStatus.MESSAGE_KEY.name(), "Expected path to start with /: foo/bar").status("" + 404)
                .set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");

        verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
                eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
    }

    @Test
    public void headersSupportExtensionContentType() throws Exception {
        manualSetUp();
        Http2Headers headers = new DefaultHttp2Headers().method(HTTP_METHOD)
                .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc+json", UTF_8))
                .set(TE_HEADER, TE_TRAILERS).path(new AsciiString("/foo/bar"));
        ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
        channelRead(headersFrame);

        ArgumentCaptor<NettyServerStream> streamCaptor = ArgumentCaptor.forClass(NettyServerStream.class);
        ArgumentCaptor<String> methodCaptor = ArgumentCaptor.forClass(String.class);
        verify(transportListener).streamCreated(streamCaptor.capture(), methodCaptor.capture(),
                any(Metadata.class));
        stream = streamCaptor.getValue();
    }

    @Test
    public void keepAliveManagerOnDataReceived_headersRead() throws Exception {
        manualSetUp();
        ByteBuf headersFrame = headersFrame(STREAM_ID, new DefaultHttp2Headers());
        channelRead(headersFrame);

        verify(spyKeepAliveManager).onDataReceived();
        verify(spyKeepAliveManager, never()).onTransportTermination();
    }

    @Test
    public void keepAliveManagerOnDataReceived_dataRead() throws Exception {
        manualSetUp();
        createStream();
        verify(spyKeepAliveManager).onDataReceived(); // received headers

        channelRead(grpcDataFrame(STREAM_ID, false, contentAsArray()));

        verify(spyKeepAliveManager, times(2)).onDataReceived();

        channelRead(grpcDataFrame(STREAM_ID, false, contentAsArray()));

        verify(spyKeepAliveManager, times(3)).onDataReceived();
        verify(spyKeepAliveManager, never()).onTransportTermination();
    }

    @Test
    public void keepAliveManagerOnDataReceived_rstStreamRead() throws Exception {
        manualSetUp();
        createStream();
        verify(spyKeepAliveManager).onDataReceived(); // received headers

        channelRead(rstStreamFrame(STREAM_ID, (int) Http2Error.CANCEL.code()));

        verify(spyKeepAliveManager, times(2)).onDataReceived();
        verify(spyKeepAliveManager, never()).onTransportTermination();
    }

    @Test
    public void keepAliveManagerOnDataReceived_pingRead() throws Exception {
        manualSetUp();
        channelRead(pingFrame(false /* isAck */, 1234L));

        verify(spyKeepAliveManager).onDataReceived();
        verify(spyKeepAliveManager, never()).onTransportTermination();
    }

    @Test
    public void keepAliveManagerOnDataReceived_pingActRead() throws Exception {
        manualSetUp();
        channelRead(pingFrame(true /* isAck */, 1234L));

        verify(spyKeepAliveManager).onDataReceived();
        verify(spyKeepAliveManager, never()).onTransportTermination();
    }

    @Test
    public void keepAliveManagerOnTransportTermination() throws Exception {
        manualSetUp();
        handler().channelInactive(handler().ctx());

        verify(spyKeepAliveManager).onTransportTermination();
    }

    @Test
    public void keepAliveManager_pingSent() throws Exception {
        keepAliveTimeInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        keepAliveTimeoutInNanos = TimeUnit.MINUTES.toNanos(30L);
        manualSetUp();

        assertEquals(0, transportTracer.getStats().keepAlivesSent);
        fakeClock().forwardNanos(keepAliveTimeInNanos);
        assertEquals(1, transportTracer.getStats().keepAlivesSent);

        verifyWrite().writePing(eq(ctx()), eq(false), eq(0xDEADL), any(ChannelPromise.class));

        spyKeepAliveManager.onDataReceived();
        fakeClock().forwardTime(10L, TimeUnit.MILLISECONDS);

        verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(0xDEADL), any(ChannelPromise.class));
        assertTrue(channel().isOpen());
    }

    @Test
    public void keepAliveManager_pingTimeout() throws Exception {
        keepAliveTimeInNanos = 123L /* nanoseconds */;
        keepAliveTimeoutInNanos = 456L /* nanoseconds */;
        manualSetUp();

        fakeClock().forwardNanos(keepAliveTimeInNanos);

        assertTrue(channel().isOpen());

        fakeClock().forwardNanos(keepAliveTimeoutInNanos);

        assertTrue(!channel().isOpen());
    }

    @Test
    public void keepAliveEnforcer_enforcesPings() throws Exception {
        permitKeepAliveWithoutCalls = false;
        permitKeepAliveTimeInNanos = TimeUnit.HOURS.toNanos(1);
        manualSetUp();

        for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES + 1; i++) {
            channelRead(pingFrame(false /* isAck */, 1L));
        }
        verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.ENHANCE_YOUR_CALM.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertFalse(channel().isActive());
    }

    @Test
    public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception {
        permitKeepAliveWithoutCalls = false;
        permitKeepAliveTimeInNanos = TimeUnit.HOURS.toNanos(1);
        manualSetUp();

        createStream();
        Http2Headers headers = Utils.convertServerHeaders(new Metadata());
        ChannelFuture future = enqueue(SendResponseHeadersCommand.createHeaders(stream.transportState(), headers));
        future.get();
        for (int i = 0; i < 10; i++) {
            future = enqueue(new SendGrpcFrameCommand(stream.transportState(), content().retainedSlice(), false));
            future.get();
            channel().releaseOutbound();
            channelRead(pingFrame(false /* isAck */, 1L));
        }
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.ENHANCE_YOUR_CALM.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
    }

    @Test
    public void keepAliveEnforcer_initialIdle() throws Exception {
        permitKeepAliveWithoutCalls = false;
        permitKeepAliveTimeInNanos = 0;
        manualSetUp();

        for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES + 1; i++) {
            channelRead(pingFrame(false /* isAck */, 1L));
        }
        verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.ENHANCE_YOUR_CALM.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertFalse(channel().isActive());
    }

    @Test
    public void keepAliveEnforcer_noticesActive() throws Exception {
        permitKeepAliveWithoutCalls = false;
        permitKeepAliveTimeInNanos = 0;
        manualSetUp();

        createStream();
        for (int i = 0; i < 10; i++) {
            channelRead(pingFrame(false /* isAck */, 1L));
        }
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.ENHANCE_YOUR_CALM.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
    }

    @Test
    public void keepAliveEnforcer_noticesInactive() throws Exception {
        permitKeepAliveWithoutCalls = false;
        permitKeepAliveTimeInNanos = 0;
        manualSetUp();

        createStream();
        channelRead(rstStreamFrame(STREAM_ID, (int) Http2Error.CANCEL.code()));
        for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES + 1; i++) {
            channelRead(pingFrame(false /* isAck */, 1L));
        }
        verifyWrite().writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.ENHANCE_YOUR_CALM.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        assertFalse(channel().isActive());
    }

    @Test
    public void noGoAwaySentBeforeMaxConnectionIdleReached() throws Exception {
        maxConnectionIdleInNanos = TimeUnit.MINUTES.toNanos(30L);
        manualSetUp();

        fakeClock().forwardTime(20, TimeUnit.MINUTES);

        // GO_AWAY not sent yet
        verifyWrite(never()).writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertTrue(channel().isOpen());
    }

    @Test
    public void maxConnectionIdle_goAwaySent_pingAck() throws Exception {
        maxConnectionIdleInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        manualSetUp();
        assertTrue(channel().isOpen());

        fakeClock().forwardNanos(maxConnectionIdleInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));

        channelRead(pingFrame(true /* isAck */, 0xDEADL)); // irrelevant ping Ack
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertTrue(channel().isOpen());

        channelRead(pingFrame(true /* isAck */, 0x97ACEF001L));

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        // channel closed
        assertTrue(!channel().isOpen());
    }

    @Test
    public void maxConnectionIdle_goAwaySent_pingTimeout() throws Exception {
        maxConnectionIdleInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        manualSetUp();
        assertTrue(channel().isOpen());

        fakeClock().forwardNanos(maxConnectionIdleInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertTrue(channel().isOpen());

        fakeClock().forwardTime(10, TimeUnit.SECONDS);

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        // channel closed
        assertTrue(!channel().isOpen());
    }

    @Test
    public void maxConnectionIdle_activeThenRst_pingAck() throws Exception {
        maxConnectionIdleInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        manualSetUp();
        createStream();

        fakeClock().forwardNanos(maxConnectionIdleInNanos);

        // GO_AWAY not sent when active
        verifyWrite(never()).writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertTrue(channel().isOpen());

        channelRead(rstStreamFrame(STREAM_ID, (int) Http2Error.CANCEL.code()));

        fakeClock().forwardNanos(maxConnectionIdleInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        assertTrue(channel().isOpen());

        fakeClock().forwardTime(10, TimeUnit.SECONDS);

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        // channel closed
        assertTrue(!channel().isOpen());
    }

    @Test
    public void maxConnectionIdle_activeThenRst_pingTimeoutk() throws Exception {
        maxConnectionIdleInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        manualSetUp();
        createStream();

        fakeClock().forwardNanos(maxConnectionIdleInNanos);

        // GO_AWAY not sent when active
        verifyWrite(never()).writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertTrue(channel().isOpen());

        channelRead(rstStreamFrame(STREAM_ID, (int) Http2Error.CANCEL.code()));

        fakeClock().forwardNanos(maxConnectionIdleInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        assertTrue(channel().isOpen());

        channelRead(pingFrame(true /* isAck */, 0x97ACEF001L));

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        // channel closed
        assertTrue(!channel().isOpen());
    }

    @Test
    public void noGoAwaySentBeforeMaxConnectionAgeReached() throws Exception {
        maxConnectionAgeInNanos = TimeUnit.MINUTES.toNanos(30L);
        manualSetUp();

        fakeClock().forwardTime(20, TimeUnit.MINUTES);

        // GO_AWAY not sent yet
        verifyWrite(never()).writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertTrue(channel().isOpen());
    }

    @Test
    public void maxConnectionAge_goAwaySent_pingAck() throws Exception {

        maxConnectionAgeInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        manualSetUp();
        assertTrue(channel().isOpen());

        fakeClock().forwardNanos(maxConnectionAgeInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));

        channelRead(pingFrame(true /* isAck */, 0xDEADL)); // irrelevant ping Ack
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertTrue(channel().isOpen());

        channelRead(pingFrame(true /* isAck */, 0x97ACEF001L));

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        // channel closed
        assertTrue(!channel().isOpen());
    }

    @Test
    public void maxConnectionAge_goAwaySent_pingTimeout() throws Exception {

        maxConnectionAgeInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        manualSetUp();
        assertTrue(channel().isOpen());

        fakeClock().forwardNanos(maxConnectionAgeInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        assertTrue(channel().isOpen());

        fakeClock().forwardTime(10, TimeUnit.SECONDS);

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        // channel closed
        assertTrue(!channel().isOpen());
    }

    @Test
    public void maxConnectionAgeGrace_channelStillOpenDuringGracePeriod() throws Exception {
        maxConnectionAgeInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        maxConnectionAgeGraceInNanos = TimeUnit.MINUTES.toNanos(30L);
        manualSetUp();
        createStream();

        fakeClock().forwardNanos(maxConnectionAgeInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));

        fakeClock().forwardTime(20, TimeUnit.MINUTES);

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));
        // channel not closed yet
        assertTrue(channel().isOpen());
    }

    @Test
    public void maxConnectionAgeGrace_channelClosedAfterGracePeriod_withPingTimeout() throws Exception {
        maxConnectionAgeInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        maxConnectionAgeGraceInNanos = TimeUnit.MINUTES.toNanos(30L); // greater than ping timeout
        manualSetUp();
        createStream();

        fakeClock().forwardNanos(maxConnectionAgeInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));

        fakeClock().forwardNanos(TimeUnit.SECONDS.toNanos(10));

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));

        fakeClock().forwardNanos(maxConnectionAgeGraceInNanos - 2);

        assertTrue(channel().isOpen());

        fakeClock().forwardTime(2, TimeUnit.MILLISECONDS);

        // channel closed
        assertTrue(!channel().isOpen());
    }

    @Test
    public void maxConnectionAgeGrace_channelClosedAfterGracePeriod_withPingAck() throws Exception {
        maxConnectionAgeInNanos = TimeUnit.MILLISECONDS.toNanos(10L);
        maxConnectionAgeGraceInNanos = TimeUnit.MINUTES.toNanos(30L); // greater than ping timeout
        manualSetUp();
        createStream();

        fakeClock().forwardNanos(maxConnectionAgeInNanos);

        // first GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));
        // ping sent
        verifyWrite().writePing(eq(ctx()), eq(false), eq(0x97ACEF001L), any(ChannelPromise.class));
        verifyWrite(never()).writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()),
                any(ByteBuf.class), any(ChannelPromise.class));

        long pingRoundTripMillis = 100; // less than ping timeout
        fakeClock().forwardTime(pingRoundTripMillis, TimeUnit.MILLISECONDS);
        channelRead(pingFrame(true /* isAck */, 0x97ACEF001L));

        // second GO_AWAY sent
        verifyWrite().writeGoAway(eq(ctx()), eq(STREAM_ID), eq(Http2Error.NO_ERROR.code()), any(ByteBuf.class),
                any(ChannelPromise.class));

        fakeClock().forwardNanos(maxConnectionAgeGraceInNanos - TimeUnit.MILLISECONDS.toNanos(2));

        assertTrue(channel().isOpen());

        fakeClock().forwardTime(2, TimeUnit.MILLISECONDS);

        // channel closed
        assertTrue(!channel().isOpen());
    }

    private void createStream() throws Exception {
        Http2Headers headers = new DefaultHttp2Headers().method(HTTP_METHOD)
                .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC).set(TE_HEADER, TE_TRAILERS)
                .path(new AsciiString("/foo/bar"));
        ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
        channelRead(headersFrame);

        ArgumentCaptor<NettyServerStream> streamCaptor = ArgumentCaptor.forClass(NettyServerStream.class);
        ArgumentCaptor<String> methodCaptor = ArgumentCaptor.forClass(String.class);
        verify(transportListener).streamCreated(streamCaptor.capture(), methodCaptor.capture(),
                any(Metadata.class));
        stream = streamCaptor.getValue();
    }

    private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception {
        ByteBuf buf = NettyTestUtil.messageFrame("");
        try {
            return dataFrame(streamId, endStream, buf);
        } finally {
            buf.release();
        }
    }

    @Override
    protected NettyServerHandler newHandler() {
        return NettyServerHandler.newHandler(/* channelUnused= */ channel().newPromise(), frameReader(),
                frameWriter(), transportListener, Arrays.asList(streamTracerFactory), transportTracer,
                maxConcurrentStreams, flowControlWindow, maxHeaderListSize, DEFAULT_MAX_MESSAGE_SIZE,
                keepAliveTimeInNanos, keepAliveTimeoutInNanos, maxConnectionIdleInNanos, maxConnectionAgeInNanos,
                maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos);
    }

    @Override
    protected WriteQueue initWriteQueue() {
        return handler().getWriteQueue();
    }

    @Override
    protected void makeStream() throws Exception {
        createStream();
    }
}