Java tutorial
/** * Licensed to the zk1931 under one or more contributor license * agreements. See the NOTICE file distributed with this work * for additional information regarding copyright ownership. * The ASF licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the * License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.zk1931.jzab.transport; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.InputStreamReader; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.ArrayList; import java.util.LinkedList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import com.github.zk1931.jzab.MessageBuilder; import com.github.zk1931.jzab.TestBase; import com.github.zk1931.jzab.proto.ZabMessage.Message; import com.github.zk1931.jzab.proto.ZabMessage.Message.MessageType; import com.github.zk1931.jzab.SslParameters; import com.github.zk1931.jzab.Zxid; import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Test NettyTransport. */ public class NettyTransportTest extends TestBase { private static final Logger LOG = LoggerFactory.getLogger(NettyTransportTest.class); private static final Transport.Receiver NOOP = new Transport.Receiver() { @Override public void onReceived(String source, Message message) { } @Override public void onDisconnected(String source) { } }; public static ByteBuffer createByteBuffer(int num) { ByteBuffer bb = ByteBuffer.allocate(4); bb.putInt(num); bb.flip(); return bb; } public static Message createAck(Zxid zxid) { return MessageBuilder.buildAck(zxid); } /** * Make sure the constructor fails when the port is invalid. */ @Test(timeout = 10000, expected = IllegalArgumentException.class) public void testInvalidPort() throws Exception { NettyTransport transport = new NettyTransport(getHostPort(-1), NOOP, getDirectory()); } @Test(timeout = 1000) public void testLocalSend() throws Exception { final String localId = getUniqueHostPort(); // receiver simply appends messages to a list. final LinkedList<Zxid> messages = new LinkedList<>(); Transport.Receiver receiver = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.assertEquals(localId, source); Zxid zxid = MessageBuilder.fromProtoZxid(message.getAck().getZxid()); messages.add(zxid); } public void onDisconnected(String source) { } }; // send messages to itself. NettyTransport transport = new NettyTransport(localId, receiver, getDirectory()); for (int i = 0; i < 20; i++) { transport.send(localId, createAck(new Zxid(0, i))); } // receive messages. for (int i = 0; i < 20; i++) { Zxid zxid = messages.pop(); LOG.debug("Received a message: {}", zxid); Assert.assertEquals(new Zxid(0, i), zxid); } Assert.assertTrue(messages.isEmpty()); } @Test(timeout = 5000) public void testConnectFailed() throws Exception { final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final CountDownLatch disconnected = new CountDownLatch(1); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { disconnected.countDown(); } }; NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); transportA.send(peerB, createAck(new Zxid(0, 0))); disconnected.await(); } @Test(timeout = 10000) public void testSend() throws Exception { final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); // receiver simply appends messages to a list. int messageCount = 20; final LinkedList<Zxid> messagesA = new LinkedList<>(); final LinkedList<Zxid> messagesB = new LinkedList<>(); final CountDownLatch latchA = new CountDownLatch(messageCount); final CountDownLatch latchB = new CountDownLatch(messageCount); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.assertEquals(peerB, source); Zxid zxid = MessageBuilder.fromProtoZxid(message.getAck().getZxid()); messagesA.add(zxid); latchA.countDown(); } public void onDisconnected(String source) { } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.assertEquals(peerA, source); Zxid zxid = MessageBuilder.fromProtoZxid(message.getAck().getZxid()); messagesB.add(zxid); latchB.countDown(); } public void onDisconnected(String source) { } }; NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); // send messages from A to B. for (int i = 0; i < messageCount; i++) { transportA.send(peerB, createAck(new Zxid(0, i))); } latchB.await(); for (int i = 0; i < messageCount; i++) { Zxid zxid = messagesB.pop(); LOG.debug("Received a message: {}", zxid); Assert.assertEquals(new Zxid(0, i), zxid); } Assert.assertTrue(messagesB.isEmpty()); // send messages from B to A. for (int i = 0; i < messageCount; i++) { transportB.send(peerA, createAck(new Zxid(0, i))); } latchA.await(); for (int i = 0; i < messageCount; i++) { Zxid zxid = messagesA.pop(); LOG.debug("Received a message: {}", zxid); Assert.assertEquals(new Zxid(0, i), zxid); } Assert.assertTrue(messagesA.isEmpty()); transportA.shutdown(); transportB.shutdown(); } @Test(timeout = 5000) public void testDisconnectClient() throws Exception { final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final CountDownLatch latchA = new CountDownLatch(1); final CountDownLatch latchB = new CountDownLatch(1); final CountDownLatch disconnectedB = new CountDownLatch(1); // receiver simply decrement the latch. Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.assertEquals(peerB, source); latchA.countDown(); } public void onDisconnected(String source) { } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.assertEquals(peerA, source); latchB.countDown(); } public void onDisconnected(String source) { disconnectedB.countDown(); } }; NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); // A initiates a handshake. transportA.send(peerB, createAck(new Zxid(0, 0))); latchB.await(); // shutdown A and make sure B removes the channel to A. transportA.shutdown(); Assert.assertTrue(transportA.senders.isEmpty()); disconnectedB.await(); Assert.assertTrue(transportB.senders.containsKey(peerA)); transportB.shutdown(); } @Test(timeout = 10000) public void testDisconnectServer() throws Exception { final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final CountDownLatch latchA = new CountDownLatch(1); final CountDownLatch latchB = new CountDownLatch(1); // receiver simply decrement the latch. Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { latchA.countDown(); } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.assertEquals(peerA, source); latchB.countDown(); } public void onDisconnected(String source) { } }; NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); // A initiates a handshake. transportA.send(peerB, createAck(new Zxid(0, 0))); latchB.await(); // shutdown B and make sure A removes the channel to B. transportB.shutdown(); Assert.assertTrue(transportB.senders.isEmpty()); // A should get onDisconnected event, but B should still be in the map. latchA.await(); Assert.assertTrue(transportA.senders.containsKey(peerB)); transportA.shutdown(); } @Test(timeout = 10000) public void testTieBreaker() throws Exception { final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final CountDownLatch disconnectedA = new CountDownLatch(1); final CountDownLatch disconnectedB = new CountDownLatch(1); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.fail("Handshake should have failed"); } public void onDisconnected(String source) { LOG.debug("Got disconnected from {}", source); disconnectedA.countDown(); } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.fail("Handshake should have failed"); } public void onDisconnected(String source) { LOG.debug("Got disconnected from {}", source); disconnectedB.countDown(); } }; final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); transportB.channel.pipeline().addFirst(new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { // B initiates another handshake before responding to A's handshake. transportB.send(peerA, createAck(new Zxid(0, 0))); ctx.pipeline().remove(this); ctx.fireChannelRead(msg); } }); // A initiates a handshake. transportA.send(peerB, createAck(new Zxid(0, 0))); disconnectedA.await(); disconnectedB.await(); transportA.shutdown(); transportB.shutdown(); } @Test(timeout = 5000) public void testCloseClient() throws Exception { final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final int messageCount = 100; final CountDownLatch latchB = new CountDownLatch(messageCount); final CountDownLatch disconnectedB = new CountDownLatch(1); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.assertEquals(peerA, source); latchB.countDown(); LOG.debug("Received a message from {}: {}: {}", source, message, latchB.getCount()); } public void onDisconnected(String source) { LOG.debug("Got disconnected from {}", source); disconnectedB.countDown(); } }; final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); for (int i = 0; i < messageCount; i++) { transportA.send(peerB, createAck(new Zxid(0, i))); } latchB.await(); // A should remove B from the map after clear() is called. transportA.clear(peerB); Assert.assertFalse(transportA.senders.containsKey(peerB)); // B should get onDisconnected event, but A should be in the map. disconnectedB.await(); Assert.assertTrue(transportB.senders.containsKey(peerA)); transportA.shutdown(); transportB.shutdown(); } @Test(timeout = 10000) public void testCloseServer() throws Exception { final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final int messageCount = 100; final CountDownLatch latchB = new CountDownLatch(messageCount); final CountDownLatch disconnectedA = new CountDownLatch(1); final CountDownLatch disconnectedB = new CountDownLatch(1); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { LOG.debug("Got disconnected from {}", source); disconnectedA.countDown(); } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { Assert.assertEquals(peerA, source); latchB.countDown(); LOG.debug("Received a message from {}: {}: {}", source, message, latchB.getCount()); } public void onDisconnected(String source) { } }; final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); for (int i = 0; i < messageCount; i++) { transportA.send(peerB, createAck(new Zxid(0, i))); } latchB.await(); // B should remove A from the map after clear() is called. transportB.clear(peerA); Assert.assertFalse(transportB.senders.containsKey(peerA)); // A should get onDisconnected event, but B should be in the map. disconnectedA.await(); Assert.assertTrue(transportA.senders.containsKey(peerB)); transportA.shutdown(); transportB.shutdown(); } @Test(timeout = 10000) public void testHandshakeTimeout() throws Exception { final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final CountDownLatch disconnectedA = new CountDownLatch(1); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { LOG.debug("Got disconnected from {}", source); disconnectedA.countDown(); } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { } }; final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); // Discard the handshake message. transportB.channel.pipeline().addFirst(new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { // discard the message. } }); transportA.send(peerB, createAck(new Zxid(0, 0))); disconnectedA.await(); transportA.shutdown(); transportB.shutdown(); } @Test(timeout = 10000) public void testBroadcast() throws Exception { final int messageCount = 100; final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final String peerC = getUniqueHostPort(); final CountDownLatch latchA = new CountDownLatch(messageCount); final CountDownLatch latchB = new CountDownLatch(messageCount); final CountDownLatch latchC = new CountDownLatch(messageCount); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { latchA.countDown(); } public void onDisconnected(String source) { } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { latchB.countDown(); } public void onDisconnected(String source) { } }; Transport.Receiver receiverC = new Transport.Receiver() { public void onReceived(String source, Message message) { latchC.countDown(); } public void onDisconnected(String source) { } }; final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); final NettyTransport transportC = new NettyTransport(peerC, receiverC, getDirectory()); for (int i = 0; i < messageCount; i++) { transportA.broadcast(Arrays.asList(peerA, peerB, peerC).listIterator(), createAck(new Zxid(0, 0))); } latchA.await(); latchB.await(); latchC.await(); transportA.shutdown(); transportB.shutdown(); transportC.shutdown(); } @Test(timeout = 20000) public void testSsl() throws Exception { String peerA = getUniqueHostPort(); String peerB = getUniqueHostPort(); String peerC = getUniqueHostPort(); String peerD = getUniqueHostPort(); final CountDownLatch latchA = new CountDownLatch(1); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { latchA.countDown(); } public void onDisconnected(String source) { } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { } }; Transport.Receiver receiverC = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { } }; Transport.Receiver receiverD = new Transport.Receiver() { public void onReceived(String source, Message message) { } public void onDisconnected(String source) { } }; String password = "pa55w0rd"; String sslDir = "target" + File.separator + "generated-resources" + File.separator + "ssl"; File trustStore = new File(sslDir, "truststore.jks"); File keyStoreA = new File(sslDir, "keystore_a.jks"); File keyStoreB = new File(sslDir, "keystore_b.jks"); File keyStoreC = new File(sslDir, "keystore_c.jks"); SslParameters sslParam1 = new SslParameters(keyStoreA, password, trustStore, password); SslParameters sslParam2 = new SslParameters(keyStoreB, password, trustStore, password); SslParameters sslParam3 = new SslParameters(keyStoreC, password, trustStore, password); NettyTransport transportA = new NettyTransport(peerA, receiverA, sslParam1, getDirectory()); NettyTransport transportB = new NettyTransport(peerB, receiverB, sslParam2, getDirectory()); NettyTransport transportC = new NettyTransport(peerC, receiverC, sslParam3, getDirectory()); NettyTransport transportD = new NettyTransport(peerD, receiverD, getDirectory()); // D doesn't use SSL transportD.send(peerA, createAck(new Zxid(0, 0))); Assert.assertFalse(latchA.await(2, TimeUnit.SECONDS)); // C uses untrusted cert transportC.send(peerA, createAck(new Zxid(0, 0))); Assert.assertFalse(latchA.await(2, TimeUnit.SECONDS)); // B uses trusted cert transportB.send(peerA, createAck(new Zxid(0, 0))); latchA.await(); transportA.shutdown(); transportB.shutdown(); transportC.shutdown(); } @Test(timeout = 10000) public void testSendFile() throws Exception { int messageCount = 3; final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final CountDownLatch latchA = new CountDownLatch(messageCount); final CountDownLatch latchB = new CountDownLatch(messageCount); final ArrayList<File> receivedFiles = new ArrayList<>(); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { LOG.debug("onReceived {}", message); latchA.countDown(); } public void onDisconnected(String source) { } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { if (message.getType() == MessageType.FILE_RECEIVED) { receivedFiles.add(new File(message.getFileReceived().getFullPath())); } LOG.debug("onReceived {}", message); latchB.countDown(); } public void onDisconnected(String source) { } }; final NettyTransport transportA = new NettyTransport(peerA, receiverA, getDirectory()); final NettyTransport transportB = new NettyTransport(peerB, receiverB, getDirectory()); transportA.send(peerB, createAck(new Zxid(0, 0))); File file = new File("./pom.xml"); transportA.send(peerB, file); transportA.send(peerB, createAck(new Zxid(0, 1))); latchB.await(); Assert.assertTrue(compareFiles(file, receivedFiles.get(0))); transportA.shutdown(); transportB.shutdown(); } @Test(timeout = 10000) public void testSendFileSsl() throws Exception { int messageCount = 3; final String peerA = getUniqueHostPort(); final String peerB = getUniqueHostPort(); final CountDownLatch latchA = new CountDownLatch(messageCount); final CountDownLatch latchB = new CountDownLatch(messageCount); final ArrayList<File> receivedFiles = new ArrayList<>(); String password = "pa55w0rd"; String sslDir = "target" + File.separator + "generated-resources" + File.separator + "ssl"; File trustStore = new File(sslDir, "truststore.jks"); File keyStoreA = new File(sslDir, "keystore_a.jks"); File keyStoreB = new File(sslDir, "keystore_b.jks"); Transport.Receiver receiverA = new Transport.Receiver() { public void onReceived(String source, Message message) { LOG.debug("onReceived {}", message); latchA.countDown(); } public void onDisconnected(String source) { } }; Transport.Receiver receiverB = new Transport.Receiver() { public void onReceived(String source, Message message) { if (message.getType() == MessageType.FILE_RECEIVED) { receivedFiles.add(new File(message.getFileReceived().getFullPath())); } LOG.debug("onReceived {}", message); latchB.countDown(); } public void onDisconnected(String source) { } }; SslParameters sslParam1 = new SslParameters(keyStoreA, password, trustStore, password); SslParameters sslParam2 = new SslParameters(keyStoreB, password, trustStore, password); NettyTransport transportA = new NettyTransport(peerA, receiverA, sslParam1, getDirectory()); NettyTransport transportB = new NettyTransport(peerB, receiverB, sslParam2, getDirectory()); transportA.send(peerB, createAck(new Zxid(0, 0))); File file = new File("./pom.xml"); transportA.send(peerB, file); transportA.send(peerB, createAck(new Zxid(0, 1))); latchB.await(); Assert.assertTrue(compareFiles(file, receivedFiles.get(0))); transportA.shutdown(); transportB.shutdown(); } // Compare if two files are same. static boolean compareFiles(File file1, File file2) throws Exception { Assert.assertTrue(file1.exists()); Assert.assertTrue(file2.exists()); if (file1.length() != file2.length()) { return false; } try (FileInputStream fin1 = new FileInputStream(file1); FileInputStream fin2 = new FileInputStream(file2)) { BufferedReader reader = new BufferedReader(new InputStreamReader(fin1)); StringBuilder sb1 = new StringBuilder(); String line = null; while ((line = reader.readLine()) != null) { sb1.append(line); } StringBuilder sb2 = new StringBuilder(); reader = new BufferedReader(new InputStreamReader(fin2)); while ((line = reader.readLine()) != null) { sb2.append(line); } return sb1.toString().equals(sb2.toString()); } } }