jlibs.wamp4j.netty.NettyServerEndpoint.java Source code

Java tutorial

Introduction

Here is the source code for jlibs.wamp4j.netty.NettyServerEndpoint.java

Source

/**
 * Copyright 2015 Santhosh Kumar Tekuri
 *
 * The JLibs authors license this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package jlibs.wamp4j.netty;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.AttributeKey;
import jlibs.wamp4j.SSLSettings;
import jlibs.wamp4j.Util;
import jlibs.wamp4j.spi.AcceptListener;
import jlibs.wamp4j.spi.NamedThreadFactory;
import jlibs.wamp4j.spi.WAMPServerEndPoint;

import java.net.URI;

/**
 * @author Santhosh Kumar Tekuri
 */
public class NettyServerEndpoint extends NettyEndpoint implements WAMPServerEndPoint {
    private static final AttributeKey<AcceptListener> ACCEPT_LISTENER = AttributeKey
            .newInstance(AcceptListener.class.getName());
    private Channel channel;

    public NettyServerEndpoint() {
        super(NamedThreadFactory.ROUTER_THREAD_FACTORY);
    }

    @Override
    public void bind(final URI uri, final String subProtocols[], final AcceptListener listener) {
        final SslContext sslContext;
        if ("wss".equals(uri.getScheme())) {
            try {
                if (sslSettings == null) {
                    SelfSignedCertificate ssc = new SelfSignedCertificate();
                    sslSettings = new SSLSettings().keyFile(ssc.privateKey()).certificateFile(ssc.certificate());
                }
                ClientAuth clientAuth = ClientAuth.values()[sslSettings.clientAuthentication.ordinal()];
                sslContext = SslContextBuilder
                        .forServer(sslSettings.certificateFile, sslSettings.keyFile, sslSettings.keyPassword)
                        .clientAuth(clientAuth).trustManager(sslSettings.trustCertChainFile).build();
            } catch (Throwable thr) {
                listener.onError(thr);
                return;
            }
        } else if ("ws".equals(uri.getScheme()))
            sslContext = null;
        else
            throw new IllegalArgumentException("invalid protocol: " + uri.getScheme());

        int port = uri.getPort();
        if (port == -1)
            port = sslContext == null ? 80 : 443;
        ServerBootstrap bootstrap = new ServerBootstrap().group(eventLoopGroup)
                .channel(NioServerSocketChannel.class)
                .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
                .childOption(ChannelOption.MAX_MESSAGES_PER_READ, 50000)
                .childOption(ChannelOption.WRITE_SPIN_COUNT, 50000)
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) throws Exception {
                        if (sslContext != null)
                            ch.pipeline().addLast(sslContext.newHandler(ch.alloc()));
                        ch.pipeline().addLast(new HttpServerCodec(), new HttpObjectAggregator(65536),
                                new Handshaker(uri, listener, subProtocols));
                    }
                });
        bootstrap.bind(uri.getHost(), port).addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future) throws Exception {
                if (future.isSuccess()) {
                    channel = future.channel();
                    channel.attr(ACCEPT_LISTENER).set(listener);
                    listener.onBind(NettyServerEndpoint.this);
                } else
                    listener.onError(future.cause());
            }
        });
    }

    @Override
    public void close() {
        channel.close().addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future) throws Exception {
                AcceptListener acceptListener = channel.attr(ACCEPT_LISTENER).get();
                if (!future.isSuccess())
                    acceptListener.onError(future.cause());
                acceptListener.onClose(NettyServerEndpoint.this);
            }
        });
    }

    private static class Handshaker extends SimpleChannelInboundHandler<FullHttpRequest> {
        private final String scheme;
        private final String path;
        private final AcceptListener acceptListener;
        private final String subProtocols[];

        public Handshaker(URI uri, AcceptListener acceptListener, String subProtocols[]) {
            scheme = uri.getScheme();
            path = uri.getPath().isEmpty() ? "/" : uri.getPath();
            this.acceptListener = acceptListener;
            this.subProtocols = subProtocols;
        }

        @Override
        protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) throws Exception {
            String host = request.headers().get("Host");
            String connection = request.headers().get("Connection");
            if (request.getMethod() != HttpMethod.GET || !request.getUri().equals(path) || host == null
                    || !"Upgrade".equals(connection)) {
                DefaultFullHttpResponse response = new DefaultFullHttpResponse(request.getProtocolVersion(),
                        HttpResponseStatus.NOT_FOUND);
                HttpHeaders.setContentLength(response, 0);
                ChannelFuture future = ctx.writeAndFlush(response);
                if (!HttpHeaders.isKeepAlive(request))
                    future.addListener(ChannelFutureListener.CLOSE);
                return;
            }
            String url = scheme + "://" + host + path;
            WebSocketServerHandshaker handshaker = new WebSocketServerHandshakerFactory(url,
                    Util.toString(subProtocols), false).newHandshaker(request);
            if (handshaker == null) {
                WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
                return;
            }

            ChannelFuture future = handshaker.handshake(ctx.channel(), request);
            for (String subProtocol : subProtocols) {
                if (subProtocol.equals(handshaker.selectedSubprotocol())) {
                    NettyWebSocket webSocket = new NettyWebSocket(handshaker, subProtocol);
                    ctx.pipeline().addLast("ws-aggregator", new WebSocketFrameAggregator(16 * 1024 * 1024));
                    ctx.pipeline().addLast("websocket", webSocket);
                    ctx.pipeline().remove(this);
                    webSocket.channelActive(ctx);
                    acceptListener.onAccept(webSocket);
                    return;
                }
            }
            future.addListener(ChannelFutureListener.CLOSE);
        }
    }
}