com.hypersocket.netty.NettyServer.java Source code

Java tutorial

Introduction

Here is the source code for com.hypersocket.netty.NettyServer.java

Source

/*******************************************************************************
 * Copyright (c) 2013 Hypersocket Limited.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the GNU Public License v3.0
 * which accompanies this distribution, and is available at
 * http://www.gnu.org/licenses/gpl.html
 ******************************************************************************/
package com.hypersocket.netty;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

import javax.annotation.PostConstruct;

import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelException;
import org.jboss.netty.channel.ChannelHandler;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.SimpleChannelHandler;
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory;
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;
import org.jboss.netty.handler.ipfilter.IpFilterRule;
import org.jboss.netty.handler.ipfilter.IpFilterRuleHandler;
import org.jboss.netty.handler.ipfilter.IpSubnetFilterRule;
import org.jboss.netty.handler.logging.LoggingHandler;
import org.jboss.netty.logging.InternalLogLevel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import com.hypersocket.config.SystemConfigurationService;
import com.hypersocket.ip.IPRestrictionListener;
import com.hypersocket.ip.IPRestrictionService;
import com.hypersocket.netty.forwarding.SocketForwardingWebsocketClientHandler;
import com.hypersocket.server.HypersocketServerImpl;
import com.hypersocket.server.websocket.TCPForwardingClientCallback;

@Component
public class NettyServer extends HypersocketServerImpl implements IPRestrictionListener {

    static Logger log = LoggerFactory.getLogger(NettyServer.class);

    private ClientBootstrap clientBootstrap = null;
    private ServerBootstrap serverBootstrap = null;
    Set<Channel> httpChannels;
    Set<Channel> httpsChannels;

    ExecutorService bossExecutor;
    ExecutorService workerExecutors;

    IpFilterRuleHandler ipFilterHandler = new IpFilterRuleHandler();
    MonitorChannelHandler monitorChannelHandler = new MonitorChannelHandler();
    Map<String, List<Channel>> channelsByIPAddress = new HashMap<String, List<Channel>>();

    @Autowired
    IPRestrictionService ipRestrictionService;

    @Autowired
    SystemConfigurationService configurationService;

    public NettyServer() {

    }

    @PostConstruct
    private void postConstruct() {
        ipRestrictionService.registerListener(this);
    }

    public ClientBootstrap getClientBootstrap() {
        return clientBootstrap;
    }

    public ServerBootstrap getServerBootstrap() {
        return serverBootstrap;
    }

    @Override
    protected void doStart() throws IOException {

        System.setProperty("hypersocket.netty.debug", "true");

        clientBootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(
                bossExecutor = Executors.newCachedThreadPool(new NettyThreadFactory()),
                workerExecutors = Executors.newCachedThreadPool(new NettyThreadFactory())));

        clientBootstrap.setPipelineFactory(new ChannelPipelineFactory() {
            public ChannelPipeline getPipeline() throws Exception {
                ChannelPipeline pipeline = Channels.pipeline();
                pipeline.addLast("handler", new SocketForwardingWebsocketClientHandler());
                return pipeline;
            }
        });

        // Configure the server.
        serverBootstrap = new ServerBootstrap(new NioServerSocketChannelFactory(Executors.newCachedThreadPool(),
                Executors.newCachedThreadPool()));

        // Set up the event pipeline factory.
        serverBootstrap.setPipelineFactory(new ChannelPipelineFactory() {
            public ChannelPipeline getPipeline() throws Exception {
                ChannelPipeline pipeline = Channels.pipeline();
                if (Boolean.getBoolean("hypersocket.netty.debug")) {
                    pipeline.addLast("logger", new LoggingHandler(InternalLogLevel.DEBUG));
                }
                pipeline.addLast("ipFilter", ipFilterHandler);
                pipeline.addLast("channelMonitor", monitorChannelHandler);
                pipeline.addLast("switcherA",
                        new SSLSwitchingHandler(NettyServer.this, getHttpPort(), getHttpsPort()));
                return pipeline;
            }
        });

        serverBootstrap.setOption("child.receiveBufferSize", 1048576);
        serverBootstrap.setOption("child.sendBufferSize", 1048576);
        serverBootstrap.setOption("backlog", 5000);

        httpChannels = new HashSet<Channel>();
        httpsChannels = new HashSet<Channel>();

        bindInterface(getHttpPort(), httpChannels);
        bindInterface(getHttpsPort(), httpsChannels);

        if (httpChannels.size() == 0 && httpsChannels.size() == 0) {
            throw new IOException("Failed to startup any listening interfaces!");
        }
    }

    @Override
    public int getActualHttpPort() {
        if (httpChannels == null) {
            throw new IllegalStateException(
                    "You cannot get the actual port in use because the server is not started");
        }
        return ((InetSocketAddress) httpChannels.iterator().next().getLocalAddress()).getPort();
    }

    @Override
    public int getActualHttpsPort() {
        if (httpsChannels == null) {
            throw new IllegalStateException(
                    "You cannot get the actual port in use because the server is not started");
        }
        return ((InetSocketAddress) httpChannels.iterator().next().getLocalAddress()).getPort();
    }

    protected void bindInterface(Integer port, Set<Channel> channels) throws IOException {

        Enumeration<NetworkInterface> e = NetworkInterface.getNetworkInterfaces();

        Set<String> interfacesToBind = new HashSet<String>(
                Arrays.asList(configurationService.getValues("listening.interfaces")));

        if (interfacesToBind.isEmpty()) {

            if (log.isInfoEnabled()) {
                log.info("Binding server to all interfaces on port " + port);
            }
            Channel ch = serverBootstrap.bind(new InetSocketAddress(port));
            channels.add(ch);

            if (log.isInfoEnabled()) {
                log.info("Bound to port " + ((InetSocketAddress) ch.getLocalAddress()).getPort());
            }
        } else {
            while (e.hasMoreElements()) {

                NetworkInterface i = e.nextElement();

                Enumeration<InetAddress> inetAddresses = i.getInetAddresses();

                for (InetAddress inetAddress : Collections.list(inetAddresses)) {

                    if (interfacesToBind.contains(inetAddress.getHostAddress())) {
                        try {
                            if (log.isInfoEnabled()) {
                                log.info("Binding server to interface " + i.getDisplayName() + " "
                                        + inetAddress.getHostAddress() + ":" + port);
                            }

                            Channel ch = serverBootstrap.bind(new InetSocketAddress(inetAddress, port));
                            channels.add(ch);

                            if (log.isInfoEnabled()) {
                                log.info("Bound to " + inetAddress.getHostAddress() + ":"
                                        + ((InetSocketAddress) ch.getLocalAddress()).getPort());
                            }

                        } catch (ChannelException ex) {
                            log.error("Failed to bind port", ex);
                        }
                    }
                }
            }
        }
    }

    @Override
    protected void doStop() {

    }

    @Override
    public void connect(TCPForwardingClientCallback callback) {

        clientBootstrap.connect(new InetSocketAddress(callback.getHostname(), callback.getPort()))
                .addListener(new ClientConnectCallbackImpl(callback));

    }

    @Override
    public void restart(final Long delay) {

        Thread t = new Thread() {
            public void run() {
                if (log.isInfoEnabled()) {
                    log.info("Restarting the server in " + delay + " seconds");
                }

                try {
                    Thread.sleep(delay * 1000);

                    if (log.isInfoEnabled()) {
                        log.info("Restarting...");
                    }
                    Main.getInstance().restartServer();
                } catch (Exception e) {
                    log.error("Failed to restart", e);
                }
            }
        };

        t.start();

    }

    @Override
    public void shutdown(final Long delay) {

        Thread t = new Thread() {
            public void run() {
                if (log.isInfoEnabled()) {
                    log.info("Shutting down the server in " + delay + " seconds");
                }
                try {
                    Thread.sleep(delay * 1000);

                    if (log.isInfoEnabled()) {
                        log.info("Shutting down");
                    }
                    Main.getInstance().shutdownServer();
                } catch (Exception e) {
                    log.error("Failed to shutdown", e);
                }
            }
        };

        t.start();

    }

    class NettyThreadFactory implements ThreadFactory {

        @Override
        public Thread newThread(Runnable run) {
            Thread t = new Thread(run);
            t.setContextClassLoader(Main.getInstance().getClassLoader());
            return t;
        }

    }

    public ChannelHandler getIpFilter() {
        return ipFilterHandler;
    }

    class MonitorChannelHandler extends SimpleChannelHandler {

        @Override
        public void channelBound(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
            InetAddress addr = ((InetSocketAddress) ctx.getChannel().getRemoteAddress()).getAddress();

            if (log.isDebugEnabled()) {
                log.debug("Opening channel from " + addr.toString());
            }

            synchronized (channelsByIPAddress) {
                if (!channelsByIPAddress.containsKey(addr.getHostAddress())) {
                    channelsByIPAddress.put(addr.getHostAddress(), new ArrayList<Channel>());
                }
                channelsByIPAddress.get(addr.getHostAddress()).add(ctx.getChannel());
            }

        }

        @Override
        public void channelUnbound(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
            InetAddress addr = ((InetSocketAddress) ctx.getChannel().getRemoteAddress()).getAddress();

            if (log.isDebugEnabled()) {
                log.debug("Closing channel from " + addr.toString());
            }

            synchronized (channelsByIPAddress) {
                channelsByIPAddress.get(addr.getHostAddress()).remove(ctx.getChannel());
                if (channelsByIPAddress.get(addr.getHostAddress()).isEmpty()) {
                    channelsByIPAddress.remove(addr.getHostAddress());
                }
            }

        }
    }

    @Override
    public void onBlockIP(String addr) {

        try {
            String cidr = addr;
            if (cidr.indexOf('/') == -1) {
                cidr += "/32";
            }

            IpFilterRule rule = new IpSubnetFilterRule(false, cidr);
            ipFilterHandler.add(rule);

            synchronized (channelsByIPAddress) {
                if (channelsByIPAddress.containsKey(addr)) {
                    for (Channel c : channelsByIPAddress.get(addr)) {
                        c.close();
                    }
                }
            }
        } catch (UnknownHostException e) {
            log.error("Failed to block IP " + addr, e);
        }

    }

    @Override
    public void onUnblockIP(String addr) {

        try {
            if (addr.indexOf('/') == -1) {
                addr += "/32";
            }

            IpFilterRule rule = new IpSubnetFilterRule(false, addr);
            ipFilterHandler.remove(rule);
        } catch (UnknownHostException e) {
            log.error("Failed to unblock IP " + addr, e);
        }

    }
}