org.apache.spark.network.util.TransportFrameDecoderSuite.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.spark.network.util.TransportFrameDecoderSuite.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.network.util;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import org.junit.AfterClass;
import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

public class TransportFrameDecoderSuite {

    private static Random RND = new Random();

    @AfterClass
    public static void cleanup() {
        RND = null;
    }

    @Test
    public void testFrameDecoding() throws Exception {
        TransportFrameDecoder decoder = new TransportFrameDecoder();
        ChannelHandlerContext ctx = mockChannelHandlerContext();
        ByteBuf data = createAndFeedFrames(100, decoder, ctx);
        verifyAndCloseDecoder(decoder, ctx, data);
    }

    @Test
    public void testInterception() throws Exception {
        int interceptedReads = 3;
        TransportFrameDecoder decoder = new TransportFrameDecoder();
        TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
        ChannelHandlerContext ctx = mockChannelHandlerContext();

        byte[] data = new byte[8];
        ByteBuf len = Unpooled.copyLong(8 + data.length);
        ByteBuf dataBuf = Unpooled.wrappedBuffer(data);

        try {
            decoder.setInterceptor(interceptor);
            for (int i = 0; i < interceptedReads; i++) {
                decoder.channelRead(ctx, dataBuf);
                assertEquals(0, dataBuf.refCnt());
                dataBuf = Unpooled.wrappedBuffer(data);
            }
            decoder.channelRead(ctx, len);
            decoder.channelRead(ctx, dataBuf);
            verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
            verify(ctx).fireChannelRead(any(ByteBuffer.class));
            assertEquals(0, len.refCnt());
            assertEquals(0, dataBuf.refCnt());
        } finally {
            release(len);
            release(dataBuf);
        }
    }

    @Test
    public void testRetainedFrames() throws Exception {
        TransportFrameDecoder decoder = new TransportFrameDecoder();

        AtomicInteger count = new AtomicInteger();
        List<ByteBuf> retained = new ArrayList<>();

        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
        when(ctx.fireChannelRead(any())).thenAnswer(in -> {
            // Retain a few frames but not others.
            ByteBuf buf = (ByteBuf) in.getArguments()[0];
            if (count.incrementAndGet() % 2 == 0) {
                retained.add(buf);
            } else {
                buf.release();
            }
            return null;
        });

        ByteBuf data = createAndFeedFrames(100, decoder, ctx);
        try {
            // Verify all retained buffers are readable.
            for (ByteBuf b : retained) {
                byte[] tmp = new byte[b.readableBytes()];
                b.readBytes(tmp);
                b.release();
            }
            verifyAndCloseDecoder(decoder, ctx, data);
        } finally {
            for (ByteBuf b : retained) {
                release(b);
            }
        }
    }

    @Test
    public void testSplitLengthField() throws Exception {
        byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
        ByteBuf buf = Unpooled.buffer(frame.length + 8);
        buf.writeLong(frame.length + 8);
        buf.writeBytes(frame);

        TransportFrameDecoder decoder = new TransportFrameDecoder();
        ChannelHandlerContext ctx = mockChannelHandlerContext();
        try {
            decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain());
            verify(ctx, never()).fireChannelRead(any(ByteBuf.class));
            decoder.channelRead(ctx, buf);
            verify(ctx).fireChannelRead(any(ByteBuf.class));
            assertEquals(0, buf.refCnt());
        } finally {
            decoder.channelInactive(ctx);
            release(buf);
        }
    }

    @Test(expected = IllegalArgumentException.class)
    public void testNegativeFrameSize() throws Exception {
        testInvalidFrame(-1);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testEmptyFrame() throws Exception {
        // 8 because frame size includes the frame length.
        testInvalidFrame(8);
    }

    /**
     * Creates a number of randomly sized frames and feed them to the given decoder, verifying
     * that the frames were read.
     */
    private ByteBuf createAndFeedFrames(int frameCount, TransportFrameDecoder decoder, ChannelHandlerContext ctx)
            throws Exception {
        ByteBuf data = Unpooled.buffer();
        for (int i = 0; i < frameCount; i++) {
            byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
            data.writeLong(frame.length + 8);
            data.writeBytes(frame);
        }

        try {
            while (data.isReadable()) {
                int size = RND.nextInt(4 * 1024) + 256;
                decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain());
            }

            verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
        } catch (Exception e) {
            release(data);
            throw e;
        }
        return data;
    }

    private void verifyAndCloseDecoder(TransportFrameDecoder decoder, ChannelHandlerContext ctx, ByteBuf data)
            throws Exception {
        try {
            decoder.channelInactive(ctx);
            assertTrue("There shouldn't be dangling references to the data.", data.release());
        } finally {
            release(data);
        }
    }

    private void testInvalidFrame(long size) throws Exception {
        TransportFrameDecoder decoder = new TransportFrameDecoder();
        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
        ByteBuf frame = Unpooled.copyLong(size);
        try {
            decoder.channelRead(ctx, frame);
        } finally {
            release(frame);
        }
    }

    private ChannelHandlerContext mockChannelHandlerContext() {
        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
        when(ctx.fireChannelRead(any())).thenAnswer(in -> {
            ByteBuf buf = (ByteBuf) in.getArguments()[0];
            buf.release();
            return null;
        });
        return ctx;
    }

    private void release(ByteBuf buf) {
        if (buf.refCnt() > 0) {
            buf.release(buf.refCnt());
        }
    }

    private static class MockInterceptor implements TransportFrameDecoder.Interceptor {

        private int remainingReads;

        MockInterceptor(int readCount) {
            this.remainingReads = readCount;
        }

        @Override
        public boolean handle(ByteBuf data) throws Exception {
            data.readerIndex(data.readerIndex() + data.readableBytes());
            assertFalse(data.isReadable());
            remainingReads -= 1;
            return remainingReads != 0;
        }

        @Override
        public void exceptionCaught(Throwable cause) throws Exception {

        }

        @Override
        public void channelInactive() throws Exception {

        }

    }

}