Java tutorial
/* * Copyright 2014 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.grpc.netty; import static com.google.common.base.Charsets.UTF_8; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; import static org.junit.Assert.assertEquals; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.InternalChannelz.TransportStats; import io.grpc.internal.FakeClock; import io.grpc.internal.MessageFramer; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.internal.WritableBuffer; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http2.DefaultHttp2FrameReader; import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; import io.netty.handler.codec.http2.Http2CodecUtil; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2ConnectionHandler; import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.codec.http2.Http2FrameReader; import io.netty.handler.codec.http2.Http2FrameWriter; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2HeadersDecoder; import io.netty.handler.codec.http2.Http2LocalFlowController; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.ScheduledFuture; import java.io.ByteArrayInputStream; import java.util.concurrent.Delayed; import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; /** * Base class for Netty handler unit tests. */ @RunWith(JUnit4.class) public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> { private ByteBuf content; private EmbeddedChannel channel; private ChannelHandlerContext ctx; private Http2FrameWriter frameWriter; private Http2FrameReader frameReader; private T handler; private WriteQueue writeQueue; /** * Does additional setup jobs. Call it manually when necessary. */ protected void manualSetUp() throws Exception { } protected final TransportTracer transportTracer = new TransportTracer(); protected int flowControlWindow = DEFAULT_WINDOW_SIZE; private final FakeClock fakeClock = new FakeClock(); FakeClock fakeClock() { return fakeClock; } /** * Must be called by subclasses to initialize the handler and channel. */ protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception { content = Unpooled.copiedBuffer("hello world", UTF_8); frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter())); frameReader = new DefaultHttp2FrameReader(headersDecoder); channel = new FakeClockSupportedChanel(); handler = newHandler(); channel.pipeline().addLast(handler); ctx = channel.pipeline().context(handler); writeQueue = initWriteQueue(); } private final class FakeClockSupportedChanel extends EmbeddedChannel { EventLoop eventLoop; FakeClockSupportedChanel(ChannelHandler... handlers) { super(handlers); } @Override public EventLoop eventLoop() { if (eventLoop == null) { createEventLoop(); } return eventLoop; } void createEventLoop() { EventLoop realEventLoop = super.eventLoop(); if (realEventLoop == null) { return; } eventLoop = mock(EventLoop.class, delegatesTo(realEventLoop)); doAnswer(new Answer<ScheduledFuture<Void>>() { @Override public ScheduledFuture<Void> answer(InvocationOnMock invocation) throws Throwable { Runnable command = (Runnable) invocation.getArguments()[0]; Long delay = (Long) invocation.getArguments()[1]; TimeUnit timeUnit = (TimeUnit) invocation.getArguments()[2]; return new FakeClockScheduledNettyFuture(eventLoop, command, delay, timeUnit); } }).when(eventLoop).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); } } private final class FakeClockScheduledNettyFuture extends DefaultPromise<Void> implements ScheduledFuture<Void> { final java.util.concurrent.ScheduledFuture<?> future; FakeClockScheduledNettyFuture(EventLoop eventLoop, final Runnable command, long delay, TimeUnit timeUnit) { super(eventLoop); Runnable wrap = new Runnable() { @Override public void run() { try { command.run(); } catch (Throwable t) { setFailure(t); return; } if (!isDone()) { Promise<Void> unused = setSuccess(null); } // else: The command itself, such as a shutdown task, might have cancelled all the // scheduled tasks already. } }; future = fakeClock.getScheduledExecutorService().schedule(wrap, delay, timeUnit); } @Override public boolean cancel(boolean mayInterruptIfRunning) { if (future.cancel(mayInterruptIfRunning)) { return super.cancel(mayInterruptIfRunning); } return false; } @Override public long getDelay(TimeUnit unit) { return Math.max(future.getDelay(unit), 1L); // never return zero or negative delay. } @Override public int compareTo(Delayed o) { return future.compareTo(o); } } protected final T handler() { return handler; } protected final EmbeddedChannel channel() { return channel; } protected final ChannelHandlerContext ctx() { return ctx; } protected final Http2FrameWriter frameWriter() { return frameWriter; } protected final Http2FrameReader frameReader() { return frameReader; } protected final ByteBuf content() { return content; } protected final byte[] contentAsArray() { return ByteBufUtil.getBytes(content()); } protected final Http2FrameWriter verifyWrite() { return verify(frameWriter); } protected final Http2FrameWriter verifyWrite(VerificationMode verificationMode) { return verify(frameWriter, verificationMode); } protected final void channelRead(Object obj) throws Exception { channel.writeInbound(obj); } protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { final ByteBuf compressionFrame = Unpooled.buffer(content.length); MessageFramer framer = new MessageFramer(new MessageFramer.Sink() { @Override public void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) { if (frame != null) { ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); compressionFrame.writeBytes(bytebuf); } } }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), StatsTraceContext.NOOP); framer.writePayload(new ByteArrayInputStream(content)); framer.flush(); ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream, newPromise()); return captureWrite(ctx); } protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { // Need to retain the content since the frameWriter releases it. content.retain(); ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise()); return captureWrite(ctx); } protected final ByteBuf pingFrame(boolean ack, long payload) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writePing(ctx, ack, payload, newPromise()); return captureWrite(ctx); } protected final ByteBuf headersFrame(int streamId, Http2Headers headers) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeHeaders(ctx, streamId, headers, 0, false, newPromise()); return captureWrite(ctx); } protected final ByteBuf goAwayFrame(int lastStreamId) { return goAwayFrame(lastStreamId, 0, Unpooled.EMPTY_BUFFER); } protected final ByteBuf goAwayFrame(int lastStreamId, int errorCode, ByteBuf data) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeGoAway(ctx, lastStreamId, errorCode, data, newPromise()); return captureWrite(ctx); } protected final ByteBuf rstStreamFrame(int streamId, int errorCode) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeRstStream(ctx, streamId, errorCode, newPromise()); return captureWrite(ctx); } protected final ByteBuf serializeSettings(Http2Settings settings) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeSettings(ctx, settings, newPromise()); return captureWrite(ctx); } protected final ByteBuf windowUpdate(int streamId, int delta) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeWindowUpdate(ctx, 0, delta, newPromise()); return captureWrite(ctx); } protected final ChannelPromise newPromise() { return channel.newPromise(); } protected final Http2Connection connection() { return handler().connection(); } @CanIgnoreReturnValue protected final ChannelFuture enqueue(WriteQueue.QueuedCommand command) { ChannelFuture future = writeQueue.enqueue(command, true); channel.runPendingTasks(); return future; } protected final ChannelHandlerContext newMockContext() { ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); EventLoop eventLoop = mock(EventLoop.class); when(ctx.executor()).thenReturn(eventLoop); when(ctx.channel()).thenReturn(channel); return ctx; } protected final ByteBuf captureWrite(ChannelHandlerContext ctx) { ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); verify(ctx, atLeastOnce()).write(captor.capture(), any(ChannelPromise.class)); CompositeByteBuf composite = Unpooled.compositeBuffer(); for (ByteBuf buf : captor.getAllValues()) { composite.addComponent(buf); composite.writerIndex(composite.writerIndex() + buf.readableBytes()); } return composite; } protected abstract T newHandler() throws Http2Exception; protected abstract WriteQueue initWriteQueue(); protected abstract void makeStream() throws Exception; @Test public void dataPingSentOnHeaderRecieved() throws Exception { manualSetUp(); makeStream(); AbstractNettyHandler handler = (AbstractNettyHandler) handler(); handler.setAutoTuneFlowControl(true); channelRead(dataFrame(3, false, content())); assertEquals(1, handler.flowControlPing().getPingCount()); } @Test public void dataPingAckIsRecognized() throws Exception { manualSetUp(); makeStream(); AbstractNettyHandler handler = (AbstractNettyHandler) handler(); handler.setAutoTuneFlowControl(true); channelRead(dataFrame(3, false, content())); long pingData = handler.flowControlPing().payload(); channelRead(pingFrame(true, pingData)); assertEquals(1, handler.flowControlPing().getPingCount()); assertEquals(1, handler.flowControlPing().getPingReturn()); } @Test public void dataSizeSincePingAccumulates() throws Exception { manualSetUp(); makeStream(); AbstractNettyHandler handler = (AbstractNettyHandler) handler(); handler.setAutoTuneFlowControl(true); long frameData = 123456; ByteBuf buff = ctx().alloc().buffer(16); buff.writeLong(frameData); int length = buff.readableBytes(); channelRead(dataFrame(3, false, buff.copy())); channelRead(dataFrame(3, false, buff.copy())); channelRead(dataFrame(3, false, buff.copy())); assertEquals(length * 3, handler.flowControlPing().getDataSincePing()); } @Test public void windowUpdateMatchesTarget() throws Exception { manualSetUp(); Http2Stream connectionStream = connection().connectionStream(); Http2LocalFlowController localFlowController = connection().local().flowController(); makeStream(); AbstractNettyHandler handler = (AbstractNettyHandler) handler(); handler.setAutoTuneFlowControl(true); ByteBuf data = ctx().alloc().buffer(1024); while (data.isWritable()) { data.writeLong(1111); } int length = data.readableBytes(); ByteBuf frame = dataFrame(3, false, data.copy()); channelRead(frame); int accumulator = length; // 40 is arbitrary, any number large enough to trigger a window update would work for (int i = 0; i < 40; i++) { channelRead(dataFrame(3, false, data.copy())); accumulator += length; } long pingData = handler.flowControlPing().payload(); channelRead(pingFrame(true, pingData)); assertEquals(accumulator, handler.flowControlPing().getDataSincePing()); assertEquals(2 * accumulator, localFlowController.initialWindowSize(connectionStream)); } @Test public void windowShouldNotExceedMaxWindowSize() throws Exception { manualSetUp(); makeStream(); AbstractNettyHandler handler = (AbstractNettyHandler) handler(); handler.setAutoTuneFlowControl(true); Http2Stream connectionStream = connection().connectionStream(); Http2LocalFlowController localFlowController = connection().local().flowController(); int maxWindow = handler.flowControlPing().maxWindow(); handler.flowControlPing().setDataSizeAndSincePing(maxWindow); long payload = handler.flowControlPing().payload(); channelRead(pingFrame(true, payload)); assertEquals(maxWindow, localFlowController.initialWindowSize(connectionStream)); } @Test public void transportTracer_windowSizeDefault() throws Exception { manualSetUp(); TransportStats transportStats = transportTracer.getStats(); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow); assertEquals(flowControlWindow, transportStats.localFlowControlWindow); } @Test public void transportTracer_windowSize() throws Exception { flowControlWindow = 1024 * 1024; manualSetUp(); TransportStats transportStats = transportTracer.getStats(); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow); assertEquals(flowControlWindow, transportStats.localFlowControlWindow); } @Test public void transportTracer_windowUpdate_remote() throws Exception { manualSetUp(); TransportStats before = transportTracer.getStats(); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.localFlowControlWindow); ByteBuf serializedSettings = windowUpdate(0, 1000); channelRead(serializedSettings); TransportStats after = transportTracer.getStats(); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE + 1000, after.remoteFlowControlWindow); assertEquals(flowControlWindow, after.localFlowControlWindow); } @Test public void transportTracer_windowUpdate_local() throws Exception { manualSetUp(); TransportStats before = transportTracer.getStats(); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow); assertEquals(flowControlWindow, before.localFlowControlWindow); // If the window size is below a certain threshold, netty will wait to apply the update. // Use a large increment to be sure that it exceeds the threshold. connection().local().flowController().incrementWindowSize(connection().connectionStream(), 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE); TransportStats after = transportTracer.getStats(); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, after.remoteFlowControlWindow); assertEquals(flowControlWindow + 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE, connection().local().flowController().windowSize(connection().connectionStream())); } }