org.apache.flink.runtime.io.network.netty.OutboundConnectionQueueTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.runtime.io.network.netty.OutboundConnectionQueueTest.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.io.network.netty;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.timeout.IdleStateEvent;
import org.apache.flink.runtime.io.network.Envelope;
import org.apache.flink.runtime.io.network.NetworkConnectionManager;
import org.apache.flink.runtime.io.network.RemoteReceiver;
import org.apache.flink.runtime.io.network.channels.ChannelID;
import org.apache.flink.runtime.jobgraph.JobID;
import org.junit.Assert;
import org.mockito.Mockito;
import org.powermock.reflect.Whitebox;

import java.net.InetAddress;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;

public class OutboundConnectionQueueTest {

    private final static long RANDOM_SEED = 520346508276087l;

    private final Object lock = new Object();

    private Channel channel;

    private NetworkConnectionManager connectionManager;

    private RemoteReceiver receiver;

    private OutboundConnectionQueue queue;

    private TestControlHandler controller;

    private TestVerificationHandler verifier;

    private Throwable exception;

    private void initTest(boolean autoTriggerWrite) {
        controller = new TestControlHandler(autoTriggerWrite);
        verifier = new TestVerificationHandler();

        channel = Mockito.spy(new EmbeddedChannel(new ChannelInboundHandlerAdapter() {
            @Override
            public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
                exception = cause;
                super.exceptionCaught(ctx, cause);
            }
        }));

        connectionManager = Mockito.mock(NetworkConnectionManager.class);

        receiver = Mockito.mock(RemoteReceiver.class);

        queue = new OutboundConnectionQueue(channel, receiver, connectionManager, 0);

        channel.pipeline().addFirst("Test Control Handler", controller);
        channel.pipeline().addFirst("Test Verification Handler", verifier);

        exception = null;

        // The testing pipeline looks as follows:
        // - Test Verification Handler [OUT]
        // - Test Control Handler [IN]
        // - Idle State Handler [IN/OUT] [added by OutboundConnectionQueue]
        // - Outbound queue (SUT) [IN] [added by OutboundConnectionQueue]
        // - Exception setter [IN] [EmbeddedChannel constructor]
    }

    /**
     * Verifies that the channel is closed after an idle event, when
     * there are no queued envelopes.
     */
    public void testClose() throws Exception {
        initTest(false);

        JobID jid = new JobID();
        ChannelID cid = new ChannelID();

        Assert.assertTrue(queue.enqueue(new Envelope(1, jid, cid)));
        Assert.assertTrue(queue.enqueue(new Envelope(2, jid, cid)));
        Assert.assertTrue(queue.enqueue(new Envelope(3, jid, cid)));

        controller.triggerWrite();

        controller.fireIdle();

        verifier.waitForClose();

        verifier.verifyEnvelopeReceived(cid, 3);

        Mockito.verify(connectionManager, Mockito.times(1)).close(Mockito.any(RemoteReceiver.class));
    }

    /**
     * Verifies that the channel is not closed while there are queued
     * envelopes.
     */
    public void testCloseWithQueuedEnvelopes() throws Exception {
        initTest(true);

        final JobID jid = new JobID();
        final ChannelID cid = new ChannelID();
        final CountDownLatch sync = verifier.waitForEnvelopes(3, cid);

        // Make channel not writable => envelopes are queued
        Mockito.when(channel.isWritable()).thenReturn(false);

        Assert.assertTrue(queue.enqueue(new Envelope(1, jid, cid)));
        Assert.assertTrue(queue.enqueue(new Envelope(2, jid, cid)));
        Assert.assertTrue(queue.enqueue(new Envelope(3, jid, cid)));

        // Verify idle event doesn't close channel
        controller.fireIdle();

        Mockito.verify(connectionManager, Mockito.times(0)).close(Mockito.any(RemoteReceiver.class));

        Boolean hasRequestedClose = Whitebox.<Boolean>getInternalState(queue, "hasRequestedClose");
        Assert.assertFalse("Close request while envelope in flight.", hasRequestedClose);

        // Change writability of channel back to writable
        Mockito.when(channel.isWritable()).thenReturn(true);
        channel.pipeline().fireChannelWritabilityChanged();

        // Wait for the processing of queued envelopes
        while (sync.getCount() != 0) {
            sync.await();
        }

        verifier.verifyEnvelopeReceived(cid, 3);

        // Now close again
        controller.fireIdle();
        verifier.waitForClose();

        Mockito.verify(connectionManager, Mockito.times(1)).close(Mockito.any(RemoteReceiver.class));
    }

    /**
     * Verifies that envelopes are delegated back to the connection
     * manager after a close.
     */
    public void testEnqueueAfterClose() throws Exception {
        initTest(true);

        // Immediately close the channel
        controller.fireIdle();
        verifier.waitForClose();

        Assert.assertFalse(queue.enqueue(new Envelope(1, new JobID(), new ChannelID())));
    }

    /**
     * Verifies that multiple idle events are handled correctly.
     */
    public void testMultipleIdleEvents() throws Exception {
        initTest(true);

        controller.fireIdle();
        verifier.waitForClose();

        controller.fireIdle();

        // Second close should not cause an exception in the
        // verification handler.
        Assert.assertNull(exception);
    }

    /**
     * Verifies that unknown user events throw an exception.
     */
    public void testUnknownUserEvent() throws Exception {
        initTest(true);

        Assert.assertNull(exception);

        controller.context.fireUserEventTriggered("Unknown user event");

        Assert.assertNotNull(exception);
        Assert.assertTrue(exception instanceof IllegalStateException);
    }

    // ------------------------------------------------------------------------

    public void testConcurrentEnqueueAndClose() throws Exception {
        Integer[][] configs = new Integer[][] { { 1, 512, 0, 0 }, { 1, 512, 40, 80 }, { 2, 512, 40, 80 },
                { 4, 512, 40, 80 }, { 8, 512, 40, 80 }, { 32, 512, 40, 80 }, { 128, 512, 40, 80 },
                { 256, 512, 40, 80 }, { 512, 512, 40, 80 } };

        for (Integer[] params : configs) {
            System.out.println(String.format(
                    "Running %s with config: %d producers, %d envelopes to send per producer, "
                            + "%d ms min sleep time, %d ms max sleep time.",
                    "testConcurrentEnqueueAndClose", params[0], params[1], params[2], params[3]));

            long start = System.currentTimeMillis();
            doTestConcurrentEnqueueAndClose(params[0], params[1], params[2], params[3]);
            long end = System.currentTimeMillis();

            System.out.println(String.format("Runtime: %d ms.", (end - start)));
        }
    }

    /**
     * Verifies that concurrent enqueue and close events are handled
     * correctly.
     */
    private void doTestConcurrentEnqueueAndClose(final int numProducers, final int numEnvelopesPerProducer,
            final int minSleepTimeMs, final int maxSleepTimeMs) throws Exception {

        final InetAddress bindHost = InetAddress.getLocalHost();
        final int bindPort = 20000;

        // Testing concurrent enqueue and close requires real TCP channels,
        // because Netty's testing EmbeddedChannel does not implement the
        // same threading model as the NioEventLoopGroup (for example there
        // is no difference between being IN and OUTSIDE of the event loop
        // thread).

        final ServerBootstrap in = new ServerBootstrap();
        in.group(new NioEventLoopGroup(1)).channel(NioServerSocketChannel.class).localAddress(bindHost, bindPort)
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    public void initChannel(SocketChannel channel) throws Exception {
                        channel.pipeline().addLast(new ChannelInboundHandlerAdapter());
                    }
                });

        final Bootstrap out = new Bootstrap();
        out.group(new NioEventLoopGroup(1)).channel(NioSocketChannel.class)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    public void initChannel(SocketChannel channel) throws Exception {
                        channel.pipeline().addLast(new ChannelOutboundHandlerAdapter());
                    }
                }).option(ChannelOption.TCP_NODELAY, false).option(ChannelOption.SO_KEEPALIVE, true);

        in.bind().sync();

        // --------------------------------------------------------------------

        // The testing pipeline looks as follows:
        // - Test Verification Handler [OUT]
        // - Test Control Handler [IN]
        // - Idle State Handler [IN/OUT] [added by OutboundConnectionQueue]
        // - Outbound queue (SUT) [IN] [added by OutboundConnectionQueue]

        channel = out.connect(bindHost, bindPort).sync().channel();

        queue = new OutboundConnectionQueue(channel, receiver, connectionManager, 0);

        controller = new TestControlHandler(true);
        verifier = new TestVerificationHandler();

        channel.pipeline().addFirst("Test Control Handler", controller);
        channel.pipeline().addFirst("Test Verification Handler", verifier);

        // --------------------------------------------------------------------

        final Random rand = new Random(RANDOM_SEED);

        // Every producer works on their local reference of the queue and only
        // updates it to the new channel when enqueue returns false, which
        // should only happen if the channel has been closed.
        final ConcurrentMap<ChannelID, OutboundConnectionQueue> producerQueues = new ConcurrentHashMap<ChannelID, OutboundConnectionQueue>();

        final ChannelID[] ids = new ChannelID[numProducers];

        for (int i = 0; i < numProducers; i++) {
            ids[i] = new ChannelID();

            producerQueues.put(ids[i], queue);
        }

        final CountDownLatch receivedAllEnvelopesLatch = verifier.waitForEnvelopes(numEnvelopesPerProducer - 1,
                ids);

        final List<Channel> closedChannels = new ArrayList<Channel>();

        // --------------------------------------------------------------------

        final Runnable closer = new Runnable() {
            @Override
            public void run() {
                while (receivedAllEnvelopesLatch.getCount() != 0) {
                    try {
                        controller.fireIdle();

                        // Test two idle events arriving "closely"
                        // after each other
                        if (rand.nextBoolean()) {
                            controller.fireIdle();
                        }

                        Thread.sleep(minSleepTimeMs / 2);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        };

        final Runnable[] producers = new Runnable[numProducers];

        for (int i = 0; i < numProducers; i++) {
            final int index = i;

            producers[i] = new Runnable() {
                @Override
                public void run() {
                    final JobID jid = new JobID();
                    final ChannelID cid = ids[index];

                    for (int j = 0; j < numEnvelopesPerProducer; j++) {
                        OutboundConnectionQueue localQueue = producerQueues.get(cid);

                        try {
                            // This code path is handled by the NetworkConnectionManager
                            // in production to enqueue the envelope either to the current
                            // channel or a new one if it was closed.
                            while (!localQueue.enqueue(new Envelope(j, jid, cid))) {
                                synchronized (lock) {
                                    if (localQueue == queue) {
                                        closedChannels.add(channel);

                                        channel = out.connect(bindHost, bindPort).sync().channel();

                                        queue = new OutboundConnectionQueue(channel, receiver, connectionManager,
                                                0);

                                        channel.pipeline().addFirst("Test Control Handler", controller);
                                        channel.pipeline().addFirst("Test Verification Handler", verifier);
                                    }
                                }

                                producerQueues.put(cid, queue);
                                localQueue = queue;
                            }

                            int sleepTime = rand.nextInt((maxSleepTimeMs - minSleepTimeMs) + 1) + minSleepTimeMs;
                            Thread.sleep(sleepTime);
                        } catch (InterruptedException e) {
                            throw new RuntimeException(e);
                        }
                    }
                }
            };
        }

        for (int i = 0; i < numProducers; i++) {
            new Thread(producers[i], "Producer " + i).start();
        }

        new Thread(closer, "Closer").start();

        // --------------------------------------------------------------------

        while (receivedAllEnvelopesLatch.getCount() != 0) {
            receivedAllEnvelopesLatch.await();
        }

        // Final close, if the last close didn't make it.
        synchronized (lock) {
            if (channel != null) {
                controller.fireIdle();
            }
        }

        verifier.waitForClose();

        // If the producers do not sleep after each envelope, the close
        // should not make it through and no channel should have been
        // added to the list of closed channels
        if (minSleepTimeMs == 0 && maxSleepTimeMs == 0) {
            Assert.assertEquals(0, closedChannels.size());
        }

        for (Channel ch : closedChannels) {
            Assert.assertFalse(ch.isOpen());
        }

        System.out.println(closedChannels.size() + " channels were closed during execution.");

        out.group().shutdownGracefully().sync();
        in.group().shutdownGracefully().sync();
    }

    // ------------------------------------------------------------------------

    // Handler to control the flow of Netty's custom user events. Used
    // to fire idle events and manually trigger write events.
    @ChannelHandler.Sharable
    private static class TestControlHandler extends ChannelInboundHandlerAdapter {

        private final Queue<Object> events = new ArrayDeque<Object>();

        private final boolean forwardUserEvents;

        private ChannelHandlerContext context;

        private TestControlHandler(boolean forwardUserEvents) {
            this.forwardUserEvents = forwardUserEvents;
        }

        @Override
        public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
            synchronized (this) {
                context = ctx;
            }
        }

        @Override
        public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
            if (forwardUserEvents) {
                ctx.fireUserEventTriggered(evt);
            } else {
                events.add(evt);
            }
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            super.exceptionCaught(ctx, cause);
        }

        public void triggerWrite() {
            synchronized (this) {
                if (!events.isEmpty()) {
                    context.fireUserEventTriggered(events.remove());
                }
            }
        }

        public void fireIdle() {
            synchronized (this) {
                context.fireUserEventTriggered(IdleStateEvent.ALL_IDLE_STATE_EVENT);
            }
        }
    }

    // Handler to verify writes and close events.
    @ChannelHandler.Sharable
    private static class TestVerificationHandler extends ChannelOutboundHandlerAdapter {

        private final Map<ChannelID, Integer> receivedSequenceNums = new HashMap<ChannelID, Integer>();

        private final Map<ChannelID, Integer> expectedSequenceNums = new HashMap<ChannelID, Integer>();

        private final Map<ChannelID, CountDownLatch> envelopeLatches = new HashMap<ChannelID, CountDownLatch>();

        private CountDownLatch closeLatch;

        @Override
        public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
            closeLatch = new CountDownLatch(1);

            super.handlerAdded(ctx);
        }

        @Override
        public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            if (closeLatch.getCount() == 0) {
                throw new IllegalStateException("Write on closed channel.");
            }

            if (msg.getClass() == Envelope.class) {
                Envelope env = (Envelope) msg;

                final int currentSeqNum = env.getSequenceNumber();
                final ChannelID source = env.getSource();

                Integer previousSeqNum = receivedSequenceNums.put(source, currentSeqNum);

                if (previousSeqNum != null) {
                    String errMsg = String.format("Received %s with unexpected sequence number.", env);
                    Assert.assertEquals(errMsg, previousSeqNum + 1, currentSeqNum);
                }

                promise.setSuccess();

                Integer expectedSeqNum = expectedSequenceNums.get(source);
                if (expectedSeqNum != null && expectedSeqNum.equals(currentSeqNum)) {
                    envelopeLatches.remove(source).countDown();
                }
            }
        }

        @Override
        public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
            if (closeLatch.getCount() == 0) {
                throw new IllegalStateException("Received multiple close events.");
            }

            super.close(ctx, promise);

            closeLatch.countDown();
        }

        public void waitForClose() throws InterruptedException {
            while (closeLatch.getCount() != 0) {
                closeLatch.await();
            }
        }

        /**
         * Verifies that envelope with expected sequence number has been
         * processed by the handler (or no envelope has been received by
         * the given source, if expected sequence number is null).
         */
        public void verifyEnvelopeReceived(ChannelID source, Integer expectedSequenceNum) {
            if (expectedSequenceNum == null && receivedSequenceNums.containsKey(source)) {
                Assert.fail("Received unexpected envelope from channel " + source);
            } else if (receivedSequenceNums.containsKey(source)) {
                Assert.assertEquals(expectedSequenceNum, receivedSequenceNums.get(source));
            } else {
                Assert.fail("Did not receive any envelope from channel " + source);
            }
        }

        public CountDownLatch waitForEnvelopes(Integer expectedSequenceNum, ChannelID... ids) {
            CountDownLatch latch = new CountDownLatch(ids.length);

            for (ChannelID id : ids) {
                expectedSequenceNums.put(id, expectedSequenceNum);
                envelopeLatches.put(id, latch);
            }

            return latch;
        }
    }

    // --------------------------------------------------------------------

    public void runAllTests() throws Exception {
        testClose();
        testCloseWithQueuedEnvelopes();
        testEnqueueAfterClose();
        testUnknownUserEvent();
        testMultipleIdleEvents();
        testConcurrentEnqueueAndClose();
    }

    public static void main(String[] args) throws Exception {
        new OutboundConnectionQueueTest().runAllTests();
    }
}