com.github.sinsinpub.pero.manual.proxyhandler.ProxyServer.java Source code

Java tutorial

Introduction

Here is the source code for com.github.sinsinpub.pero.manual.proxyhandler.ProxyServer.java

Source

/*
 * Copyright 2014 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package com.github.sinsinpub.pero.manual.proxyhandler;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoop;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;

abstract class ProxyServer {

    protected final InternalLogger logger = InternalLoggerFactory.getInstance(getClass());

    private final ServerSocketChannel ch;
    private final Queue<Throwable> recordedExceptions = new LinkedBlockingQueue<Throwable>();
    protected final TestMode testMode;
    protected final String username;
    protected final String password;
    protected final InetSocketAddress destination;

    /**
     * Starts a new proxy server with disabled authentication for testing purpose.
     * 
     * @param useSsl {@code true} if and only if implicit SSL is enabled
     * @param testMode the test mode
     * @param destination the expected destination. If the client requests proxying to a different
     *            destination, this server will reject the connection request.
     */
    protected ProxyServer(boolean useSsl, TestMode testMode, InetSocketAddress destination) {
        this(useSsl, testMode, destination, null, null, 0, false);
    }

    protected ProxyServer(final boolean useSsl, TestMode testMode, InetSocketAddress destination, String username,
            String password) {
        this(useSsl, testMode, destination, null, null, 0, false);
    }

    /**
     * Starts a new proxy server with disabled authentication for testing purpose.
     * 
     * @param useSsl {@code true} if and only if implicit SSL is enabled
     * @param testMode the test mode
     * @param username the expected username. If the client tries to authenticate with a different
     *            username, this server will fail the authentication request.
     * @param password the expected password. If the client tries to authenticate with a different
     *            password, this server will fail the authentication request.
     * @param destination the expected destination. If the client requests proxying to a different
     *            destination, this server will reject the connection request.
     */
    protected ProxyServer(final boolean useSsl, TestMode testMode, InetSocketAddress destination, String username,
            String password, int bindPort, boolean logging) {

        this.testMode = testMode;
        this.destination = destination;
        this.username = username;
        this.password = password;

        ServerBootstrap b = new ServerBootstrap();
        b.channel(NioServerSocketChannel.class);
        if (logging) {
            b.handler(new LoggingHandler(LogLevel.INFO));
        }
        b.group(StaticContextProvider.group);
        b.childHandler(new ChannelInitializer<SocketChannel>() {
            @Override
            protected void initChannel(SocketChannel ch) throws Exception {
                ChannelPipeline p = ch.pipeline();
                if (useSsl) {
                    p.addLast(StaticContextProvider.serverSslCtx.newHandler(ch.alloc()));
                }

                configure(ch);
            }
        });

        ch = (ServerSocketChannel) b.bind(NetUtil.LOCALHOST, bindPort).syncUninterruptibly().channel();
    }

    public final InetSocketAddress address() {
        return new InetSocketAddress(NetUtil.LOCALHOST, ch.localAddress().getPort());
    }

    protected abstract void configure(SocketChannel ch) throws Exception;

    final void recordException(Throwable t) {
        logger.warn("Unexpected exception from proxy server:", t);
        recordedExceptions.add(t);
    }

    /**
     * Clears all recorded exceptions.
     */
    public final void clearExceptions() {
        recordedExceptions.clear();
    }

    /**
     * Logs all recorded exceptions and raises the last one so that the caller can fail.
     */
    public final void checkExceptions() {
        Throwable t;
        for (;;) {
            t = recordedExceptions.poll();
            if (t == null) {
                break;
            }

            logger.warn("Unexpected exception:", t);
        }

        if (t != null) {
            PlatformDependent.throwException(t);
        }
    }

    public final Channel channel() {
        return ch;
    }

    public final void stop() {
        ch.close();
    }

    protected abstract class IntermediaryHandler extends SimpleChannelInboundHandler<Object> {

        private final Queue<Object> received = new ArrayDeque<Object>();

        private boolean finished;
        private Channel backend;

        @Override
        protected final void channelRead0(final ChannelHandlerContext ctx, Object msg) throws Exception {
            if (finished) {
                received.add(ReferenceCountUtil.retain(msg));
                flush();
                return;
            }

            boolean finished = handleProxyProtocol(ctx, msg);
            if (finished) {
                this.finished = true;
                ChannelFuture f = connectToDestination(ctx.channel().eventLoop(), new BackendHandler(ctx));
                f.addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                        if (!future.isSuccess()) {
                            recordException(future.cause());
                            ctx.close();
                        } else {
                            backend = future.channel();
                            flush();
                        }
                    }
                });
            }
        }

        private void flush() {
            if (backend != null) {
                boolean wrote = false;
                for (;;) {
                    Object msg = received.poll();
                    if (msg == null) {
                        break;
                    }
                    backend.write(msg);
                    wrote = true;
                }

                if (wrote) {
                    backend.flush();
                }
            }
        }

        protected abstract boolean handleProxyProtocol(ChannelHandlerContext ctx, Object msg) throws Exception;

        protected abstract SocketAddress intermediaryDestination();

        private ChannelFuture connectToDestination(EventLoop loop, ChannelHandler handler) {
            Bootstrap b = new Bootstrap();
            b.channel(NioSocketChannel.class);
            b.group(loop);
            b.handler(handler);
            return b.connect(intermediaryDestination());
        }

        @Override
        public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
            ctx.flush();
        }

        @Override
        public void channelInactive(ChannelHandlerContext ctx) throws Exception {
            if (backend != null) {
                backend.close();
            }
        }

        @Override
        public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            recordException(cause);
            ctx.close();
        }

        private final class BackendHandler extends ChannelInboundHandlerAdapter {

            private final ChannelHandlerContext frontend;

            BackendHandler(ChannelHandlerContext frontend) {
                this.frontend = frontend;
            }

            @Override
            public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
                frontend.write(msg);
            }

            @Override
            public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
                frontend.flush();
            }

            @Override
            public void channelInactive(ChannelHandlerContext ctx) throws Exception {
                frontend.close();
            }

            @Override
            public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
                recordException(cause);
                ctx.close();
            }
        }
    }

    protected abstract class TerminalHandler extends SimpleChannelInboundHandler<Object> {

        private boolean finished;

        @Override
        protected final void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
            if (finished) {
                String str = ((ByteBuf) msg).toString(CharsetUtil.US_ASCII);
                if ("A\n".equals(str)) {
                    ctx.write(Unpooled.copiedBuffer("1\n", CharsetUtil.US_ASCII));
                } else if ("B\n".equals(str)) {
                    ctx.write(Unpooled.copiedBuffer("2\n", CharsetUtil.US_ASCII));
                } else if ("C\n".equals(str)) {
                    ctx.write(Unpooled.copiedBuffer("3\n", CharsetUtil.US_ASCII))
                            .addListener(ChannelFutureListener.CLOSE);
                } else {
                    throw new IllegalStateException("unexpected message: " + str);
                }
                return;
            }

            boolean finished = handleProxyProtocol(ctx, msg);
            if (finished) {
                this.finished = finished;
            }
        }

        protected abstract boolean handleProxyProtocol(ChannelHandlerContext ctx, Object msg) throws Exception;

        @Override
        public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
            ctx.flush();
        }

        @Override
        public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            recordException(cause);
            ctx.close();
        }
    }
}