com.barchart.netty.client.base.ConnectableBase.java Source code

Java tutorial

Introduction

Here is the source code for com.barchart.netty.client.base.ConnectableBase.java

Source

/**
 * Copyright (C) 2011-2014 Barchart, Inc. <http://www.barchart.com/>
 *
 * All rights reserved. Licensed under the OSI BSD License.
 *
 * http://www.opensource.org/licenses/bsd-license.php
 */
package com.barchart.netty.client.base;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.timeout.ReadTimeoutException;
import io.netty.handler.timeout.ReadTimeoutHandler;

import java.net.URI;
import java.util.Collection;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import rx.Observable;
import rx.subjects.PublishSubject;
import rx.subjects.ReplaySubject;

import com.barchart.netty.client.BootstrapInitializer;
import com.barchart.netty.client.Connectable;
import com.barchart.netty.client.policy.ReconnectPolicy;
import com.barchart.netty.client.transport.TransportFactory;
import com.barchart.netty.client.transport.TransportProtocol;
import com.barchart.netty.common.pipeline.PipelineInitializer;

/**
 * A base Connectable implementation which provides basic configuration,
 * connection workflow, status monitoring, and message subscriptions.
 */
public abstract class ConnectableBase<T extends Connectable<T>> implements Connectable<T>, PipelineInitializer {

    protected final Logger log = LoggerFactory.getLogger(getClass());

    protected abstract static class Builder<B extends Builder<B, C>, C extends ConnectableBase<C>> {

        protected Builder() {
        }

        /* Standard fields */
        protected TransportProtocol transport;
        protected EventLoopGroup eventLoop = null;
        protected BootstrapInitializer bootstrapper = null;

        /* Implementation specific */
        protected long timeout = 0;

        /**
         * Set the remote host address to connect to.
         *
         * @see com.barchart.netty.client.transport.TransportFactory#create(URI)
         */
        @SuppressWarnings("unchecked")
        public B host(final String url) {
            transport = TransportFactory.create(url);
            return (B) this;
        }

        /**
         * Retrieve the host TransportProtocol for this connectable for subclass
         * builders.
         */
        protected TransportProtocol host() {
            return transport;
        }

        /**
         * Set the connection read timeout. If the specified time elapses
         * between inbound messages, the connection will terminate. To
         * automatically reconnect after a timeout, set a
         * {@link ReconnectPolicy}.
         */
        @SuppressWarnings("unchecked")
        public B timeout(final long timeout_, final TimeUnit unit_) {
            timeout = TimeUnit.MILLISECONDS.convert(timeout_, unit_);
            return (B) this;
        }

        /**
         * Set the Netty EventLoopGroup for this Connectable.
         */
        @SuppressWarnings("unchecked")
        public B eventLoop(final EventLoopGroup group_) {
            eventLoop = group_;
            return (B) this;
        }

        /**
         * Roll-your-own Netty bootstrap for additional flexibility in
         * configuration channel options. You should only call options() on the
         * provided Bootstrap, as other values (remote host, channel type,
         * channel initializer, etc) may be overwritten by the default
         * bootstrapping process.
         */
        @SuppressWarnings("unchecked")
        public B bootstrapper(final BootstrapInitializer bootstrapper_) {
            bootstrapper = bootstrapper_;
            return (B) this;
        }

        protected C configure(final C client) {

            if (eventLoop != null) {
                client.eventLoopGroup(eventLoop);
            }

            if (bootstrapper != null) {
                client.bootstrapper(bootstrapper);
            }

            client.timeout(timeout);

            return client;

        }

        /**
         * Build a new Connectable client with the current configuration.
         */
        protected abstract C build();

    }

    /* Message subscriptions */
    private final ConcurrentMap<Class<?>, MessageSubscription<?>> subscriptions = new ConcurrentHashMap<Class<?>, MessageSubscription<?>>();

    /* Connection state */
    private final PublishSubject<Connectable.StateChange<T>> stateChanges = PublishSubject.create();
    private Connectable.State lastState = null;

    /* Netty resources */
    protected Channel channel;

    private final TransportProtocol transport;
    private final ChannelInitializer<Channel> channelInitializer;

    private EventLoopGroup group;
    private BootstrapInitializer bootstrapper = null;

    /* Read timeout */
    private long timeout = 0;

    /**
     * Create a new Connectable client. This method is intended to be called by
     * subclass Builder implementations.
     *
     * @param eventLoop_ The Netty EventLoopGroup to use for transport
     *            operations
     * @param address_ The remote peer address
     * @param transport_ The transport type
     */
    protected ConnectableBase(final TransportProtocol transport_) {

        transport = transport_;

        group = new NioEventLoopGroup();

        channelInitializer = new ClientPipelineInitializer();

    }

    /**
     * The current read timeout in milliseconds. 0 indicates no timeout.
     */
    protected long timeout() {
        return timeout;
    }

    /**
     * Set the read timeout in milliseconds. Set to 0 to disable timeout.
     */
    protected void timeout(final long millis) {
        timeout = millis;
    }

    private Bootstrap bootstrap() {

        final Bootstrap bootstrap = transport.bootstrap();

        if (bootstrapper != null) {
            bootstrapper.initBootstrap(bootstrap);
        }

        return bootstrap;

    }

    protected void bootstrapper(final BootstrapInitializer bi) {
        bootstrapper = bi;
    }

    protected void eventLoopGroup(final EventLoopGroup group_) {
        group = group_;
    }

    @Override
    public Observable<T> connect() {

        if (transport == null) {
            throw new IllegalArgumentException("Transport cannot be null");
        }

        if (channelInitializer == null) {
            throw new IllegalArgumentException("Channel initializer cannot be null");
        }

        log.debug("Client connecting to " + transport.address().toString());
        changeState(Connectable.State.CONNECTING);

        final ChannelFuture future = bootstrap() //
                .group(group) //
                .handler(new ClientPipelineInitializer()) //
                .connect();

        channel = future.channel();

        final ReplaySubject<T> connectObs = ReplaySubject.create();

        future.addListener(new ChannelFutureListener() {

            @SuppressWarnings("unchecked")
            @Override
            public void operationComplete(final ChannelFuture future) throws Exception {

                if (!future.isSuccess()) {
                    changeState(Connectable.State.CONNECT_FAIL);
                    connectObs.onError(future.cause());
                } else {
                    connectObs.onNext((T) ConnectableBase.this);
                    connectObs.onCompleted();
                }

            }

        });

        return connectObs;

    }

    @SuppressWarnings("unchecked")
    @Override
    public Observable<T> disconnect() {

        if (channel != null && channel.isActive()) {

            changeState(Connectable.State.DISCONNECTING);

            return ChannelFutureObservable.create(channel.close(), (T) this);

        }

        return Observable.<T>just((T) this);

    }

    @Override
    public Observable<Connectable.StateChange<T>> stateChanges() {
        return stateChanges;
    }

    @Override
    public Connectable.State state() {
        return lastState;
    }

    /**
     * Send a message to the connected peer. The message type must be supported
     * by the internal Netty pipeline.
     *
     * @param message An object to encode and send to the remote peer
     */
    protected <U> Observable<U> send(final U message) {

        if (!channel.isActive()) {
            throw new IllegalStateException("Channel is not active");
        }

        return ChannelFutureObservable.create(channel.writeAndFlush(message), message);

    }

    /**
     * Receive messages of a specific type from the connected peer.
     *
     * The message type must be supported by the internal Netty pipeline.
     * Channel handlers to decode different message types should be provided by
     * the subclass by overriding the initPipeline() method, otherwise the only
     * message type available will be ByteBuf.class.
     *
     * This method is not thread-safe. It if is called at the same time as a
     * connect() attempt the message handler may fail to register.
     *
     * @param type The message type
     */
    @SuppressWarnings("unchecked")
    protected <U> Observable<U> receive(final Class<U> type) {

        MessageSubscription<U> subscription = (MessageSubscription<U>) subscriptions.get(type);

        if (subscription == null) {

            subscription = new MessageSubscription<U>(type);

            final MessageSubscription<?> existing = subscriptions.putIfAbsent(type, subscription);

            if (existing != null) {
                subscription = (MessageSubscription<U>) existing;
            }

        }

        return subscription.observable();

    }

    protected final void changeState(final Connectable.State state) {

        final Connectable.State previous = lastState;
        lastState = state;

        stateChanges.onNext(new Connectable.StateChange<T>() {

            @SuppressWarnings("unchecked")
            @Override
            public T connectable() {
                return (T) ConnectableBase.this;
            }

            @Override
            public Connectable.State state() {
                return state;
            }

            @Override
            public Connectable.State previous() {
                return previous;
            }

        });

    }

    private class ConnectionStateHandler extends ChannelInboundHandlerAdapter {

        @Override
        public void channelActive(final ChannelHandlerContext ctx) throws Exception {

            changeState(Connectable.State.CONNECTED);

            super.channelActive(ctx);

        }

        @Override
        public void channelInactive(final ChannelHandlerContext ctx) throws Exception {

            channel = null;

            super.channelInactive(ctx);

            changeState(Connectable.State.DISCONNECTED);

        }

        @Override
        public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {

            if (cause instanceof ReadTimeoutException) {

                // No activity from peer
                changeState(Connectable.State.TIMEOUT);
                ctx.close();

            } else {

                log.warn(cause.getClass().getName() + ": " + cause.getMessage());
                // ctx.fireExceptionCaught(cause);

            }

        }

    }

    private class ClientPipelineInitializer extends ChannelInitializer<Channel> {

        @Override
        public void initChannel(final Channel ch) throws Exception {

            final ChannelPipeline pipeline = ch.pipeline();

            // User-specified pipeline handlers (message codecs)
            initPipeline(pipeline);

            // Transport-required pipeline handlers
            transport.initPipeline(pipeline);

            // Connection read timeout handler
            if (timeout > 0) {
                pipeline.addFirst(new ReadTimeoutHandler(timeout, TimeUnit.MILLISECONDS));
            }

            // Monitor connection state
            pipeline.addLast(new ConnectionStateHandler());

            // Process messages and route to observers
            pipeline.addLast(new MessageRouter(subscriptions.values()));

        }

    }

    protected static class ChannelFutureObservable {

        public static <T> Observable<T> create(final ChannelFuture future, final T result) {

            final ReplaySubject<T> subject = ReplaySubject.create();

            future.addListener(new ChannelFutureListener() {

                @Override
                public void operationComplete(final ChannelFuture future) throws Exception {

                    if (!future.isSuccess()) {
                        subject.onError(future.cause());
                    } else {
                        subject.onNext(result);
                        subject.onCompleted();
                    }

                }

            });

            return subject;

        }

    }

    private static class MessageRouter extends SimpleChannelInboundHandler<Object> {

        private final Collection<MessageSubscription<?>> subscriptions;

        public MessageRouter(final Collection<MessageSubscription<?>> subscriptions_) {
            super(Object.class);
            subscriptions = subscriptions_;
        }

        @Override
        public void channelRead0(final ChannelHandlerContext ctx, final Object msg) throws Exception {

            for (final MessageSubscription<?> subscription : subscriptions) {
                subscription.route(msg);
            }

        }

    }

    private static class MessageSubscription<M> {

        private final Class<M> type;
        private final PublishSubject<M> publish;

        public MessageSubscription(final Class<M> type_) {
            type = type_;
            publish = PublishSubject.create();
        }

        public Observable<M> observable() {
            return publish;
        }

        public void route(final Object msg) throws Exception {
            if (type.isInstance(msg)) {
                publish.onNext(type.cast(msg));
            }
        }

    }

}