com.github.zk1931.jzab.transport.NettyTransportTest.java Source code

Java tutorial

Introduction

Here is the source code for com.github.zk1931.jzab.transport.NettyTransportTest.java

Source

/**
 * Licensed to the zk1931 under one or more contributor license
 * agreements. See the NOTICE file distributed with this work
 * for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License,
 * Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License.  You may obtain a copy of the
 * License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.zk1931.jzab.transport;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import com.github.zk1931.jzab.MessageBuilder;
import com.github.zk1931.jzab.TestBase;
import com.github.zk1931.jzab.proto.ZabMessage.Message;
import com.github.zk1931.jzab.proto.ZabMessage.Message.MessageType;
import com.github.zk1931.jzab.SslParameters;
import com.github.zk1931.jzab.Zxid;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Test NettyTransport.
 */
public class NettyTransportTest extends TestBase {
    private static final Logger LOG = LoggerFactory.getLogger(NettyTransportTest.class);

    private static final Transport.Receiver NOOP = new Transport.Receiver() {
        @Override
        public void onReceived(String source, Message message) {
        }

        @Override
        public void onDisconnected(String source) {
        }
    };

    public static ByteBuffer createByteBuffer(int num) {
        ByteBuffer bb = ByteBuffer.allocate(4);
        bb.putInt(num);
        bb.flip();
        return bb;
    }

    public static Message createAck(Zxid zxid) {
        return MessageBuilder.buildAck(zxid);
    }

    /**
     * Make sure the constructor fails when the port is invalid.
     */
    @Test(timeout = 10000, expected = IllegalArgumentException.class)
    public void testInvalidPort() throws Exception {
        NettyTransport transport = new NettyTransport(getHostPort(-1), NOOP, getDirectory());
    }

    @Test(timeout = 1000)
    public void testLocalSend() throws Exception {
        final String localId = getUniqueHostPort();

        // receiver simply appends messages to a list.
        final LinkedList<Zxid> messages = new LinkedList<>();
        Transport.Receiver receiver = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.assertEquals(localId, source);
                Zxid zxid = MessageBuilder.fromProtoZxid(message.getAck().getZxid());
                messages.add(zxid);
            }

            public void onDisconnected(String source) {
            }
        };

        // send messages to itself.
        NettyTransport transport = new NettyTransport(localId, receiver, getDirectory());
        for (int i = 0; i < 20; i++) {
            transport.send(localId, createAck(new Zxid(0, i)));
        }

        // receive messages.
        for (int i = 0; i < 20; i++) {
            Zxid zxid = messages.pop();
            LOG.debug("Received a message: {}", zxid);
            Assert.assertEquals(new Zxid(0, i), zxid);
        }
        Assert.assertTrue(messages.isEmpty());
    }

    @Test(timeout = 5000)
    public void testConnectFailed() throws Exception {
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final CountDownLatch disconnected = new CountDownLatch(1);
        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
                disconnected.countDown();
            }
        };
        NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        transportA.send(peerB, createAck(new Zxid(0, 0)));
        disconnected.await();
    }

    @Test(timeout = 10000)
    public void testSend() throws Exception {
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();

        // receiver simply appends messages to a list.
        int messageCount = 20;
        final LinkedList<Zxid> messagesA = new LinkedList<>();
        final LinkedList<Zxid> messagesB = new LinkedList<>();

        final CountDownLatch latchA = new CountDownLatch(messageCount);
        final CountDownLatch latchB = new CountDownLatch(messageCount);

        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.assertEquals(peerB, source);
                Zxid zxid = MessageBuilder.fromProtoZxid(message.getAck().getZxid());
                messagesA.add(zxid);
                latchA.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.assertEquals(peerA, source);
                Zxid zxid = MessageBuilder.fromProtoZxid(message.getAck().getZxid());
                messagesB.add(zxid);
                latchB.countDown();
            }

            public void onDisconnected(String source) {
            }
        };

        NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());

        // send messages from A to B.
        for (int i = 0; i < messageCount; i++) {
            transportA.send(peerB, createAck(new Zxid(0, i)));
        }
        latchB.await();
        for (int i = 0; i < messageCount; i++) {
            Zxid zxid = messagesB.pop();
            LOG.debug("Received a message: {}", zxid);
            Assert.assertEquals(new Zxid(0, i), zxid);
        }
        Assert.assertTrue(messagesB.isEmpty());

        // send messages from B to A.
        for (int i = 0; i < messageCount; i++) {
            transportB.send(peerA, createAck(new Zxid(0, i)));
        }
        latchA.await();
        for (int i = 0; i < messageCount; i++) {
            Zxid zxid = messagesA.pop();
            LOG.debug("Received a message: {}", zxid);
            Assert.assertEquals(new Zxid(0, i), zxid);
        }
        Assert.assertTrue(messagesA.isEmpty());
        transportA.shutdown();
        transportB.shutdown();
    }

    @Test(timeout = 5000)
    public void testDisconnectClient() throws Exception {
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final CountDownLatch latchA = new CountDownLatch(1);
        final CountDownLatch latchB = new CountDownLatch(1);
        final CountDownLatch disconnectedB = new CountDownLatch(1);

        // receiver simply decrement the latch.
        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.assertEquals(peerB, source);
                latchA.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.assertEquals(peerA, source);
                latchB.countDown();
            }

            public void onDisconnected(String source) {
                disconnectedB.countDown();
            }
        };
        NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());

        // A initiates a handshake.
        transportA.send(peerB, createAck(new Zxid(0, 0)));
        latchB.await();

        // shutdown A and make sure B removes the channel to A.
        transportA.shutdown();
        Assert.assertTrue(transportA.senders.isEmpty());
        disconnectedB.await();
        Assert.assertTrue(transportB.senders.containsKey(peerA));
        transportB.shutdown();
    }

    @Test(timeout = 10000)
    public void testDisconnectServer() throws Exception {
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final CountDownLatch latchA = new CountDownLatch(1);
        final CountDownLatch latchB = new CountDownLatch(1);

        // receiver simply decrement the latch.
        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
                latchA.countDown();
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.assertEquals(peerA, source);
                latchB.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());

        // A initiates a handshake.
        transportA.send(peerB, createAck(new Zxid(0, 0)));
        latchB.await();

        // shutdown B and make sure A removes the channel to B.
        transportB.shutdown();
        Assert.assertTrue(transportB.senders.isEmpty());

        // A should get onDisconnected event, but B should still be in the map.
        latchA.await();
        Assert.assertTrue(transportA.senders.containsKey(peerB));
        transportA.shutdown();
    }

    @Test(timeout = 10000)
    public void testTieBreaker() throws Exception {
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final CountDownLatch disconnectedA = new CountDownLatch(1);
        final CountDownLatch disconnectedB = new CountDownLatch(1);

        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.fail("Handshake should have failed");
            }

            public void onDisconnected(String source) {
                LOG.debug("Got disconnected from {}", source);
                disconnectedA.countDown();
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.fail("Handshake should have failed");
            }

            public void onDisconnected(String source) {
                LOG.debug("Got disconnected from {}", source);
                disconnectedB.countDown();
            }
        };

        final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());
        transportB.channel.pipeline().addFirst(new ChannelInboundHandlerAdapter() {
            @Override
            public void channelRead(ChannelHandlerContext ctx, Object msg) {
                // B initiates another handshake before responding to A's handshake.
                transportB.send(peerA, createAck(new Zxid(0, 0)));
                ctx.pipeline().remove(this);
                ctx.fireChannelRead(msg);
            }
        });

        // A initiates a handshake.
        transportA.send(peerB, createAck(new Zxid(0, 0)));
        disconnectedA.await();
        disconnectedB.await();
        transportA.shutdown();
        transportB.shutdown();
    }

    @Test(timeout = 5000)
    public void testCloseClient() throws Exception {
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final int messageCount = 100;
        final CountDownLatch latchB = new CountDownLatch(messageCount);
        final CountDownLatch disconnectedB = new CountDownLatch(1);

        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.assertEquals(peerA, source);
                latchB.countDown();
                LOG.debug("Received a message from {}: {}: {}", source, message, latchB.getCount());
            }

            public void onDisconnected(String source) {
                LOG.debug("Got disconnected from {}", source);
                disconnectedB.countDown();
            }
        };

        final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());

        for (int i = 0; i < messageCount; i++) {
            transportA.send(peerB, createAck(new Zxid(0, i)));
        }
        latchB.await();

        // A should remove B from the map after clear() is called.
        transportA.clear(peerB);
        Assert.assertFalse(transportA.senders.containsKey(peerB));

        // B should get onDisconnected event, but A should be in the map.
        disconnectedB.await();
        Assert.assertTrue(transportB.senders.containsKey(peerA));
        transportA.shutdown();
        transportB.shutdown();
    }

    @Test(timeout = 10000)
    public void testCloseServer() throws Exception {
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final int messageCount = 100;
        final CountDownLatch latchB = new CountDownLatch(messageCount);
        final CountDownLatch disconnectedA = new CountDownLatch(1);
        final CountDownLatch disconnectedB = new CountDownLatch(1);

        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
                LOG.debug("Got disconnected from {}", source);
                disconnectedA.countDown();
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                Assert.assertEquals(peerA, source);
                latchB.countDown();
                LOG.debug("Received a message from {}: {}: {}", source, message, latchB.getCount());
            }

            public void onDisconnected(String source) {
            }
        };

        final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());

        for (int i = 0; i < messageCount; i++) {
            transportA.send(peerB, createAck(new Zxid(0, i)));
        }
        latchB.await();

        // B should remove A from the map after clear() is called.
        transportB.clear(peerA);
        Assert.assertFalse(transportB.senders.containsKey(peerA));

        // A should get onDisconnected event, but B should be in the map.
        disconnectedA.await();
        Assert.assertTrue(transportA.senders.containsKey(peerB));

        transportA.shutdown();
        transportB.shutdown();
    }

    @Test(timeout = 10000)
    public void testHandshakeTimeout() throws Exception {
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final CountDownLatch disconnectedA = new CountDownLatch(1);

        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
                LOG.debug("Got disconnected from {}", source);
                disconnectedA.countDown();
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
            }
        };

        final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());

        // Discard the handshake message.
        transportB.channel.pipeline().addFirst(new ChannelInboundHandlerAdapter() {
            @Override
            public void channelRead(ChannelHandlerContext ctx, Object msg) {
                // discard the message.
            }
        });

        transportA.send(peerB, createAck(new Zxid(0, 0)));
        disconnectedA.await();
        transportA.shutdown();
        transportB.shutdown();
    }

    @Test(timeout = 10000)
    public void testBroadcast() throws Exception {
        final int messageCount = 100;
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final String peerC = getUniqueHostPort();
        final CountDownLatch latchA = new CountDownLatch(messageCount);
        final CountDownLatch latchB = new CountDownLatch(messageCount);
        final CountDownLatch latchC = new CountDownLatch(messageCount);
        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                latchA.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                latchB.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverC = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                latchC.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());
        final NettyTransport transportC = new NettyTransport(peerC, receiverC, getDirectory());

        for (int i = 0; i < messageCount; i++) {
            transportA.broadcast(Arrays.asList(peerA, peerB, peerC).listIterator(), createAck(new Zxid(0, 0)));
        }
        latchA.await();
        latchB.await();
        latchC.await();
        transportA.shutdown();
        transportB.shutdown();
        transportC.shutdown();
    }

    @Test(timeout = 20000)
    public void testSsl() throws Exception {
        String peerA = getUniqueHostPort();
        String peerB = getUniqueHostPort();
        String peerC = getUniqueHostPort();
        String peerD = getUniqueHostPort();
        final CountDownLatch latchA = new CountDownLatch(1);
        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                latchA.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverC = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverD = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
            }

            public void onDisconnected(String source) {
            }
        };
        String password = "pa55w0rd";
        String sslDir = "target" + File.separator + "generated-resources" + File.separator + "ssl";

        File trustStore = new File(sslDir, "truststore.jks");
        File keyStoreA = new File(sslDir, "keystore_a.jks");
        File keyStoreB = new File(sslDir, "keystore_b.jks");
        File keyStoreC = new File(sslDir, "keystore_c.jks");

        SslParameters sslParam1 = new SslParameters(keyStoreA, password, trustStore, password);
        SslParameters sslParam2 = new SslParameters(keyStoreB, password, trustStore, password);
        SslParameters sslParam3 = new SslParameters(keyStoreC, password, trustStore, password);

        NettyTransport transportA = new NettyTransport(peerA, receiverA, sslParam1, getDirectory());
        NettyTransport transportB = new NettyTransport(peerB, receiverB, sslParam2, getDirectory());
        NettyTransport transportC = new NettyTransport(peerC, receiverC, sslParam3, getDirectory());
        NettyTransport transportD = new NettyTransport(peerD, receiverD, getDirectory());

        // D doesn't use SSL
        transportD.send(peerA, createAck(new Zxid(0, 0)));
        Assert.assertFalse(latchA.await(2, TimeUnit.SECONDS));

        // C uses untrusted cert
        transportC.send(peerA, createAck(new Zxid(0, 0)));
        Assert.assertFalse(latchA.await(2, TimeUnit.SECONDS));

        // B uses trusted cert
        transportB.send(peerA, createAck(new Zxid(0, 0)));
        latchA.await();

        transportA.shutdown();
        transportB.shutdown();
        transportC.shutdown();
    }

    @Test(timeout = 10000)
    public void testSendFile() throws Exception {
        int messageCount = 3;
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final CountDownLatch latchA = new CountDownLatch(messageCount);
        final CountDownLatch latchB = new CountDownLatch(messageCount);
        final ArrayList<File> receivedFiles = new ArrayList<>();

        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                LOG.debug("onReceived {}", message);
                latchA.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                if (message.getType() == MessageType.FILE_RECEIVED) {
                    receivedFiles.add(new File(message.getFileReceived().getFullPath()));
                }
                LOG.debug("onReceived {}", message);
                latchB.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory());
        final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory());

        transportA.send(peerB, createAck(new Zxid(0, 0)));
        File file = new File("./pom.xml");
        transportA.send(peerB, file);
        transportA.send(peerB, createAck(new Zxid(0, 1)));
        latchB.await();

        Assert.assertTrue(compareFiles(file, receivedFiles.get(0)));
        transportA.shutdown();
        transportB.shutdown();
    }

    @Test(timeout = 10000)
    public void testSendFileSsl() throws Exception {
        int messageCount = 3;
        final String peerA = getUniqueHostPort();
        final String peerB = getUniqueHostPort();
        final CountDownLatch latchA = new CountDownLatch(messageCount);
        final CountDownLatch latchB = new CountDownLatch(messageCount);
        final ArrayList<File> receivedFiles = new ArrayList<>();

        String password = "pa55w0rd";
        String sslDir = "target" + File.separator + "generated-resources" + File.separator + "ssl";
        File trustStore = new File(sslDir, "truststore.jks");
        File keyStoreA = new File(sslDir, "keystore_a.jks");
        File keyStoreB = new File(sslDir, "keystore_b.jks");

        Transport.Receiver receiverA = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                LOG.debug("onReceived {}", message);
                latchA.countDown();
            }

            public void onDisconnected(String source) {
            }
        };
        Transport.Receiver receiverB = new Transport.Receiver() {
            public void onReceived(String source, Message message) {
                if (message.getType() == MessageType.FILE_RECEIVED) {
                    receivedFiles.add(new File(message.getFileReceived().getFullPath()));
                }
                LOG.debug("onReceived {}", message);
                latchB.countDown();
            }

            public void onDisconnected(String source) {
            }
        };

        SslParameters sslParam1 = new SslParameters(keyStoreA, password, trustStore, password);
        SslParameters sslParam2 = new SslParameters(keyStoreB, password, trustStore, password);

        NettyTransport transportA = new NettyTransport(peerA, receiverA, sslParam1, getDirectory());
        NettyTransport transportB = new NettyTransport(peerB, receiverB, sslParam2, getDirectory());
        transportA.send(peerB, createAck(new Zxid(0, 0)));
        File file = new File("./pom.xml");
        transportA.send(peerB, file);
        transportA.send(peerB, createAck(new Zxid(0, 1)));
        latchB.await();

        Assert.assertTrue(compareFiles(file, receivedFiles.get(0)));
        transportA.shutdown();
        transportB.shutdown();
    }

    // Compare if two files are same.
    static boolean compareFiles(File file1, File file2) throws Exception {
        Assert.assertTrue(file1.exists());
        Assert.assertTrue(file2.exists());
        if (file1.length() != file2.length()) {
            return false;
        }
        try (FileInputStream fin1 = new FileInputStream(file1); FileInputStream fin2 = new FileInputStream(file2)) {
            BufferedReader reader = new BufferedReader(new InputStreamReader(fin1));
            StringBuilder sb1 = new StringBuilder();
            String line = null;
            while ((line = reader.readLine()) != null) {
                sb1.append(line);
            }
            StringBuilder sb2 = new StringBuilder();
            reader = new BufferedReader(new InputStreamReader(fin2));
            while ((line = reader.readLine()) != null) {
                sb2.append(line);
            }
            return sb1.toString().equals(sb2.toString());
        }
    }
}