com.spotify.netty4.handler.codec.zmtp.ZMTPSocket.java Source code

Java tutorial

Introduction

Here is the source code for com.spotify.netty4.handler.codec.zmtp.ZMTPSocket.java

Source

/*
* Copyright (c) 2012-2015 Spotify AB
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package com.spotify.netty4.handler.codec.zmtp;

import com.google.common.base.Function;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.AsyncFunction;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;

import com.spotify.netty4.util.BatchFlusher;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Closeable;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.concurrent.GlobalEventExecutor;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.util.concurrent.Futures.immediateFailedFuture;
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static com.google.common.util.concurrent.Futures.transform;
import static com.spotify.netty4.handler.codec.zmtp.ListenableFutureAdapter.listenable;

/**
 * A simple ZMTP socket implementation for testing purposes.
 */
public class ZMTPSocket implements Closeable {

    private static final Logger log = LoggerFactory.getLogger(ZMTPSocket.class);

    private static final ListeningExecutorService EXECUTOR = MoreExecutors
            .listeningDecorator(GlobalEventExecutor.INSTANCE);

    /**
     * Represents a connected peer.
     */
    public interface ZMTPPeer {

        /**
         * Get the ZMTP session for this peer.
         */
        ZMTPSession session();

        /**
         * Send a message to this peer.
         */
        ListenableFuture<Void> send(ZMTPMessage message);
    }

    /**
     * Handles incoming messages and connection events.
     */
    public interface Handler {

        /**
         * A peer connected.
         */
        void connected(ZMTPSocket socket, ZMTPPeer peer);

        /**
         * A peer disconnected.
         */
        void disconnected(ZMTPSocket socket, ZMTPPeer peer);

        /**
         * A message was received from a peer.
         */
        void message(ZMTPSocket socket, ZMTPPeer peer, ZMTPMessage message);
    }

    private interface Sender {

        ListenableFuture<Void> send(ZMTPMessage message);
    }

    private interface Receiver {

        void receive(final ZMTPPeer peer, ZMTPMessage message);
    }

    private static final ThreadFactory DAEMON = new ThreadFactoryBuilder().setDaemon(true).build();

    private static final NioEventLoopGroup GROUP = new NioEventLoopGroup(1, DAEMON);

    private final ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

    private volatile List<ZMTPPeer> peers = ImmutableList.of();
    private volatile Map<ByteBuf, ZMTPPeer> routing = ImmutableMap.of();
    private final Object lock = new Object();

    private final Sender sender;
    private final Receiver receiver;

    private final Handler handler;
    private final ZMTPConfig config;

    private volatile boolean closed;

    /**
     * Create a new socket.
     */
    private ZMTPSocket(final Handler handler, final ZMTPConfig config) {
        this.handler = checkNotNull(handler, "handler");
        this.config = config.toBuilder().identityGenerator(new IdentityGenerator())
                .decoder(decoder(config.socketType())).encoder(encoder(config.socketType())).build();
        this.sender = sender(config.socketType());
        this.receiver = receiver(config.socketType());
    }

    /**
     * Bind this socket to an endpoint.
     */
    public ListenableFuture<InetSocketAddress> bind(final String endpoint) {
        return transform(address(endpoint), new AsyncFunction<InetSocketAddress, InetSocketAddress>() {
            @Override
            public ListenableFuture<InetSocketAddress> apply(
                    @SuppressWarnings("NullableProblems") final InetSocketAddress input) throws Exception {
                return bind(input);
            }
        });
    }

    /**
     * Bind this socket to an address.
     */
    public ListenableFuture<InetSocketAddress> bind(final InetSocketAddress address) {
        final ServerBootstrap b = new ServerBootstrap().channel(NioServerSocketChannel.class).group(GROUP)
                .childHandler(new ChannelInitializer());
        final ChannelFuture f = b.bind(address);
        channelGroup.add(f.channel());
        if (closed) {
            f.channel().close();
            return immediateFailedFuture(new ClosedChannelException());
        }
        return transform(listenable(f), new Function<Void, InetSocketAddress>() {
            @Override
            public InetSocketAddress apply(final Void input) {
                return (InetSocketAddress) f.channel().localAddress();
            }
        });
    }

    /**
     * Connect this socket to an endpoint.
     */
    public ListenableFuture<Void> connect(final String endpoint) {
        return transform(address(endpoint), new AsyncFunction<InetSocketAddress, Void>() {
            @Override
            public ListenableFuture<Void> apply(
                    @SuppressWarnings("NullableProblems") final InetSocketAddress address) throws Exception {
                return connect(address);
            }
        });
    }

    /**
     * Connect this socket to an address.
     */
    public ListenableFuture<Void> connect(final InetSocketAddress address) {
        final Bootstrap b = new Bootstrap().group(GROUP).channel(NioSocketChannel.class)
                .handler(new ChannelInitializer());
        final ChannelFuture f = b.connect(address);
        if (closed) {
            f.channel().close();
            return immediateFailedFuture(new ClosedChannelException());
        }
        return listenable(f);
    }

    /**
     * Send a message on this socket.
     */
    public ListenableFuture<Void> send(final ZMTPMessage message) {
        return sender.send(message);
    }

    /**
     * Close this socket.
     */
    @Override
    public void close() {
        closed = true;
        channelGroup.close().awaitUninterruptibly();
    }

    /**
     * Get a list of all connected peers.
     */
    public List<ZMTPPeer> peers() {
        return peers;
    }

    /**
     * Get a sender for a socket type.
     */
    private ZMTPEncoder.Factory encoder(final ZMTPSocketType socketType) {
        switch (socketType) {
        case ROUTER:
            return RoutingEncoder.FACTORY;
        default:
            return ZMTPMessageEncoder.FACTORY;
        }
    }

    /**
     * Get a decoder for a socket type.
     */
    private ZMTPDecoder.Factory decoder(final ZMTPSocketType socketType) {
        switch (socketType) {
        case ROUTER:
            return RoutingDecoder.FACTORY;
        default:
            return ZMTPMessageDecoder.FACTORY;
        }
    }

    /**
     * Get a receiver that implements the appropriate socket type behavior.
     */
    private Receiver receiver(final ZMTPSocketType socketType) {
        switch (socketType) {
        case PUSH:
            return new DropReceiver();
        case DEALER:
        case ROUTER:
        case SUB:
        case PUB:
        case PULL:
            return new PassReceiver();
        default:
            throw new IllegalArgumentException("Unsupported socket type: " + socketType);
        }
    }

    /**
     * Get a sender that implements the appropriate socket type behavior.
     */
    private Sender sender(final ZMTPSocketType socketType) {
        switch (socketType) {
        case PUSH:
        case DEALER:
            return new RoundRobinSender();
        case ROUTER:
            return new RoutingSender();
        case SUB:
        case PUB:
            return new BroadcastSender();
        case PULL:
            return new UnsupportedOperationSender();
        default:
            throw new IllegalArgumentException("Unsupported socket type: " + socketType);
        }
    }

    /**
     * Resolve an endpoint into an address. Async to avoid blocking on DNS resolution.
     */
    private static ListenableFuture<InetSocketAddress> address(final String endpoint) {
        return EXECUTOR.submit(new Callable<InetSocketAddress>() {
            @Override
            public InetSocketAddress call() throws Exception {
                final URI uri = URI.create(endpoint);
                checkArgument("tcp".equals(uri.getScheme()), "Unsupported endpoint type: %s", uri.getScheme());
                final List<String> parts = Splitter.on(':').splitToList(uri.getAuthority());
                final String hostString = parts.get(0);
                final InetAddress host = hostString.equals("*") ? null : InetAddress.getByName(hostString);
                final String portString = parts.get(1);
                final int port = portString.equals("*") ? 0 : Integer.valueOf(portString);
                return new InetSocketAddress(host, port);
            }
        });
    }

    /**
     * Handles a single connected peer.
     */
    private class Peer extends ChannelInboundHandlerAdapter implements ZMTPPeer {

        private final Channel ch;
        private final BatchFlusher flusher;
        private final ZMTPSession session;

        public Peer(final Channel ch, final ZMTPSession session) {
            this.ch = ch;
            this.session = session;
            this.flusher = new BatchFlusher(ch);
        }

        @Override
        public ZMTPSession session() {
            return session;
        }

        @Override
        public ListenableFuture<Void> send(final ZMTPMessage message) {
            final ChannelFuture f = ch.write(message);
            flusher.flush();
            return listenable(f);
        }

        @Override
        public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
            if (evt instanceof ZMTPHandshakeSuccess) {
                register(session.peerIdentity());
                try {
                    handler.connected(ZMTPSocket.this, this);
                } catch (Exception e) {
                    log.error("handler threw exception", e);
                }
            }
        }

        @Override
        public void channelInactive(final ChannelHandlerContext ctx) throws Exception {
            deregister(session.peerIdentity());
            try {
                handler.disconnected(ZMTPSocket.this, this);
            } catch (Exception e) {
                log.error("handler threw exception", e);
            }
        }

        private void register(final ByteBuffer identity) {
            synchronized (lock) {
                peers = ImmutableList.<ZMTPPeer>builder().addAll(peers).add(this).build();
                routing = ImmutableMap.<ByteBuf, ZMTPPeer>builder().putAll(routing)
                        .put(Unpooled.wrappedBuffer(identity), this).build();
            }
        }

        private void deregister(final ByteBuffer identity) {
            synchronized (lock) {
                final ImmutableList.Builder<ZMTPPeer> newPeers = ImmutableList.builder();
                for (final ZMTPPeer handler : peers) {
                    if (handler != this) {
                        newPeers.add(handler);
                    }
                }
                peers = newPeers.build();
                final ImmutableMap.Builder<ByteBuf, ZMTPPeer> newRouting = ImmutableMap.builder();
                final ByteBuf id = Unpooled.wrappedBuffer(identity);
                for (final Map.Entry<ByteBuf, ZMTPPeer> entry : routing.entrySet()) {
                    if (!entry.getKey().equals(id)) {
                        newRouting.put(entry.getKey(), entry.getValue());
                    }
                }
                routing = newRouting.build();
            }
        }

        @Override
        public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
            if (msg instanceof ZMTPMessage) {
                try {
                    receiver.receive(this, (ZMTPMessage) msg);
                } catch (Exception e) {
                    log.error("handler threw exception", e);
                }
            }
        }

        @Override
        public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception {
            log.warn("exception", cause);
            ctx.close();
        }
    }

    private class ChannelInitializer extends io.netty.channel.ChannelInitializer {

        @Override
        protected void initChannel(final Channel ch) throws Exception {
            channelGroup.add(ch);
            final ZMTPCodec codec = ZMTPCodec.from(config);
            final Peer peer = new Peer(ch, codec.session());
            ch.pipeline().addLast(codec, peer);
        }
    }

    /**
     * A sender that round robin load balances messages over all connected peers.
     */
    private class RoundRobinSender implements Sender {

        private final AtomicInteger i = new AtomicInteger();

        @Override
        public ListenableFuture<Void> send(final ZMTPMessage message) {
            final List<ZMTPPeer> channels = peers();
            if (channels.size() == 0) {
                return immediateFailedFuture(new ClosedChannelException());
            }
            final ZMTPPeer handler = next(channels);
            return handler.send(message);
        }

        private ZMTPPeer next(final List<ZMTPPeer> channels) {
            assert !channels.isEmpty();
            int next;
            int prev;
            do {
                prev = i.get();
                next = prev + 1;
                if (next >= channels.size()) {
                    next = 0;
                }
            } while (!i.compareAndSet(prev, next));
            return channels.get(next);
        }
    }

    /**
     * A sender that routes message to the appropriate peer using the front identity frame.
     */
    private class RoutingSender implements Sender {

        @Override
        public ListenableFuture<Void> send(final ZMTPMessage message) {
            if (message.size() == 0) {
                return immediateFailedFuture(new IllegalArgumentException("empty message"));
            }
            final ByteBuf identity = message.frame(0);
            final ZMTPPeer peer = routing.get(identity);
            if (peer == null) {
                message.release();
                return immediateFailedFuture(new ClosedChannelException());
            }
            return peer.send(message);
        }
    }

    /**
     * A sender that broadcasts messages to all connected peers.
     */
    private class BroadcastSender implements Sender {

        @Override
        public ListenableFuture<Void> send(final ZMTPMessage message) {
            final List<ZMTPPeer> channels = peers();
            for (final ZMTPPeer handler : channels) {
                handler.send(message);
            }
            return immediateFuture(null);
        }
    }

    /**
     * A sender that immediately fails all send operations.
     */
    private class UnsupportedOperationSender implements Sender {

        @Override
        public ListenableFuture<Void> send(final ZMTPMessage message) {
            return immediateFailedFuture(new UnsupportedOperationException());
        }
    }

    /**
     * A receiver that drops all incoming messages.
     */
    private class DropReceiver implements Receiver {

        @Override
        public void receive(final ZMTPPeer peer, final ZMTPMessage message) {
            message.release();
        }
    }

    /**
     * A receive that passes on all incoming messages to the {@link Handler}.
     */
    private class PassReceiver implements Receiver {

        @Override
        public void receive(final ZMTPPeer peer, final ZMTPMessage message) {
            try {
                handler.message(ZMTPSocket.this, peer, message);
            } catch (Exception e) {
                log.error("handler threw exception", e);
            }
        }
    }

    /**
     * A {@link ZMTPMessage} decoder that pushes the peer identity onto the front of the message.
     */
    private static class RoutingDecoder implements ZMTPDecoder {

        private static final ZMTPDecoder.Factory FACTORY = new Factory() {
            @Override
            public ZMTPDecoder decoder(final ZMTPSession session) {
                return new RoutingDecoder(session.peerIdentity());
            }
        };

        private final ByteBuf DELIMITER = Unpooled.EMPTY_BUFFER;

        private final ByteBuffer identity;

        private final List<ByteBuf> frames = new ArrayList<ByteBuf>();
        private int frameLength;

        RoutingDecoder(final ByteBuffer identity) {
            this.identity = identity;
            reset();
        }

        private void reset() {
            frames.clear();
            frames.add(Unpooled.wrappedBuffer(identity));
            frameLength = 0;
        }

        @Override
        public void header(final ChannelHandlerContext ctx, final long length, final boolean more,
                final List<Object> out) {
            frameLength = (int) length;
        }

        @Override
        public void content(final ChannelHandlerContext ctx, final ByteBuf data, final List<Object> out) {
            if (data.readableBytes() < frameLength) {
                return;
            }
            if (frameLength == 0) {
                frames.add(DELIMITER);
                return;
            }
            final ByteBuf frame = data.readSlice(frameLength);
            frame.retain();
            frames.add(frame);
        }

        @Override
        public void finish(final ChannelHandlerContext ctx, final List<Object> out) {
            final ZMTPMessage message = ZMTPMessage.from(frames);
            reset();
            out.add(message);
        }

        @Override
        public void close() {
            for (final ByteBuf frame : frames) {
                frame.release();
            }
            frames.clear();
        }
    }

    /**
     * A {@link ZMTPMessage} encoder that pops the peer identity from the front of the message.
     */
    private static class RoutingEncoder implements ZMTPEncoder {

        private static final ZMTPEncoder.Factory FACTORY = new Factory() {
            @Override
            public ZMTPEncoder encoder(final ZMTPSession session) {
                return new RoutingEncoder();
            }
        };

        @Override
        public void estimate(final Object msg, final ZMTPEstimator estimator) {
            final ZMTPMessage message = (ZMTPMessage) msg;
            for (int i = 1; i < message.size(); i++) {
                final ByteBuf frame = message.frame(i);
                estimator.frame(frame.readableBytes());
            }
        }

        @Override
        public void encode(final Object msg, final ZMTPWriter writer) {
            final ZMTPMessage message = (ZMTPMessage) msg;
            for (int i = 1; i < message.size(); i++) {
                final ByteBuf frame = message.frame(i);
                final boolean more = i < message.size() - 1;
                final ByteBuf dst = writer.frame(frame.readableBytes(), more);
                dst.writeBytes(frame, frame.readerIndex(), frame.readableBytes());
            }
        }

        @Override
        public void close() {
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {

        private final ZMTPConfig.Builder config = ZMTPConfig.builder();
        private Handler handler;

        public Builder handler(final Handler handler) {
            this.handler = handler;
            return this;
        }

        public Builder protocol(final ZMTPProtocol protocol) {
            config.protocol(protocol);
            return this;
        }

        public Builder interop(final boolean interop) {
            config.interop(interop);
            return this;
        }

        public Builder type(final ZMTPSocketType socketType) {
            config.socketType(socketType);
            return this;
        }

        public Builder identity(final CharSequence identity) {
            config.localIdentity(identity);
            return this;
        }

        public Builder identity(final byte[] identity) {
            config.localIdentity(identity);
            return this;
        }

        public Builder identity(final ByteBuffer localIdentity) {
            config.localIdentity(localIdentity);
            return this;
        }

        public ZMTPSocket build() {
            return new ZMTPSocket(handler, config.build());
        }
    }

    /**
     * An identity generator that keeps an integer counter per {@link ZMTPSocket}.
     */
    private static class IdentityGenerator implements ZMTPIdentityGenerator {

        private final AtomicInteger peerIdCounter = new AtomicInteger(new SecureRandom().nextInt());

        @Override
        public ByteBuffer generateIdentity(final ZMTPSession session) {
            final ByteBuffer generated = ByteBuffer.allocate(5);
            generated.put((byte) 0);
            generated.putInt(peerIdCounter.incrementAndGet());
            generated.flip();
            return generated;
        }
    }
}