com.spotify.netty4.handler.codec.zmtp.benchmarks.CustomReqRepBenchmark.java Source code

Java tutorial

Introduction

Here is the source code for com.spotify.netty4.handler.codec.zmtp.benchmarks.CustomReqRepBenchmark.java

Source

/*
 * Copyright (c) 2012-2014 Spotify AB
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.spotify.netty4.handler.codec.zmtp.benchmarks;

import com.google.common.collect.FluentIterable;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.SettableFuture;

import com.spotify.netty4.handler.codec.zmtp.ZMTPCodec;
import com.spotify.netty4.handler.codec.zmtp.ZMTPDecoder;
import com.spotify.netty4.handler.codec.zmtp.ZMTPEncoder;
import com.spotify.netty4.handler.codec.zmtp.ZMTPEstimator;
import com.spotify.netty4.handler.codec.zmtp.ZMTPHandshakeSuccess;
import com.spotify.netty4.handler.codec.zmtp.ZMTPWriter;
import com.spotify.netty4.util.BatchFlusher;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.channel.MessageSizeEstimator;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.internal.chmv8.ForkJoinPool;

import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.DEALER;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.ROUTER;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.util.Arrays.asList;

public class CustomReqRepBenchmark {

    private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);

    private static final InetSocketAddress ANY_PORT = new InetSocketAddress("127.0.0.1", 0);
    private static final Thread.UncaughtExceptionHandler UNCAUGHT_EXCEPTION_HANDLER = new Thread.UncaughtExceptionHandler() {
        @Override
        public void uncaughtException(final Thread thread, final Throwable throwable) {
            throwable.printStackTrace();
        }
    };

    public static void main(final String... args) throws InterruptedException {
        final ProgressMeter meter = new ProgressMeter("requests", true);

        // Codecs
        final ZMTPCodec serverCodec = ZMTPCodec.builder().socketType(ROUTER).encoder(ReplyEncoder.class)
                .decoder(RequestDecoder.class).build();

        final ZMTPCodec clientCodec = ZMTPCodec.builder().socketType(DEALER).encoder(RequestEncoder.class)
                .decoder(ReplyDecoder.class).build();

        // Server
        final Executor serverExecutor = new ForkJoinPool(1, ForkJoinPool.defaultForkJoinWorkerThreadFactory,
                UNCAUGHT_EXCEPTION_HANDLER, true);
        final ServerBootstrap serverBootstrap = new ServerBootstrap()
                .group(new NioEventLoopGroup(1), new NioEventLoopGroup()).channel(NioServerSocketChannel.class)
                .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
                .childOption(ChannelOption.MESSAGE_SIZE_ESTIMATOR, ByteBufSizeEstimator.INSTANCE)
                .childHandler(new ChannelInitializer<NioSocketChannel>() {
                    @Override
                    protected void initChannel(final NioSocketChannel ch) throws Exception {
                        ch.pipeline().addLast(serverCodec);
                        ch.pipeline().addLast(new ServerRequestTracker());
                        ch.pipeline().addLast(new ServerHandler(serverExecutor));
                    }
                });
        final Channel server = serverBootstrap.bind(ANY_PORT).awaitUninterruptibly().channel();

        // Client
        final Executor clientExecutor = new ForkJoinPool(1, ForkJoinPool.defaultForkJoinWorkerThreadFactory,
                UNCAUGHT_EXCEPTION_HANDLER, true);
        final SocketAddress address = server.localAddress();
        final Bootstrap clientBootstrap = new Bootstrap().group(new NioEventLoopGroup())
                .channel(NioSocketChannel.class).option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
                .option(ChannelOption.MESSAGE_SIZE_ESTIMATOR, ByteBufSizeEstimator.INSTANCE)
                .handler(new ChannelInitializer<NioSocketChannel>() {
                    @Override
                    protected void initChannel(final NioSocketChannel ch) throws Exception {
                        ch.pipeline().addLast(clientCodec);
                        ch.pipeline().addLast(new ClientRequestTracker());
                        ch.pipeline().addLast(new ClientHandler(meter, clientExecutor));
                    }
                });
        final Channel client = clientBootstrap.connect(address).awaitUninterruptibly().channel();

        // Run until client is closed
        client.closeFuture().await();
    }

    private static class ServerHandler extends ChannelInboundHandlerAdapter {

        public static final ByteBuffer REPLY_PAYLOAD = UTF_8.encode("hello world");

        private final Executor executor;

        private BatchFlusher flusher;

        public ServerHandler(final Executor executor) {
            this.executor = executor;
        }

        @Override
        public void channelRegistered(final ChannelHandlerContext ctx) throws Exception {
            super.channelRegistered(ctx);
            this.flusher = new BatchFlusher(ctx.channel());
        }

        @Override
        public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
            executor.execute(new Runnable() {
                @Override
                public void run() {
                    final Request request = (Request) msg;
                    ctx.write(request.reply(200, REPLY_PAYLOAD));
                    flusher.flush();
                }
            });
        }
    }

    private static class ClientHandler extends ChannelInboundHandlerAdapter {

        private static final int CONCURRENCY = 1000;

        private final ProgressMeter meter;
        private final Executor executor;

        private BatchFlusher flusher;
        private ChannelHandlerContext ctx;

        private long seq = new SecureRandom().nextLong();

        public ClientHandler(final ProgressMeter meter, final Executor executor) {
            this.meter = meter;
            this.executor = executor;
        }

        @Override
        public void channelRegistered(final ChannelHandlerContext ctx) throws Exception {
            this.ctx = ctx;
            this.flusher = new BatchFlusher(ctx.channel());
        }

        @Override
        public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
            if (evt instanceof ZMTPHandshakeSuccess) {
                for (int i = 0; i < CONCURRENCY; i++) {
                    send(ctx);
                }
            }
        }

        private void send(final ChannelHandlerContext ctx) {
            final RequestPromise promise = new RequestPromise(this.ctx.channel());
            ctx.write(req(), promise);
            flusher.flush();
            Futures.addCallback(promise.replyFuture(), new FutureCallback<Reply>() {
                @Override
                public void onSuccess(final Reply reply) {
                    final long latency = System.nanoTime() - reply.id().timestamp();
                    meter.inc(1, latency);
                    send(ctx);
                }

                @Override
                public void onFailure(final Throwable t) {
                    System.err.println("failure");
                }
            }, executor);
        }

        public static long rand(long x) {
            x ^= (x << 21);
            x ^= (x >>> 35);
            x ^= (x << 4);
            return x;
        }

        private Request req() {
            seq = rand(seq);
            final MessageId id = new MessageId(seq, System.nanoTime());
            return new Request(id, "foo://bar/some/resource", "GET", EMPTY_BUFFER);
        }
    }

    private static class RequestPromise extends DefaultChannelPromise {

        private final SettableFuture<Reply> replyFuture = SettableFuture.create();

        public RequestPromise(final Channel channel) {
            super(channel);
        }

        public SettableFuture<Reply> replyFuture() {
            return replyFuture;
        }
    }

    private static class Message {

        private final MessageId id;
        private final CharSequence uri;
        private final CharSequence method;
        private final ByteBuffer payload;

        public Message(final MessageId id, final CharSequence method, final CharSequence uri,
                final ByteBuffer payload) {
            this.id = id;
            this.method = method;
            this.uri = uri;
            this.payload = payload;
        }

        public MessageId id() {
            return id;
        }

        public CharSequence uri() {
            return uri;
        }

        public CharSequence method() {
            return method;
        }

        public ByteBuffer payload() {
            return payload;
        }
    }

    private static class Request extends Message {

        public Request(final MessageId id, final CharSequence uri, final CharSequence method,
                final ByteBuffer payload) {
            super(id, method, uri, payload);
        }

        public Reply reply(final int code, final ByteBuffer payload) {
            return new Reply(id(), uri(), method(), code, payload);
        }
    }

    private static class Reply extends Message {

        private final int code;

        public Reply(final MessageId id, final CharSequence uri, final CharSequence method, final int code,
                final ByteBuffer payload) {
            super(id, method, uri, payload);
            this.code = code;
        }

        public int statusCode() {
            return code;
        }
    }

    private static class MessageId {

        private final long seq;
        private final long timestamp;

        public MessageId(final long seq, final long timestamp) {
            this.seq = seq;
            this.timestamp = timestamp;
        }

        public long seq() {
            return seq;
        }

        public long timestamp() {
            return timestamp;
        }

        public static MessageId from(final long seq, final long timestamp) {
            return new MessageId(seq, timestamp);
        }

        @Override
        public boolean equals(final Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }

            final MessageId messageId = (MessageId) o;

            if (seq != messageId.seq) {
                return false;
            }
            return timestamp == messageId.timestamp;
        }

        @Override
        public int hashCode() {
            int result = (int) (seq ^ (seq >>> 32));
            result = 31 * result + (int) (timestamp ^ (timestamp >>> 32));
            return result;
        }

        @Override
        public String toString() {
            return "MessageId{" + "seq=" + seq + ", timestamp=" + timestamp + '}';
        }
    }

    private static class RequestEncoder implements ZMTPEncoder {

        @Override
        public void estimate(final Object message, final ZMTPEstimator estimator) {
            final Request request = (Request) message;
            estimator.frame(request.uri().length());
            estimator.frame(request.method().length());
            estimator.frame(16);
            estimator.frame(request.payload().remaining());
        }

        @Override
        public void encode(final Object message, final ZMTPWriter writer) {
            final Request request = (Request) message;
            writeAscii(writer, request.uri());
            writeAscii(writer, request.method());
            writeId(writer, request.id());
            writePayload(writer, request.payload());
        }

        @Override
        public void close() {
        }
    }

    private static class ReplyEncoder implements ZMTPEncoder {

        @Override
        public void estimate(final Object message, final ZMTPEstimator estimator) {
            final Reply reply = (Reply) message;
            estimator.frame(reply.uri().length());
            estimator.frame(reply.method().length());
            estimator.frame(16);
            estimator.frame(4);
            estimator.frame(reply.payload().remaining());
        }

        @Override
        public void encode(final Object message, final ZMTPWriter writer) {
            final Reply reply = (Reply) message;
            writeAscii(writer, reply.uri());
            writeAscii(writer, reply.method());
            writeId(writer, reply.id());
            writer.frame(4, true).writeInt(reply.statusCode());
            writePayload(writer, reply.payload());
        }

        @Override
        public void close() {
        }
    }

    private static void writeAscii(final ZMTPWriter writer, final CharSequence s) {
        final ByteBuf frame = writer.frame(s.length(), true);
        if (s instanceof AsciiString) {
            ((AsciiString) s).write(frame);
        } else {
            for (int i = 0; i < s.length(); i++) {
                frame.writeByte(s.charAt(i));
            }
        }
    }

    private static void writeId(final ZMTPWriter writer, final MessageId id) {
        writer.frame(16, true).writeLong(id.seq()).writeLong(id.timestamp());
    }

    private static void writePayload(final ZMTPWriter writer, final ByteBuffer payload) {
        final ByteBuf buf = writer.frame(payload.remaining(), false);
        if (payload.hasArray()) {
            buf.writeBytes(payload.array(), payload.arrayOffset() + payload.position(), payload.remaining());
        } else {
            final int pos = payload.position();
            for (int i = 0; i < payload.remaining(); i++) {
                buf.writeByte(payload.get(pos + i));
            }
            payload.position(pos);
        }
    }

    private static ByteBuffer readPayload(final ByteBuf data, final int size) {
        if (size == 0) {
            return EMPTY_BUFFER;
        }
        final ByteBuffer buffer = ByteBuffer.allocate(size);
        data.readBytes(buffer);
        buffer.flip();
        return buffer;
    }

    private static int readStatusCode(final ByteBuf data, final int size) {
        if (size != 4) {
            throw new IllegalArgumentException();
        }
        return data.readInt();
    }

    private static CharSequence readAscii(final ByteBuf data, final int size) {
        final byte[] chars = new byte[size];
        data.readBytes(chars);
        return new AsciiString(chars);
    }

    private static final AsciiString[] METHODS = FluentIterable
            .from(asList("GET", "POST", "PUT", "DELETE", "PATCH")).transform(AsciiString.ASCII_STRING_FROM_STRING)
            .toArray(AsciiString.class);

    private static CharSequence readMethod(final ByteBuf data, final int size) {
        for (final AsciiString method : METHODS) {
            if (asciiEquals(method, data, size)) {
                data.skipBytes(size);
                return method;
            }
        }
        return readAscii(data, size);
    }

    private static boolean asciiEquals(final AsciiString s, final ByteBuf data, final int size) {
        final int ix = data.readerIndex();
        if (size != s.length()) {
            return false;
        }
        for (int i = 0; i < size; i++) {
            char c = (char) data.getByte(ix + i);
            if (c != s.charAt(i)) {
                return false;
            }
        }
        return true;
    }

    private static MessageId readId(final ByteBuf data, final int size) {
        if (size != 16) {
            throw new IllegalArgumentException();
        }
        final long seq = data.readLong();
        final long timestamp = data.readLong();
        return MessageId.from(seq, timestamp);
    }

    private static class RequestDecoder implements ZMTPDecoder {

        enum State {
            URI, METHOD, ID, PAYLOAD
        }

        private State state = State.URI;

        private CharSequence uri;
        private CharSequence method;
        private MessageId id;
        private ByteBuffer payload;

        private int frameLength;

        @Override
        public void header(final ChannelHandlerContext ctx, final long length, final boolean more,
                final List<Object> out) {
            if (length > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("length");
            }
            frameLength = (int) length;
        }

        @Override
        public void content(final ChannelHandlerContext ctx, final ByteBuf data, final List<Object> out) {
            if (data.readableBytes() < frameLength) {
                return;
            }
            switch (state) {
            case URI:
                uri = readAscii(data, frameLength);
                state = State.METHOD;
                break;
            case METHOD:
                method = readMethod(data, frameLength);
                state = State.ID;
                break;
            case ID:
                id = readId(data, frameLength);
                state = State.PAYLOAD;
                break;
            case PAYLOAD:
                payload = readPayload(data, frameLength);
                state = State.URI;
                break;
            }
        }

        @Override
        public void finish(final ChannelHandlerContext ctx, final List<Object> out) {
            out.add(new Request(id, uri, method, payload));
        }

        @Override
        public void close() {
        }
    }

    private static class ReplyDecoder implements ZMTPDecoder {

        enum State {
            URI, METHOD, ID, STATUSCODE, PAYLOAD
        }

        private State state = State.URI;
        private int frameLength;

        private CharSequence uri;
        private CharSequence method;
        private MessageId id;
        private int statusCode;
        private ByteBuffer payload;

        @Override
        public void header(final ChannelHandlerContext ctx, final long length, final boolean more,
                final List<Object> out) {
            if (length > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("length");
            }
            frameLength = (int) length;
        }

        @Override
        public void content(final ChannelHandlerContext ctx, final ByteBuf data, final List<Object> out) {
            if (data.readableBytes() < frameLength) {
                return;
            }
            switch (state) {
            case URI:
                uri = readAscii(data, frameLength);
                state = State.METHOD;
                break;
            case METHOD:
                method = readMethod(data, frameLength);
                state = State.ID;
                break;
            case ID:
                id = readId(data, frameLength);
                state = State.STATUSCODE;
                break;
            case STATUSCODE:
                statusCode = readStatusCode(data, frameLength);
                state = State.PAYLOAD;
                break;
            case PAYLOAD:
                payload = readPayload(data, frameLength);
                state = State.URI;
                break;
            }
        }

        @Override
        public void finish(final ChannelHandlerContext ctx, final List<Object> out) {
            out.add(new Reply(id, uri, method, statusCode, payload));
        }

        @Override
        public void close() {
        }
    }

    private static class ServerRequestTracker extends ChannelDuplexHandler {

        private final Map<MessageId, Request> outstanding = Maps.newHashMap();

        @Override
        public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
                throws Exception {
            final Reply reply = (Reply) msg;
            final Request request = outstanding.remove(reply.id());
            if (request == null) {
                System.err.println("Unexpected reply: " + reply);
            } else {
                super.write(ctx, msg, promise);
            }
        }

        public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
            final Request request = (Request) msg;
            outstanding.put(request.id(), request);
            super.channelRead(ctx, msg);
        }
    }

    private static class ClientRequestTracker extends ChannelDuplexHandler {

        private final Map<MessageId, SettableFuture<Reply>> outstanding = Maps.newHashMap();

        @Override
        public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
                throws Exception {
            final Request request = (Request) msg;
            final RequestPromise requestPromise = (RequestPromise) promise;
            outstanding.put(request.id(), requestPromise.replyFuture);
            super.write(ctx, msg, promise);
        }

        @Override
        public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
            final Reply reply = (Reply) msg;
            final SettableFuture<Reply> future = outstanding.remove(reply.id());
            if (future == null) {
                System.err.println("unexpected reply: " + reply);
            } else {
                future.set(reply);
            }
        }
    }

    private static class ByteBufSizeEstimator implements MessageSizeEstimator, MessageSizeEstimator.Handle {

        public static final ByteBufSizeEstimator INSTANCE = new ByteBufSizeEstimator();

        @Override
        public Handle newHandle() {
            return this;
        }

        @Override
        public int size(final Object msg) {
            if (msg instanceof ByteBuf) {
                return ((ByteBuf) msg).readableBytes();
            }
            return 0;
        }
    }
}