io.grpc.netty.NettyHandlerTestBase.java Source code

Java tutorial

Introduction

Here is the source code for io.grpc.netty.NettyHandlerTestBase.java

Source

/*
 * Copyright 2014 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.netty;

import static com.google.common.base.Charsets.UTF_8;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
import static org.junit.Assert.assertEquals;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.grpc.InternalChannelz.TransportStats;
import io.grpc.internal.FakeClock;
import io.grpc.internal.MessageFramer;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.Http2CodecUtil;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2ConnectionHandler;
import io.netty.handler.codec.http2.Http2Exception;
import io.netty.handler.codec.http2.Http2FrameReader;
import io.netty.handler.codec.http2.Http2FrameWriter;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2HeadersDecoder;
import io.netty.handler.codec.http2.Http2LocalFlowController;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.Http2Stream;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import java.io.ByteArrayInputStream;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;

/**
 * Base class for Netty handler unit tests.
 */
@RunWith(JUnit4.class)
public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {

    private ByteBuf content;

    private EmbeddedChannel channel;

    private ChannelHandlerContext ctx;

    private Http2FrameWriter frameWriter;

    private Http2FrameReader frameReader;

    private T handler;

    private WriteQueue writeQueue;

    /**
     * Does additional setup jobs. Call it manually when necessary.
     */
    protected void manualSetUp() throws Exception {
    }

    protected final TransportTracer transportTracer = new TransportTracer();
    protected int flowControlWindow = DEFAULT_WINDOW_SIZE;

    private final FakeClock fakeClock = new FakeClock();

    FakeClock fakeClock() {
        return fakeClock;
    }

    /**
     * Must be called by subclasses to initialize the handler and channel.
     */
    protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception {
        content = Unpooled.copiedBuffer("hello world", UTF_8);
        frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter()));
        frameReader = new DefaultHttp2FrameReader(headersDecoder);

        channel = new FakeClockSupportedChanel();
        handler = newHandler();
        channel.pipeline().addLast(handler);
        ctx = channel.pipeline().context(handler);

        writeQueue = initWriteQueue();
    }

    private final class FakeClockSupportedChanel extends EmbeddedChannel {
        EventLoop eventLoop;

        FakeClockSupportedChanel(ChannelHandler... handlers) {
            super(handlers);
        }

        @Override
        public EventLoop eventLoop() {
            if (eventLoop == null) {
                createEventLoop();
            }
            return eventLoop;
        }

        void createEventLoop() {
            EventLoop realEventLoop = super.eventLoop();
            if (realEventLoop == null) {
                return;
            }
            eventLoop = mock(EventLoop.class, delegatesTo(realEventLoop));
            doAnswer(new Answer<ScheduledFuture<Void>>() {
                @Override
                public ScheduledFuture<Void> answer(InvocationOnMock invocation) throws Throwable {
                    Runnable command = (Runnable) invocation.getArguments()[0];
                    Long delay = (Long) invocation.getArguments()[1];
                    TimeUnit timeUnit = (TimeUnit) invocation.getArguments()[2];
                    return new FakeClockScheduledNettyFuture(eventLoop, command, delay, timeUnit);
                }
            }).when(eventLoop).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class));
        }
    }

    private final class FakeClockScheduledNettyFuture extends DefaultPromise<Void>
            implements ScheduledFuture<Void> {
        final java.util.concurrent.ScheduledFuture<?> future;

        FakeClockScheduledNettyFuture(EventLoop eventLoop, final Runnable command, long delay, TimeUnit timeUnit) {
            super(eventLoop);
            Runnable wrap = new Runnable() {
                @Override
                public void run() {
                    try {
                        command.run();
                    } catch (Throwable t) {
                        setFailure(t);
                        return;
                    }
                    if (!isDone()) {
                        Promise<Void> unused = setSuccess(null);
                    }
                    // else: The command itself, such as a shutdown task, might have cancelled all the
                    // scheduled tasks already.
                }
            };
            future = fakeClock.getScheduledExecutorService().schedule(wrap, delay, timeUnit);
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            if (future.cancel(mayInterruptIfRunning)) {
                return super.cancel(mayInterruptIfRunning);
            }
            return false;
        }

        @Override
        public long getDelay(TimeUnit unit) {
            return Math.max(future.getDelay(unit), 1L); // never return zero or negative delay.
        }

        @Override
        public int compareTo(Delayed o) {
            return future.compareTo(o);
        }
    }

    protected final T handler() {
        return handler;
    }

    protected final EmbeddedChannel channel() {
        return channel;
    }

    protected final ChannelHandlerContext ctx() {
        return ctx;
    }

    protected final Http2FrameWriter frameWriter() {
        return frameWriter;
    }

    protected final Http2FrameReader frameReader() {
        return frameReader;
    }

    protected final ByteBuf content() {
        return content;
    }

    protected final byte[] contentAsArray() {
        return ByteBufUtil.getBytes(content());
    }

    protected final Http2FrameWriter verifyWrite() {
        return verify(frameWriter);
    }

    protected final Http2FrameWriter verifyWrite(VerificationMode verificationMode) {
        return verify(frameWriter, verificationMode);
    }

    protected final void channelRead(Object obj) throws Exception {
        channel.writeInbound(obj);
    }

    protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
        final ByteBuf compressionFrame = Unpooled.buffer(content.length);
        MessageFramer framer = new MessageFramer(new MessageFramer.Sink() {
            @Override
            public void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) {
                if (frame != null) {
                    ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
                    compressionFrame.writeBytes(bytebuf);
                }
            }
        }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), StatsTraceContext.NOOP);
        framer.writePayload(new ByteArrayInputStream(content));
        framer.flush();
        ChannelHandlerContext ctx = newMockContext();
        new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream, newPromise());
        return captureWrite(ctx);
    }

    protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
        // Need to retain the content since the frameWriter releases it.
        content.retain();

        ChannelHandlerContext ctx = newMockContext();
        new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise());
        return captureWrite(ctx);
    }

    protected final ByteBuf pingFrame(boolean ack, long payload) {
        ChannelHandlerContext ctx = newMockContext();
        new DefaultHttp2FrameWriter().writePing(ctx, ack, payload, newPromise());
        return captureWrite(ctx);
    }

    protected final ByteBuf headersFrame(int streamId, Http2Headers headers) {
        ChannelHandlerContext ctx = newMockContext();
        new DefaultHttp2FrameWriter().writeHeaders(ctx, streamId, headers, 0, false, newPromise());
        return captureWrite(ctx);
    }

    protected final ByteBuf goAwayFrame(int lastStreamId) {
        return goAwayFrame(lastStreamId, 0, Unpooled.EMPTY_BUFFER);
    }

    protected final ByteBuf goAwayFrame(int lastStreamId, int errorCode, ByteBuf data) {
        ChannelHandlerContext ctx = newMockContext();
        new DefaultHttp2FrameWriter().writeGoAway(ctx, lastStreamId, errorCode, data, newPromise());
        return captureWrite(ctx);
    }

    protected final ByteBuf rstStreamFrame(int streamId, int errorCode) {
        ChannelHandlerContext ctx = newMockContext();
        new DefaultHttp2FrameWriter().writeRstStream(ctx, streamId, errorCode, newPromise());
        return captureWrite(ctx);
    }

    protected final ByteBuf serializeSettings(Http2Settings settings) {
        ChannelHandlerContext ctx = newMockContext();
        new DefaultHttp2FrameWriter().writeSettings(ctx, settings, newPromise());
        return captureWrite(ctx);
    }

    protected final ByteBuf windowUpdate(int streamId, int delta) {
        ChannelHandlerContext ctx = newMockContext();
        new DefaultHttp2FrameWriter().writeWindowUpdate(ctx, 0, delta, newPromise());
        return captureWrite(ctx);
    }

    protected final ChannelPromise newPromise() {
        return channel.newPromise();
    }

    protected final Http2Connection connection() {
        return handler().connection();
    }

    @CanIgnoreReturnValue
    protected final ChannelFuture enqueue(WriteQueue.QueuedCommand command) {
        ChannelFuture future = writeQueue.enqueue(command, true);
        channel.runPendingTasks();
        return future;
    }

    protected final ChannelHandlerContext newMockContext() {
        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
        when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
        EventLoop eventLoop = mock(EventLoop.class);
        when(ctx.executor()).thenReturn(eventLoop);
        when(ctx.channel()).thenReturn(channel);
        return ctx;
    }

    protected final ByteBuf captureWrite(ChannelHandlerContext ctx) {
        ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
        verify(ctx, atLeastOnce()).write(captor.capture(), any(ChannelPromise.class));
        CompositeByteBuf composite = Unpooled.compositeBuffer();
        for (ByteBuf buf : captor.getAllValues()) {
            composite.addComponent(buf);
            composite.writerIndex(composite.writerIndex() + buf.readableBytes());
        }
        return composite;
    }

    protected abstract T newHandler() throws Http2Exception;

    protected abstract WriteQueue initWriteQueue();

    protected abstract void makeStream() throws Exception;

    @Test
    public void dataPingSentOnHeaderRecieved() throws Exception {
        manualSetUp();
        makeStream();
        AbstractNettyHandler handler = (AbstractNettyHandler) handler();
        handler.setAutoTuneFlowControl(true);

        channelRead(dataFrame(3, false, content()));

        assertEquals(1, handler.flowControlPing().getPingCount());
    }

    @Test
    public void dataPingAckIsRecognized() throws Exception {
        manualSetUp();
        makeStream();
        AbstractNettyHandler handler = (AbstractNettyHandler) handler();
        handler.setAutoTuneFlowControl(true);

        channelRead(dataFrame(3, false, content()));
        long pingData = handler.flowControlPing().payload();
        channelRead(pingFrame(true, pingData));

        assertEquals(1, handler.flowControlPing().getPingCount());
        assertEquals(1, handler.flowControlPing().getPingReturn());
    }

    @Test
    public void dataSizeSincePingAccumulates() throws Exception {
        manualSetUp();
        makeStream();
        AbstractNettyHandler handler = (AbstractNettyHandler) handler();
        handler.setAutoTuneFlowControl(true);
        long frameData = 123456;
        ByteBuf buff = ctx().alloc().buffer(16);
        buff.writeLong(frameData);
        int length = buff.readableBytes();

        channelRead(dataFrame(3, false, buff.copy()));
        channelRead(dataFrame(3, false, buff.copy()));
        channelRead(dataFrame(3, false, buff.copy()));

        assertEquals(length * 3, handler.flowControlPing().getDataSincePing());
    }

    @Test
    public void windowUpdateMatchesTarget() throws Exception {
        manualSetUp();
        Http2Stream connectionStream = connection().connectionStream();
        Http2LocalFlowController localFlowController = connection().local().flowController();
        makeStream();
        AbstractNettyHandler handler = (AbstractNettyHandler) handler();
        handler.setAutoTuneFlowControl(true);

        ByteBuf data = ctx().alloc().buffer(1024);
        while (data.isWritable()) {
            data.writeLong(1111);
        }
        int length = data.readableBytes();
        ByteBuf frame = dataFrame(3, false, data.copy());
        channelRead(frame);
        int accumulator = length;
        // 40 is arbitrary, any number large enough to trigger a window update would work
        for (int i = 0; i < 40; i++) {
            channelRead(dataFrame(3, false, data.copy()));
            accumulator += length;
        }
        long pingData = handler.flowControlPing().payload();
        channelRead(pingFrame(true, pingData));

        assertEquals(accumulator, handler.flowControlPing().getDataSincePing());
        assertEquals(2 * accumulator, localFlowController.initialWindowSize(connectionStream));
    }

    @Test
    public void windowShouldNotExceedMaxWindowSize() throws Exception {
        manualSetUp();
        makeStream();
        AbstractNettyHandler handler = (AbstractNettyHandler) handler();
        handler.setAutoTuneFlowControl(true);
        Http2Stream connectionStream = connection().connectionStream();
        Http2LocalFlowController localFlowController = connection().local().flowController();
        int maxWindow = handler.flowControlPing().maxWindow();

        handler.flowControlPing().setDataSizeAndSincePing(maxWindow);
        long payload = handler.flowControlPing().payload();
        channelRead(pingFrame(true, payload));

        assertEquals(maxWindow, localFlowController.initialWindowSize(connectionStream));
    }

    @Test
    public void transportTracer_windowSizeDefault() throws Exception {
        manualSetUp();
        TransportStats transportStats = transportTracer.getStats();
        assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow);
        assertEquals(flowControlWindow, transportStats.localFlowControlWindow);
    }

    @Test
    public void transportTracer_windowSize() throws Exception {
        flowControlWindow = 1024 * 1024;
        manualSetUp();
        TransportStats transportStats = transportTracer.getStats();
        assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow);
        assertEquals(flowControlWindow, transportStats.localFlowControlWindow);
    }

    @Test
    public void transportTracer_windowUpdate_remote() throws Exception {
        manualSetUp();
        TransportStats before = transportTracer.getStats();
        assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow);
        assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.localFlowControlWindow);

        ByteBuf serializedSettings = windowUpdate(0, 1000);
        channelRead(serializedSettings);
        TransportStats after = transportTracer.getStats();
        assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE + 1000, after.remoteFlowControlWindow);
        assertEquals(flowControlWindow, after.localFlowControlWindow);
    }

    @Test
    public void transportTracer_windowUpdate_local() throws Exception {
        manualSetUp();
        TransportStats before = transportTracer.getStats();
        assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow);
        assertEquals(flowControlWindow, before.localFlowControlWindow);

        // If the window size is below a certain threshold, netty will wait to apply the update.
        // Use a large increment to be sure that it exceeds the threshold.
        connection().local().flowController().incrementWindowSize(connection().connectionStream(),
                8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE);

        TransportStats after = transportTracer.getStats();
        assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, after.remoteFlowControlWindow);
        assertEquals(flowControlWindow + 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE,
                connection().local().flowController().windowSize(connection().connectionStream()));
    }
}