com.spotify.heroic.rpc.nativerpc.NativeRpcServerSessionInitializer.java Source code

Java tutorial

Introduction

Here is the source code for com.spotify.heroic.rpc.nativerpc.NativeRpcServerSessionInitializer.java

Source

/*
 * Copyright (c) 2015 Spotify AB.
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.heroic.rpc.nativerpc;

import java.nio.charset.Charset;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.spotify.heroic.rpc.nativerpc.message.NativeRpcError;
import com.spotify.heroic.rpc.nativerpc.message.NativeRpcHeartBeat;
import com.spotify.heroic.rpc.nativerpc.message.NativeRpcRequest;
import com.spotify.heroic.rpc.nativerpc.message.NativeRpcResponse;

import eu.toolchain.async.AsyncFuture;
import eu.toolchain.async.FutureDone;
import eu.toolchain.async.Transform;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.util.Timeout;
import io.netty.util.Timer;
import io.netty.util.TimerTask;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

@Slf4j
@RequiredArgsConstructor
public class NativeRpcServerSessionInitializer extends ChannelInitializer<SocketChannel> {
    private static final Charset UTF8 = Charset.forName("UTF-8");

    private final Timer timer;
    private final ObjectMapper mapper;
    private final NativeRpcContainer container;
    private final int maxFrameSize;

    @Override
    protected void initChannel(final SocketChannel ch) throws Exception {
        final AtomicReference<Timeout> heartbeatTimeout = new AtomicReference<>();

        final ChannelPipeline pipeline = ch.pipeline();
        // first four bytes are length prefix of message, strip first four bytes.
        pipeline.addLast(new LengthFieldBasedFrameDecoder(maxFrameSize, 0, 4, 0, 4));
        pipeline.addLast(new NativeRpcDecoder());
        pipeline.addLast(new ChannelHandler(heartbeatTimeout));

        pipeline.addLast(new LengthFieldPrepender(4));
        pipeline.addLast(new NativeRpcEncoder());
    }

    /**
     * Stop any current timeout, if pending.
     */
    private void stopCurrentTimeout(final AtomicReference<Timeout> heartbeatTimeout) {
        final Timeout old = heartbeatTimeout.getAndSet(null);

        if (old != null) {
            old.cancel();
        }
    }

    @RequiredArgsConstructor
    private class ChannelHandler extends SimpleChannelInboundHandler<Object> {
        private final AtomicReference<Timeout> heartbeatTimeout;

        @Override
        protected void channelRead0(final ChannelHandlerContext ctx, final Object msg) throws Exception {
            if (msg instanceof NativeRpcRequest) {
                try {
                    handleRequest(ctx.channel(), (NativeRpcRequest) msg);
                } catch (Exception e) {
                    log.error("Failed to handle request", e);
                    sendError(ctx.channel(), e.getMessage()).addListener(closeListener());
                    return;
                }

                return;
            }

            throw new IllegalArgumentException("Invalid request: " + msg);
        }

        private void handleRequest(final Channel ch, NativeRpcRequest msg) throws Exception {
            final NativeRpcRequest request = (NativeRpcRequest) msg;
            final NativeRpcContainer.EndpointSpec<Object, Object> handle = container.get(request.getEndpoint());

            if (handle == null) {
                sendError(ch, "No such endpoint: " + request.getEndpoint()).addListener(closeListener());
                return;
            }

            if (log.isTraceEnabled()) {
                log.trace("request[{}:{}ms] {}", request.getEndpoint(), request.getHeartbeatInterval(),
                        new String(request.getBody(), UTF8));
            }

            final long heartbeatInterval = calculcateHeartbeatInterval(msg);

            if (heartbeatInterval > 0) {
                // start sending heartbeat since we are now processing a request.
                setupHeartbeat(ch, heartbeatInterval);
            }

            final Object body = mapper.readValue(request.getBody(), handle.requestType());

            final AsyncFuture<Object> handleFuture = handle.handle(body);

            // Serialize in a separate thread on the async thread pool.
            // this also neatly catches errors for us in the next step.
            // Stop sending heartbeats immediately when the future has been finished.
            // this will cause the other end to time out if a response is available, but its unable
            // to pass the
            // network.
            handleFuture.directTransform(serialize(request)).onFinished(() -> stopCurrentTimeout(heartbeatTimeout))
                    .onDone(sendResponseHandle(ch));
        }

        private long calculcateHeartbeatInterval(NativeRpcRequest msg) {
            if (msg.getHeartbeatInterval() <= 0) {
                return 0;
            }

            return msg.getHeartbeatInterval() / 2;
        }

        private void setupHeartbeat(final Channel ch, final long heartbeatInterval) {
            scheduleHeartbeat(ch, heartbeatInterval);

            ch.closeFuture().addListener(new ChannelFutureListener() {
                @Override
                public void operationComplete(final ChannelFuture future) throws Exception {
                    stopCurrentTimeout(heartbeatTimeout);
                }
            });
        }

        private void scheduleHeartbeat(final Channel ch, final long heartbeatInterval) {
            final Timeout timeout = timer.newTimeout(new TimerTask() {
                @Override
                public void run(final Timeout timeout) throws Exception {
                    sendHeartbeat(ch).addListener(new ChannelFutureListener() {
                        @Override
                        public void operationComplete(final ChannelFuture future) throws Exception {
                            scheduleHeartbeat(ch, heartbeatInterval);
                        }
                    });
                }
            }, heartbeatInterval, TimeUnit.MILLISECONDS);

            final Timeout old = heartbeatTimeout.getAndSet(timeout);

            if (old != null) {
                old.cancel();
            }
        }

        private Transform<Object, NativeRpcResponse> serialize(final NativeRpcRequest request) {
            return new Transform<Object, NativeRpcResponse>() {
                @Override
                public NativeRpcResponse transform(Object result) throws Exception {
                    final byte[] response = mapper.writeValueAsBytes(result);
                    if (log.isTraceEnabled()) {
                        log.trace("response[{}]: {}", request.getEndpoint(), new String(response, UTF8));
                    }

                    return new NativeRpcResponse(response);
                }
            };
        }

        private FutureDone<NativeRpcResponse> sendResponseHandle(final Channel ch) {
            return new FutureDone<NativeRpcResponse>() {
                @Override
                public void cancelled() throws Exception {
                    log.error("{}: request cancelled", ch);
                    ch.writeAndFlush(new NativeRpcError("request cancelled")).addListener(closeListener());
                }

                @Override
                public void failed(final Throwable e) throws Exception {
                    log.error("{}: request failed", ch, e);
                    ch.writeAndFlush(new NativeRpcError(e.getMessage())).addListener(closeListener());
                }

                @Override
                public void resolved(final NativeRpcResponse response) throws Exception {
                    sendHeartbeat(ch).addListener(new ChannelFutureListener() {
                        @Override
                        public void operationComplete(ChannelFuture f) throws Exception {
                            if (!f.isSuccess()) {
                                sendError(ch, f.cause() == null ? "send of tail heartbeat failed"
                                        : f.cause().getMessage()).addListener(closeListener());
                                return;
                            }

                            ch.writeAndFlush(response).addListener(closeListener());
                        }
                    });
                }
            };
        }

        private ChannelFuture sendHeartbeat(final Channel ch) {
            return ch.writeAndFlush(new NativeRpcHeartBeat());
        }

        private ChannelFuture sendError(final Channel ch, final String error) {
            return ch.writeAndFlush(new NativeRpcError(error));
        }

        /**
         * Listener that shuts down the connection after its promise has been resolved.
         */
        private ChannelFutureListener closeListener() {
            // immediately stop sending heartbeats.
            stopCurrentTimeout(heartbeatTimeout);

            return new ChannelFutureListener() {
                @Override
                public void operationComplete(ChannelFuture future) throws Exception {
                    future.channel().close();
                }
            };
        }
    }
};