Java tutorial
/** * Licensed to the zk1931 under one or more contributor license * agreements. See the NOTICE file distributed with this work * for additional information regarding copyright ownership. * The ASF licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the * License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.zk1931.jzab.transport; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.TextFormat; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.DefaultFileRegion; import io.netty.channel.EventLoopGroup; import io.netty.channel.FileRegion; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.LengthFieldPrepender; import io.netty.handler.ssl.SslHandler; import io.netty.handler.stream.ChunkedFile; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.util.AttributeKey; import io.netty.util.ReferenceCountUtil; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.security.KeyStore; import java.util.List; import java.util.Map; import java.util.concurrent.BlockingDeque; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.TrustManagerFactory; import com.github.zk1931.jzab.MessageBuilder; import com.github.zk1931.jzab.proto.ZabMessage.Message; import com.github.zk1931.jzab.SslParameters; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import static com.github.zk1931.jzab.proto.ZabMessage.Message.MessageType; /** * Netty-based transport. */ public class NettyTransport extends Transport { private static final Logger LOG = LoggerFactory.getLogger(NettyTransport.class); static final AttributeKey<String> REMOTE_ID = AttributeKey.valueOf("remote"); private final String hostPort; private final EventLoopGroup bossGroup = new NioEventLoopGroup(); private final EventLoopGroup workerGroup = new NioEventLoopGroup(); Channel channel; private final File keyStore; private final char[] keyStorePassword; private final File trustStore; private final char[] trustStorePassword; private SSLContext clientContext; private SSLContext serverContext; private final File dir; // remote id => sender map. ConcurrentMap<String, Sender> senders = new ConcurrentHashMap<String, Sender>(); public NettyTransport(String hostPort, final Receiver receiver, final File dir) throws InterruptedException, GeneralSecurityException, IOException { this(hostPort, receiver, new SslParameters(), dir); } /** * Constructs a NettyTransport object. * * @param hostPort "hostname:port" string. The netty transport binds to the * port specified in the string. * @param receiver receiver callback. * @param sslParam Ssl parameters. * @param dir the directory used to store the received file. */ public NettyTransport(String hostPort, final Receiver receiver, SslParameters sslParam, final File dir) throws InterruptedException, GeneralSecurityException, IOException { super(receiver); this.keyStore = sslParam.getKeyStore(); this.trustStore = sslParam.getTrustStore(); this.keyStorePassword = sslParam.getKeyStorePassword() != null ? sslParam.getKeyStorePassword().toCharArray() : null; this.trustStorePassword = sslParam.getTrustStorePassword() != null ? sslParam.getTrustStorePassword().toCharArray() : null; this.dir = dir; if (isSslEnabled()) { initSsl(); } this.hostPort = hostPort; String[] address = hostPort.split(":", 2); int port = Integer.parseInt(address[1]); ServerBootstrap b = new ServerBootstrap(); b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class).option(ChannelOption.SO_BACKLOG, 128) .option(ChannelOption.SO_REUSEADDR, true).childOption(ChannelOption.SO_KEEPALIVE, true) .childOption(ChannelOption.TCP_NODELAY, true).childHandler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) throws Exception { if (isSslEnabled()) { SSLEngine engine = serverContext.createSSLEngine(); engine.setUseClientMode(false); engine.setNeedClientAuth(true); ch.pipeline().addLast(new SslHandler(engine)); } // Incoming handlers ch.pipeline().addLast(new MainHandler()); ch.pipeline().addLast(new ServerHandshakeHandler()); ch.pipeline().addLast(new NotifyHandler()); ch.pipeline().addLast(new ErrorHandler()); // Outgoing handlers. ch.pipeline().addLast("frameEncoder", new LengthFieldPrepender(4)); } }); // Travis build fails once in a while because it fails to bind to a port. // This is most likely a transient failure. Retry binding for 5 times with // 1 second sleep in between before giving up. int bindRetryCount = 5; for (int i = 0;; i++) { try { channel = b.bind(port).sync().channel(); LOG.info("Server started: {}", hostPort); return; } catch (Exception ex) { if (i >= bindRetryCount) { throw ex; } LOG.debug("Failed to bind to {}. Retrying after 1 second.", hostPort); Thread.sleep(1000); } } } private boolean isSslEnabled() { return keyStore != null && trustStore != null; } private void initSsl() throws IOException, GeneralSecurityException { String kmAlgorithm = KeyManagerFactory.getDefaultAlgorithm(); String tmAlgorithm = TrustManagerFactory.getDefaultAlgorithm(); // TODO make the protocol and keystore type configurable. String protocol = "TLS"; KeyStore ks = KeyStore.getInstance("JKS"); KeyStore ts = KeyStore.getInstance("JKS"); try (FileInputStream keyStoreStream = new FileInputStream(keyStore); FileInputStream trustStoreStream = new FileInputStream(trustStore)) { ks.load(keyStoreStream, keyStorePassword); ts.load(trustStoreStream, trustStorePassword); } KeyManagerFactory kmf = KeyManagerFactory.getInstance(kmAlgorithm); TrustManagerFactory tmf = TrustManagerFactory.getInstance(tmAlgorithm); kmf.init(ks, keyStorePassword); tmf.init(ts); serverContext = SSLContext.getInstance(protocol); clientContext = SSLContext.getInstance(protocol); serverContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); clientContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); } /** * Destroys the transport. */ @Override public void shutdown() throws InterruptedException { try { channel.close(); for (Map.Entry<String, Sender> entry : senders.entrySet()) { LOG.debug("Shutting down the sender({})", entry.getKey()); entry.getValue().shutdown(); } senders.clear(); LOG.debug("Shutdown complete"); } finally { try { long quietPeriodSec = 0; long timeoutSec = 10; io.netty.util.concurrent.Future wf = workerGroup.shutdownGracefully(quietPeriodSec, timeoutSec, TimeUnit.SECONDS); io.netty.util.concurrent.Future bf = bossGroup.shutdownGracefully(quietPeriodSec, timeoutSec, TimeUnit.SECONDS); wf.await(); bf.await(); LOG.debug("Shutdown complete"); } catch (InterruptedException ex) { LOG.debug("Interrupted while shutting down NioEventLoopGroup", ex); } } } /** * Handles server-side handshake. */ private class ServerHandshakeHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { try { Message message = (Message) msg; // Make sure it's a handshake message. if (message.getType() != MessageType.HANDSHAKE) { LOG.debug("The first message from {} was not a handshake", ctx.channel().remoteAddress()); ctx.close(); return; } String remoteId = message.getHandshake().getNodeId(); LOG.debug("{} received handshake from {}", hostPort, remoteId); Sender sender = new Sender(remoteId, ctx.channel()); // Attach the remote node id to this channel. Subsequent handlers use // this information to determine origins of messages. ctx.channel().attr(REMOTE_ID).set(remoteId); Sender currentSender = senders.putIfAbsent(remoteId, sender); if (currentSender != null) { LOG.debug("Rejecting a handshake from {}", remoteId); ctx.close(); return; } // Send a response and remove the handler from the pipeline. LOG.debug("Server-side handshake completed from {} to {}", hostPort, remoteId); Message response = MessageBuilder.buildHandshake(hostPort); ByteBuffer buf = ByteBuffer.wrap(response.toByteArray()); ctx.channel().writeAndFlush(Unpooled.wrappedBuffer(buf)); sender.start(); ctx.pipeline().remove(this); } finally { ReferenceCountUtil.release(msg); } } } private class MainHandler extends ByteToMessageDecoder { private FileReceiver fileReceiver = null; private Message decodeToMessage(ByteBuf in) { if (in.readableBytes() < 4) { return null; } in.markReaderIndex(); int messageLength = in.readInt(); if (in.readableBytes() < messageLength) { in.resetReaderIndex(); return null; } byte[] buffer = new byte[messageLength]; in.readBytes(buffer); try { Message msg = Message.parseFrom(buffer); return msg; } catch (InvalidProtocolBufferException e) { LOG.error("Exception when parse protocol buffer.", e); Message msg = MessageBuilder.buildInvalidMessage(buffer); return msg; } } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { if (fileReceiver == null) { Message msg = decodeToMessage(in); if (msg == null) { return; } else if (msg.getType() == MessageType.FILE_HEADER) { LOG.debug("Got FILE_HEADER."); fileReceiver = new FileReceiver(msg.getFileHeader().getLength()); } else { out.add(msg); } } else { fileReceiver.process(in); if (fileReceiver.isDone()) { String filePath = fileReceiver.file.getPath(); Message msg = MessageBuilder.buildFileReceived(filePath); out.add(msg); // Resets it to null to switch back to normal decode mode. fileReceiver = null; } } } class FileReceiver { final long fileLength; long receivedLength = 0; final File file; final FileOutputStream fout; public FileReceiver(long length) throws IOException { this.file = File.createTempFile("transport", "", dir); this.fileLength = length; this.fout = new FileOutputStream(this.file); } public void process(ByteBuf in) throws IOException { long readableBytes = in.readableBytes(); long remainingBytes = fileLength - receivedLength; long bytesToRead = (remainingBytes < readableBytes) ? remainingBytes : readableBytes; byte[] buffer = new byte[(int) bytesToRead]; in.readBytes(buffer); fout.write(buffer); receivedLength += bytesToRead; if (receivedLength == fileLength) { fout.getChannel().force(false); fout.close(); } } boolean isDone() { return receivedLength == fileLength; } } } private class NotifyHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object obj) { String remoteId = ctx.channel().attr(NettyTransport.REMOTE_ID).get(); Message msg = (Message) obj; receiver.onReceived(remoteId, msg); } } /** * Handles errors. */ private class ErrorHandler extends ChannelInboundHandlerAdapter { @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { String remoteId = ctx.channel().attr(NettyTransport.REMOTE_ID).get(); ctx.close(); if (remoteId != null) { LOG.debug("Got disconnected from {}.", remoteId); // This must not be null. Sender sender = senders.get(remoteId); if (sender != null) { sender.shutdown(); } receiver.onDisconnected(remoteId); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { // Don't handle errors here. Call ctx.close() and let channelInactive() // handle all the errrors. LOG.debug("Caught an exception", cause); ctx.close(); } } @Override public void send(final String destination, Message message) { if (destination.equals(hostPort)) { // The message is being sent to itself. Don't bother going over TCP. // Directly call onReceived. receiver.onReceived(destination, message); return; } ByteBuffer bytes = ByteBuffer.wrap(message.toByteArray()); Sender currentSender = senders.get(destination); if (currentSender != null) { currentSender.requests.add(bytes); } else { // no connection exists. LOG.debug("No connection from {} to {}. Creating a new one", hostPort, destination); Sender newSender = new Sender(hostPort, destination); currentSender = senders.putIfAbsent(destination, newSender); if (currentSender == null) { newSender.requests.add(bytes); newSender.startHandshake(); } else { currentSender.requests.add(bytes); } } } @Override public void send(final String destination, File file) { if (destination.equals(hostPort)) { LOG.error("Can't send file to itself."); throw new RuntimeException("Can't send file to itself."); } Sender currentSender = senders.get(destination); if (currentSender != null) { currentSender.requests.add(file); } else { // no connection exists. LOG.debug("No connection from {} to {}. Creating a new one", hostPort, destination); Sender newSender = new Sender(hostPort, destination); currentSender = senders.putIfAbsent(destination, newSender); if (currentSender == null) { newSender.requests.add(file); newSender.startHandshake(); } else { currentSender.requests.add(file); } } } @Override public void clear(String destination) { LOG.debug("Closing the connection to {}", destination); Sender sender = senders.remove(destination); if (sender != null) { sender.shutdown(); } } /** * sender thread. */ private class Sender implements Callable<Void> { private final String destination; private Bootstrap bootstrap = null; private Channel channel; private Future<Void> future; final BlockingDeque<Object> requests = new LinkedBlockingDeque<>(); public Sender(String destination, Channel channel) { this.destination = destination; this.channel = channel; } public Sender(final String source, final String destination) { this.destination = destination; bootstrap = new Bootstrap(); bootstrap.group(workerGroup); bootstrap.channel(NioSocketChannel.class); bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 1000); bootstrap.option(ChannelOption.SO_KEEPALIVE, true); bootstrap.option(ChannelOption.TCP_NODELAY, true); bootstrap.handler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) throws Exception { if (isSslEnabled()) { SSLEngine engine = serverContext.createSSLEngine(); engine.setUseClientMode(true); ch.pipeline().addLast(new SslHandler(engine)); } // Inbound handlers. ch.pipeline().addLast(new ReadTimeoutHandler(2)); ch.pipeline().addLast(new MainHandler()); ch.pipeline().addLast(new ClientHandshakeHandler()); // Outbound handlers. ch.pipeline().addLast("frameEncoder", new LengthFieldPrepender(4)); } }); } public void startHandshake() { String[] address = destination.split(":", 2); String host = address[0]; int port = Integer.parseInt(address[1]); LOG.debug("host: {}, port: {}", host, port); bootstrap.connect(host, port).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture cfuture) { if (cfuture.isSuccess()) { LOG.debug("{} connected to {}. Sending a handshake", hostPort, destination); Message msg = MessageBuilder.buildHandshake(hostPort); ByteBuffer bb = ByteBuffer.wrap(msg.toByteArray()); channel = cfuture.channel(); channel.writeAndFlush(Unpooled.wrappedBuffer(bb)); } else { LOG.debug("Failed to connect to {}: {}", destination, cfuture.cause().getMessage()); handshakeFailed(false); } } }); } public void handshakeCompleted() { LOG.debug("Client handshake completed: {} => {}", hostPort, destination); Sender sender = senders.get(destination); assert sender == this; sender.channel.attr(REMOTE_ID).set(destination); sender.channel.pipeline().remove(ReadTimeoutHandler.class); sender.channel.pipeline().addLast(new NotifyHandler()); sender.channel.pipeline().addLast(new ErrorHandler()); sender.start(); } public void handshakeFailed(boolean tie) { LOG.debug("Client handshake failed: {} => {}", hostPort, destination); Sender sender = senders.get(destination); if (sender != null) { sender.shutdown(); } if (tie) { try { // If the handshake failure is caused by the tie, does the random // sleep. Thread.sleep((int) (Math.random() * 300)); } catch (InterruptedException ex) { Thread.currentThread().interrupt(); } } receiver.onDisconnected(destination); } void sendFile(File file) throws Exception { long length = file.length(); LOG.debug("Got request of sending file {} of length {}.", file, length); Message handshake = MessageBuilder.buildFileHeader(length); byte[] bytes = handshake.toByteArray(); // Sends HANDSHAKE first before transferring actual file data, the // HANDSHAKE will tell the peer's channel to prepare for the file // transferring. channel.writeAndFlush(Unpooled.wrappedBuffer(bytes)).sync(); ChannelHandler prepender = channel.pipeline().get("frameEncoder"); // Removes length prepender, we don't need this handler for file // transferring. channel.pipeline().remove(prepender); // Adds ChunkedWriteHandler for file transferring. ChannelHandler cwh = new ChunkedWriteHandler(); channel.pipeline().addLast(cwh); // Begins file transferring. RandomAccessFile raf = new RandomAccessFile(file, "r"); if (channel.pipeline().get(SslHandler.class) != null) { // Zero-Copy file transferring is not supported for ssl. channel.writeAndFlush(new ChunkedFile(raf, 0, length, 8912)); } else { // Use Zero-Copy file transferring in non-ssl mode. FileRegion region = new DefaultFileRegion(raf.getChannel(), 0, length); channel.writeAndFlush(region); } // Restores pipeline to original state. channel.pipeline().remove(cwh); channel.pipeline().addLast("frameEncoder", prepender); } @Override public Void call() throws Exception { LOG.debug("Started the sender: {} => {}", hostPort, destination); try { while (true) { Object req = requests.take(); if (req instanceof ByteBuffer) { ByteBuffer buf = (ByteBuffer) req; channel.writeAndFlush(Unpooled.wrappedBuffer(buf)); } else if (req instanceof File) { File file = (File) req; sendFile(file); } else if (req instanceof Shutdown) { LOG.debug("Got shutdown request."); break; } } } catch (InterruptedException ex) { LOG.debug("Sender to {} got interrupted", destination); return null; } catch (Exception ex) { LOG.warn("Sender failed with an exception", ex); throw ex; } finally { channel.close(); } return null; } public void start() { ExecutorService es = Executors.newSingleThreadExecutor(); future = es.submit(this); es.shutdown(); } public void shutdown() { LOG.debug("Shutting down the sender: {} => {}", hostPort, destination); try { if (future != null) { try { this.requests.add(new Shutdown()); future.get(); } catch (InterruptedException | ExecutionException ex) { LOG.debug("Ignore the exception", ex); } } if (channel != null) { channel.close().syncUninterruptibly(); } } catch (RejectedExecutionException ex) { LOG.debug("Ignoring rejected execution exception", ex); } } class Shutdown { // We use it to shutdown the sender thread. } /** * Handles client-side handshake. */ public class ClientHandshakeHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { try { Message message = (Message) msg; if (message.getType() != MessageType.HANDSHAKE) { // Server responded with an invalid message. LOG.error("The first message from %s was not a handshake: %s", ctx.channel().remoteAddress(), TextFormat.shortDebugString(message)); ctx.close(); return; } String response = message.getHandshake().getNodeId(); if (!response.equals(destination)) { // Handshake response doesn't match server's node ID. LOG.error("Invalid handshake response from %s: %s", destination, response); ctx.close(); return; } // Handshake is finished. Remove the handler from the pipeline. ctx.pipeline().remove(this); handshakeCompleted(); } finally { ReferenceCountUtil.release(msg); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { // Don't call the handshake callback here. Simply close the context and // let channelInactive() call the handshake callback. LOG.debug("Caught an exception", cause); ctx.close(); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { LOG.debug("Got disconnected from {}", destination); ctx.close(); handshakeFailed(true); } } } }