alluxio.client.block.stream.NettyPacketWriterTest.java Source code

Java tutorial

Introduction

Here is the source code for alluxio.client.block.stream.NettyPacketWriterTest.java

Source

/*
 * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0
 * (the "License"). You may not use this work except in compliance with the License, which is
 * available at www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied, as more fully set forth in the License.
 *
 * See the NOTICE file distributed with this work for information regarding copyright ownership.
 */

package alluxio.client.block.stream;

import alluxio.Constants;
import alluxio.EmbeddedChannels;
import alluxio.client.file.FileSystemContext;
import alluxio.network.protocol.RPCProtoMessage;
import alluxio.network.protocol.databuffer.DataBuffer;
import alluxio.network.protocol.databuffer.DataNettyBufferV2;
import alluxio.proto.dataserver.Protocol;
import alluxio.util.CommonUtils;
import alluxio.util.ThreadFactoryUtils;
import alluxio.util.WaitForOptions;
import alluxio.util.io.BufferUtils;

import com.google.common.base.Function;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

@RunWith(PowerMockRunner.class)
@PrepareForTest({ FileSystemContext.class })
public final class NettyPacketWriterTest {
    private static final Logger LOG = LoggerFactory.getLogger(NettyPacketWriterTest.class);

    private static final int PACKET_SIZE = 1024;
    private static final ExecutorService EXECUTOR = Executors.newFixedThreadPool(4,
            ThreadFactoryUtils.build("test-executor-%d", true));

    private static final Random RANDOM = new Random();
    private static final long BLOCK_ID = 1L;
    private static final long SESSION_ID = 2L;
    private static final int TIER = 0;

    private FileSystemContext mContext;
    private InetSocketAddress mAddress;
    private EmbeddedChannels.EmbeddedEmptyCtorChannel mChannel;

    @Before
    public void before() throws Exception {
        mContext = PowerMockito.mock(FileSystemContext.class);
        mAddress = Mockito.mock(InetSocketAddress.class);

        mChannel = new EmbeddedChannels.EmbeddedEmptyCtorChannel();
        PowerMockito.when(mContext.acquireNettyChannel(mAddress)).thenReturn(mChannel);
        PowerMockito.doNothing().when(mContext).releaseNettyChannel(mAddress, mChannel);
    }

    @After
    public void after() throws Exception {
        mChannel.close();
    }

    /**
     * Writes an empty file.
     */
    @Test(timeout = 1000 * 60)
    public void writeEmptyFile() throws Exception {
        Future<Long> checksumActual;
        try (PacketWriter writer = create(10)) {
            checksumActual = verifyWriteRequests(mChannel, 0, 10);
        }
        Assert.assertEquals(0, checksumActual.get().longValue());
    }

    /**
     * Writes a file with file length matches what is given and verifies the checksum of the whole
     * file.
     */
    @Test(timeout = 1000 * 60)
    public void writeFullFile() throws Exception {
        Future<Long> checksumActual;
        Future<Long> checksumExpected;
        long length = PACKET_SIZE * 1024 + PACKET_SIZE / 3;
        try (PacketWriter writer = create(length)) {
            checksumExpected = writeFile(writer, length, 0, length - 1);
            checksumActual = verifyWriteRequests(mChannel, 0, length - 1);
            checksumExpected.get();
        }
        Assert.assertEquals(checksumExpected.get(), checksumActual.get());
    }

    /**
     * Writes a file with file length matches what is given and verifies the checksum of the whole
     * file.
     */
    @Test(timeout = 1000 * 60)
    public void writeFileChecksumOfPartialFile() throws Exception {
        Future<Long> checksumActual;
        Future<Long> checksumExpected;
        long length = PACKET_SIZE * 1024 + PACKET_SIZE / 3;
        try (PacketWriter writer = create(length)) {
            checksumExpected = writeFile(writer, length, 10, length / 3);
            checksumActual = verifyWriteRequests(mChannel, 10, length / 3);
            checksumExpected.get();
        }
        Assert.assertEquals(checksumExpected.get(), checksumActual.get());
    }

    /**
     * Writes a file with unknown length.
     */
    @Test(timeout = 1000 * 60)
    public void writeFileUnknownLength() throws Exception {
        Future<Long> checksumActual;
        Future<Long> checksumExpected;
        long length = PACKET_SIZE * 1024;
        try (PacketWriter writer = create(Long.MAX_VALUE)) {
            checksumExpected = writeFile(writer, length, 10, length / 3);
            checksumActual = verifyWriteRequests(mChannel, 10, length / 3);
            checksumExpected.get();
        }
        Assert.assertEquals(checksumExpected.get(), checksumActual.get());
    }

    /**
     * Writes lots of packets.
     */
    @Test(timeout = 1000 * 60)
    public void writeFileManyPackets() throws Exception {
        Future<Long> checksumActual;
        Future<Long> checksumExpected;
        long length = PACKET_SIZE * 30000 + PACKET_SIZE / 3;
        try (PacketWriter writer = create(Long.MAX_VALUE)) {
            checksumExpected = writeFile(writer, length, 10, length / 3);
            checksumActual = verifyWriteRequests(mChannel, 10, length / 3);
            checksumExpected.get();
        }
        Assert.assertEquals(checksumExpected.get(), checksumActual.get());
    }

    /**
     * Creates a {@link PacketWriter}.
     *
     * @param length the length
     * @return the packet writer instance
     * @throws Exception if it fails to create the packet writer
     */
    private PacketWriter create(long length) throws Exception {
        PacketWriter writer = new NettyPacketWriter(mContext, mAddress, BLOCK_ID, length, SESSION_ID, TIER,
                Protocol.RequestType.ALLUXIO_BLOCK);
        mChannel.finishChannelCreation();
        return writer;
    }

    /**
     * Writes packets via the given packet writer and returns a checksum for a region of the data
     * written.
     *
     * @param length the length
     * @param start the start position to calculate the checksum
     * @param end the end position to calculate the checksum
     * @return the checksum
     */
    private Future<Long> writeFile(final PacketWriter writer, final long length, final long start, final long end)
            throws Exception {
        return EXECUTOR.submit(new Callable<Long>() {
            @Override
            public Long call() throws IOException {
                try {
                    long checksum = 0;
                    long pos = 0;

                    long remaining = length;
                    while (remaining > 0) {
                        int bytesToWrite = (int) Math.min(remaining, PACKET_SIZE);
                        byte[] data = new byte[bytesToWrite];
                        RANDOM.nextBytes(data);
                        ByteBuf buf = Unpooled.wrappedBuffer(data);
                        try {
                            writer.writePacket(buf);
                        } catch (IOException e) {
                            Assert.fail(e.getMessage());
                            throw e;
                        }
                        remaining -= bytesToWrite;

                        for (int i = 0; i < data.length; i++) {
                            if (pos >= start && pos <= end) {
                                checksum += BufferUtils.byteToInt(data[i]);
                            }
                            pos++;
                        }
                    }
                    return checksum;
                } catch (Throwable throwable) {
                    LOG.error("Failed to write file.", throwable);
                    Assert.fail();
                    throw throwable;
                }
            }
        });
    }

    /**
     * Verifies the packets written. After receiving the last packet, it will also send an EOF to
     * the channel.
     *
     * @param checksumStart the start position to calculate the checksum
     * @param checksumEnd the end position to calculate the checksum
     * @return the checksum of the data read starting from checksumStart
     */
    private Future<Long> verifyWriteRequests(final EmbeddedChannel channel, final long checksumStart,
            final long checksumEnd) {
        return EXECUTOR.submit(new Callable<Long>() {
            @Override
            public Long call() {
                try {
                    long checksum = 0;
                    long pos = 0;
                    while (true) {
                        RPCProtoMessage request = (RPCProtoMessage) CommonUtils.waitForResult("wrtie request",
                                new Function<Void, Object>() {
                                    @Override
                                    public Object apply(Void v) {
                                        return channel.readOutbound();
                                    }
                                }, WaitForOptions.defaults().setTimeout(Constants.MINUTE_MS));
                        validateWriteRequest(request.getMessage().<Protocol.WriteRequest>getMessage(), pos);

                        DataBuffer buffer = request.getPayloadDataBuffer();
                        // Last packet.
                        if (buffer == null) {
                            channel.writeInbound(RPCProtoMessage.createOkResponse(null));
                            return checksum;
                        }
                        try {
                            Assert.assertTrue(buffer instanceof DataNettyBufferV2);
                            ByteBuf buf = (ByteBuf) buffer.getNettyOutput();
                            while (buf.readableBytes() > 0) {
                                if (pos >= checksumStart && pos <= checksumEnd) {
                                    checksum += BufferUtils.byteToInt(buf.readByte());
                                } else {
                                    buf.readByte();
                                }
                                pos++;
                            }
                        } finally {
                            buffer.release();
                        }
                    }
                } catch (Throwable throwable) {
                    LOG.error("Failed to verify write requests.", throwable);
                    Assert.fail();
                    throw throwable;
                }
            }
        });
    }

    /**
     * Validates the read request sent.
     *
     * @param request the request
     * @param offset the offset
     */
    private void validateWriteRequest(Protocol.WriteRequest request, long offset) {
        Assert.assertEquals(BLOCK_ID, request.getId());
        Assert.assertEquals(offset, request.getOffset());
    }
}