Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.ibm.mqlight.api.impl.network; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.ssl.JdkSslContext; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslHandler; import io.netty.util.concurrent.GenericFutureListener; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.UnresolvedAddressException; import java.security.KeyManagementException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; import java.util.Arrays; import java.util.LinkedList; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Pattern; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; import javax.net.ssl.TrustManagerFactory; import com.ibm.mqlight.api.ClientException; import com.ibm.mqlight.api.NetworkException; import com.ibm.mqlight.api.Promise; import com.ibm.mqlight.api.endpoint.Endpoint; import com.ibm.mqlight.api.impl.LogbackLogging; import com.ibm.mqlight.api.logging.FFDCProbeId; import com.ibm.mqlight.api.logging.Logger; import com.ibm.mqlight.api.logging.LoggerFactory; import com.ibm.mqlight.api.network.NetworkChannel; import com.ibm.mqlight.api.network.NetworkListener; import com.ibm.mqlight.api.network.NetworkService; public class NettyNetworkService implements NetworkService { private static final Logger logger = LoggerFactory.getLogger(NettyNetworkService.class); static { LogbackLogging.setup(); } private static Object bootstrapSync = new Object(); private static Bootstrap bootstrap; static class NettyInboundHandler extends ChannelInboundHandlerAdapter implements NetworkChannel { private static final Logger logger = LoggerFactory.getLogger(NettyInboundHandler.class); private final SocketChannel channel; private NetworkListener listener = null; private final AtomicBoolean closed = new AtomicBoolean(false); protected NettyInboundHandler(SocketChannel channel) { final String methodName = "<init>"; logger.entry(this, methodName, channel); this.channel = channel; logger.exit(this, methodName); } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { final String methodName = "channelRead"; logger.entry(this, methodName, ctx, msg); if (listener != null) listener.onRead(this, (ByteBuf) msg); logger.exit(this, methodName); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { final String methodName = "exceptionCaught"; logger.entry(this, methodName, cause); try { ctx.close(); Exception exception; if (cause instanceof Exception) { exception = (Exception) cause; } else { logger.ffdc(methodName, FFDCProbeId.PROBE_001, cause, this); exception = new NetworkException("unexpected error", cause); } // if we have a nested chain of causes, walk it until we have at // most a single pair of Exception and cause while (exception.getCause() != null && exception.getCause() instanceof Exception) { if (exception.getCause().getCause() == null) { break; } exception = (Exception) exception.getCause(); } // rewrap security-related exceptions final String condition = exception.getClass().getName(); if (condition.contains("javax.net.ssl.") || condition.contains("java.security.") || condition.contains("com.ibm.jsse2.") || condition.contains("sun.security.")) { exception = new com.ibm.mqlight.api.SecurityException(exception.getMessage(), exception.getCause()); } if (listener != null) { listener.onError(this, exception); } } catch (Throwable t) { logger.error("An exception was thrown during " + methodName + "() handling of " + cause.toString(), t); } logger.exit(this, methodName); } @Override public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { final String methodName = "channelWritabilityChanged"; logger.entry(this, methodName, ctx); doWrite(); logger.exit(this, methodName); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { final String methodName = "channelInactive"; logger.entry(this, methodName, ctx); boolean alreadyClosed = closed.getAndSet(true); if (!alreadyClosed) { if (listener != null) { listener.onClose(this); } decrementUseCount(); } logger.exit(this, methodName); } protected void setListener(NetworkListener listener) { final String methodName = "setListener"; logger.entry(this, methodName, listener); this.listener = listener; logger.exit(this, methodName); } @Override public void close(final Promise<Void> nwfuture) { final String methodName = "close"; logger.entry(this, methodName, nwfuture); boolean alreadyClosed = closed.getAndSet(true); if (!alreadyClosed) { final ChannelFuture f = channel.disconnect(); if (nwfuture != null) { f.addListener(new GenericFutureListener<ChannelFuture>() { @Override public void operationComplete(ChannelFuture future) throws Exception { nwfuture.setSuccess(null); decrementUseCount(); } }); } else { decrementUseCount(); } } else if (nwfuture != null) { nwfuture.setSuccess(null); } logger.exit(this, methodName); } private static class WriteRequest { protected final ByteBuf buffer; protected final Promise<Boolean> promise; protected WriteRequest(ByteBuf buffer, Promise<Boolean> promise) { this.buffer = buffer; this.promise = promise; } } @Override public void write(ByteBuffer buffer, Promise<Boolean> promise) { final String methodName = "write"; logger.entry(this, methodName, buffer, promise); doWrite(buffer, promise); logger.exit(this, methodName); } LinkedList<WriteRequest> pendingWrites = new LinkedList<>(); boolean writeInProgress = false; private void processWriteRequest(WriteRequest toProcess) { final String methodName = "processWriteRequest"; logger.entry(this, methodName, toProcess); final Promise<Boolean> writeCompletePromise = toProcess.promise; logger.data(this, methodName, "writeAndFlush {}", toProcess); final ChannelFuture f = channel.writeAndFlush(toProcess.buffer); f.addListener(new GenericFutureListener<ChannelFuture>() { @Override public void operationComplete(ChannelFuture future) throws Exception { boolean havePendingWrites = false; synchronized (pendingWrites) { writeInProgress = false; havePendingWrites = !pendingWrites.isEmpty(); } logger.data(this, methodName, "doWrite (complete)"); writeCompletePromise.setSuccess(!havePendingWrites); doWrite(); } }); logger.exit(this, methodName); } private void doWrite() { final String methodName = "doWrite"; logger.entry(this, methodName); WriteRequest toProcess = null; synchronized (pendingWrites) { if (!writeInProgress && channel.isWritable() && !pendingWrites.isEmpty()) { toProcess = pendingWrites.removeFirst(); writeInProgress = true; } } if (toProcess != null) processWriteRequest(toProcess); logger.exit(this, methodName); } private void doWrite(ByteBuffer buffer, Promise<Boolean> promise) { final String methodName = "doWrite"; logger.entry(this, methodName, buffer, promise); WriteRequest toProcess = null; synchronized (pendingWrites) { if (!writeInProgress && channel.isWritable()) { if (pendingWrites.isEmpty()) { // Ideally here we should be able to use Unpooled.wrappedBuffer, to save copying. But network // writes can become deferred under load, hence we must make a copy of the buffer to protect the // data (as the caller may need to reuse the buffer when we return) toProcess = new WriteRequest(Unpooled.copiedBuffer(buffer), promise); } else { pendingWrites.addLast(new WriteRequest(Unpooled.copiedBuffer(buffer), promise)); toProcess = pendingWrites.removeFirst(); } writeInProgress = true; } else { pendingWrites.addLast(new WriteRequest(Unpooled.copiedBuffer(buffer), promise)); } } if (toProcess != null) processWriteRequest(toProcess); logger.exit(this, methodName); } private Object context; @Override public synchronized void setContext(Object context) { this.context = context; } @Override public synchronized Object getContext() { return context; } } protected class ConnectListener implements GenericFutureListener<ChannelFuture> { private final Logger logger = LoggerFactory.getLogger(ConnectListener.class); private final Endpoint endpoint; private final Promise<NetworkChannel> promise; private final NetworkListener listener; protected ConnectListener(Endpoint endpoint, ChannelFuture cFuture, Promise<NetworkChannel> promise, NetworkListener listener) { final String methodName = "<init>"; logger.entry(this, methodName, endpoint, cFuture, promise, listener); this.endpoint = endpoint; this.promise = promise; this.listener = listener; logger.exit(this, methodName); } @Override public void operationComplete(ChannelFuture cFuture) throws Exception { final String methodName = "operationComplete"; logger.entry(this, methodName, cFuture); if (cFuture.isSuccess()) { NettyInboundHandler handler = (NettyInboundHandler) cFuture.channel().pipeline().last(); handler.setListener(listener); promise.setSuccess(handler); } else { String message = cFuture.cause().getMessage(); if (message == null || message.length() == 0) { if (cFuture.cause() instanceof UnresolvedAddressException) { message = "unresolved address " + endpoint.getURI(); } else { message = cFuture.cause().toString() + " for address " + endpoint.getURI(); } } final ClientException cause = new NetworkException("Could not connect to server: " + message, cFuture.cause()); promise.setFailure(cause); decrementUseCount(); } logger.exit(this, methodName); } } /** Pattern of protocols to disable */ final Pattern disabledProtocolPattern = Pattern.compile("(SSLv2|SSLv3).*"); /** Pattern of cipher suites to disable */ final Pattern disabledCipherPattern = Pattern.compile(".*_(NULL|EXPORT|DES|RC4|MD5|PSK|SRP|CAMELLIA)_.*"); @Override public void connect(Endpoint endpoint, NetworkListener listener, Promise<NetworkChannel> promise) { final String methodName = "connect"; logger.entry(this, methodName, endpoint, listener, promise); SslContext sslCtx = null; try { if (endpoint.getCertChainFile() != null && endpoint.getCertChainFile().exists()) { try (FileInputStream fileInputStream = new FileInputStream(endpoint.getCertChainFile())) { KeyStore jks = KeyStore.getInstance("JKS"); jks.load(fileInputStream, null); TrustManagerFactory trustManagerFactory = TrustManagerFactory .getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(jks); sslCtx = SslContext.newClientContext(); if (sslCtx instanceof JdkSslContext) { ((JdkSslContext) sslCtx).context().init(null, trustManagerFactory.getTrustManagers(), null); } } catch (IOException | NoSuchAlgorithmException | CertificateException | KeyStoreException | KeyManagementException e) { logger.data(this, methodName, e.toString()); } } // fallback to passing as .PEM file (or null, which loads default cacerts) if (sslCtx == null) { sslCtx = SslContext.newClientContext(endpoint.getCertChainFile()); } final SSLEngine sslEngine = sslCtx.newEngine(null, endpoint.getHost(), endpoint.getPort()); sslEngine.setUseClientMode(true); final LinkedList<String> enabledProtocols = new LinkedList<String>() { private static final long serialVersionUID = 7838479468739671083L; { for (String protocol : sslEngine.getSupportedProtocols()) { if (!disabledProtocolPattern.matcher(protocol).matches()) { add(protocol); } } } }; sslEngine.setEnabledProtocols(enabledProtocols.toArray(new String[0])); logger.data(this, methodName, "enabledProtocols", Arrays.toString(sslEngine.getEnabledProtocols())); final LinkedList<String> enabledCipherSuites = new LinkedList<String>() { private static final long serialVersionUID = 7838479468739671083L; { for (String cipher : sslEngine.getSupportedCipherSuites()) { if (!disabledCipherPattern.matcher(cipher).matches()) { add(cipher); } } } }; sslEngine.setEnabledCipherSuites(enabledCipherSuites.toArray(new String[0])); logger.data(this, methodName, "enabledCipherSuites", Arrays.toString(sslEngine.getEnabledCipherSuites())); if (endpoint.getVerifyName()) { SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); } // The listener must be added to the ChannelFuture before the bootstrap channel initialisation completes (i.e. // before the NettyInboundHandler is added to the channel pipeline) otherwise the listener may not be able to // see the NettyInboundHandler, when its operationComplete() method is called (there is a small window where // the socket connection fails just after initChannel has complete but before ConnectListener is added, with // the ConnectListener.operationComplete() being called as though the connection was successful) // Hence we synchronise here and within the ChannelInitializer.initChannel() method. synchronized (bootstrapSync) { final ChannelHandler handler; if (endpoint.useSsl()) { handler = new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) throws Exception { synchronized (bootstrapSync) { ch.pipeline().addFirst(new SslHandler(sslEngine)); ch.pipeline().addLast(new NettyInboundHandler(ch)); } } }; } else { handler = new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) throws Exception { synchronized (bootstrapSync) { ch.pipeline().addLast(new NettyInboundHandler(ch)); } } }; } final Bootstrap bootstrap = getBootstrap(endpoint.useSsl(), sslEngine, handler); final ChannelFuture f = bootstrap.connect(endpoint.getHost(), endpoint.getPort()); f.addListener(new ConnectListener(endpoint, f, promise, listener)); } } catch (SSLException e) { if (e.getCause() == null) { promise.setFailure(new SecurityException(e.getMessage(), e)); } else { promise.setFailure(new SecurityException(e.getCause().getMessage(), e.getCause())); } } logger.exit(this, methodName); } private static int useCount = 0; /** * Request a {@link Bootstrap} for obtaining a {@link Channel} and track * that the workerGroup is being used. * * @param secure * a {@code boolean} indicating whether or not a secure channel * will be required * @param sslEngine * an {@link SSLEngine} if one should be used to secure the channel * @param handler a {@link ChannelHandler} to use for serving the requests. * @return a netty {@link Bootstrap} object suitable for obtaining a * {@link Channel} from */ private static synchronized Bootstrap getBootstrap(final boolean secure, final SSLEngine sslEngine, final ChannelHandler handler) { final String methodName = "getBootstrap"; logger.entry(methodName, secure, sslEngine); ++useCount; if (useCount == 1) { EventLoopGroup workerGroup = new NioEventLoopGroup(); bootstrap = new Bootstrap(); bootstrap.group(workerGroup); bootstrap.channel(NioSocketChannel.class); bootstrap.option(ChannelOption.SO_KEEPALIVE, true); bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 30000); bootstrap.handler(handler); } final Bootstrap result; if (secure) { result = bootstrap.clone(); result.handler(handler); } else { result = bootstrap; } logger.exit(methodName, result); return result; } /** * Decrement the use count of the workerGroup and request a graceful * shutdown once it is no longer being used by anyone. */ private static synchronized void decrementUseCount() { final String methodName = "decrementUseCount"; logger.entry(methodName); --useCount; if (useCount <= 0) { if (bootstrap != null) { bootstrap.group().shutdownGracefully(0, 500, TimeUnit.MILLISECONDS); } bootstrap = null; useCount = 0; } logger.exit(methodName); } /** * Waits for the underlying network service to terminate. * * @param timeout Maximum time to wait in seconds. * @return {@code true} if the underlying network service has terminated, {@code false} if the underlying network * service is still active after waiting the specified time. * @throws InterruptedException */ public boolean awaitTermination(long timeout) throws InterruptedException { final String methodName = "awaitTermination"; logger.entry(methodName); final boolean terminated; if (bootstrap != null) { terminated = bootstrap.group().awaitTermination(timeout, TimeUnit.SECONDS); } else { terminated = true; } logger.exit(methodName, terminated); return terminated; } }