cf.dropsonde.firehose.NettyFirehoseOnSubscribe.java Source code

Java tutorial

Introduction

Here is the source code for cf.dropsonde.firehose.NettyFirehoseOnSubscribe.java

Source

/*
 *   Copyright (c) 2015 Intellectual Reserve, Inc.  All rights reserved.
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 *
 */
package cf.dropsonde.firehose;

import org.cloudfoundry.dropsonde.events.Envelope;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.util.CharsetUtil;
import rx.Subscriber;

import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import java.io.Closeable;
import java.net.URI;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;

/**
 * @author Mike Heath
 */
class NettyFirehoseOnSubscribe implements rx.Observable.OnSubscribe<Envelope>, Closeable {

    public static final String HANDLER_NAME = "handler";
    private Channel channel;

    // Set this EventLoopGroup if an EventLoopGroup was not provided. This groups needs to be shutdown when the
    // firehose client shuts down.
    private final EventLoopGroup eventLoopGroup;
    private final Bootstrap bootstrap;

    public NettyFirehoseOnSubscribe(URI uri, String token, String subscriptionId, boolean skipTlsValidation,
            EventLoopGroup eventLoopGroup, Class<? extends SocketChannel> channelClass) {
        try {
            final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost();
            final String scheme = uri.getScheme() == null ? "ws" : uri.getScheme();
            final int port = getPort(scheme, uri.getPort());
            final URI fullUri = uri.resolve("/firehose/" + subscriptionId);

            final SslContext sslContext;
            if ("wss".equalsIgnoreCase(scheme)) {
                final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
                if (skipTlsValidation) {
                    sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
                } else {
                    TrustManagerFactory trustManagerFactory = TrustManagerFactory
                            .getInstance(TrustManagerFactory.getDefaultAlgorithm());
                    trustManagerFactory.init((KeyStore) null);
                    sslContextBuilder.trustManager(trustManagerFactory);
                }
                sslContext = sslContextBuilder.build();
            } else {
                sslContext = null;
            }

            bootstrap = new Bootstrap();
            if (eventLoopGroup == null) {
                this.eventLoopGroup = new NioEventLoopGroup();
                bootstrap.group(this.eventLoopGroup);
            } else {
                this.eventLoopGroup = null;
                bootstrap.group(eventLoopGroup);
            }
            bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 15000)
                    .channel(channelClass == null ? NioSocketChannel.class : channelClass).remoteAddress(host, port)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel c) throws Exception {
                            final HttpHeaders headers = new DefaultHttpHeaders();
                            headers.add(HttpHeaders.Names.AUTHORIZATION, token);
                            final WebSocketClientHandler handler = new WebSocketClientHandler(
                                    WebSocketClientHandshakerFactory.newHandshaker(fullUri, WebSocketVersion.V13,
                                            null, false, headers));
                            final ChannelPipeline pipeline = c.pipeline();
                            if (sslContext != null) {
                                pipeline.addLast(sslContext.newHandler(c.alloc(), host, port));
                            }
                            pipeline.addLast(new ReadTimeoutHandler(30));
                            pipeline.addLast(new HttpClientCodec(), new HttpObjectAggregator(8192));
                            pipeline.addLast(HANDLER_NAME, handler);

                            channel = c;
                        }
                    });
        } catch (NoSuchAlgorithmException | SSLException | KeyStoreException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void call(Subscriber<? super Envelope> subscriber) {
        bootstrap.connect().addListener((ChannelFutureListener) future -> {
            if (future.isSuccess()) {
                future.channel().pipeline().get(WebSocketClientHandler.class).setSubscriber(subscriber);
            } else {
                subscriber.onError(future.cause());
            }
        });
    }

    @Override
    public void close() {
        if (channel != null) {
            channel.close();
        }
        if (eventLoopGroup != null) {
            eventLoopGroup.shutdownGracefully();
        }
    }

    private int getPort(String scheme, int port) {
        if (port == -1) {
            if ("ws".equalsIgnoreCase(scheme)) {
                return 80;
            } else if ("wss".equalsIgnoreCase(scheme)) {
                return 443;
            } else {
                return -1;
            }
        } else {
            return port;
        }
    }

    public boolean isConnected() {
        return channel != null && channel.isActive();
    }

    private static class WebSocketClientHandler extends SimpleChannelInboundHandler<Object> {

        private final WebSocketClientHandshaker handshaker;

        private Subscriber<? super Envelope> subscriber;

        public WebSocketClientHandler(WebSocketClientHandshaker handshaker) {
            this.handshaker = handshaker;
        }

        @Override
        public void channelActive(ChannelHandlerContext context) throws Exception {
            handshaker.handshake(context.channel());
        }

        @Override
        public void channelInactive(ChannelHandlerContext context) throws Exception {
            subscriber.onCompleted();
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception {
            context.close();
            subscriber.onError(cause);
        }

        @Override
        protected void channelRead0(ChannelHandlerContext context, Object message) throws Exception {
            final Channel channel = context.channel();
            if (!handshaker.isHandshakeComplete()) {
                handshaker.finishHandshake(channel, (FullHttpResponse) message);
                channel.pipeline().addBefore(HANDLER_NAME, "websocket-frame-aggregator",
                        new WebSocketFrameAggregator(64 * 1024));
                subscriber.onStart();
                return;
            }

            if (message instanceof FullHttpResponse) {
                final FullHttpResponse response = (FullHttpResponse) message;
                throw new IllegalStateException("Unexpected FullHttpResponse (getStatus=" + response.getStatus()
                        + ", content=" + response.content().toString(CharsetUtil.UTF_8) + ')');
            }

            final WebSocketFrame frame = (WebSocketFrame) message;
            if (frame instanceof PingWebSocketFrame) {
                context.writeAndFlush(new PongWebSocketFrame(((PingWebSocketFrame) frame).retain().content()));
            } else if (frame instanceof BinaryWebSocketFrame) {
                final ByteBufInputStream input = new ByteBufInputStream(((BinaryWebSocketFrame) message).content());
                final Envelope envelope = Envelope.ADAPTER.decode(input);
                subscriber.onNext(envelope);
            }
        }

        public void setSubscriber(Subscriber<? super Envelope> subscriber) {
            this.subscriber = subscriber;
        }
    }

}