cyril.server.io.SpamHandler.java Source code

Java tutorial

Introduction

Here is the source code for cyril.server.io.SpamHandler.java

Source

// Northerner Cyril, an online, token-based implementation of the game
// Carcassonne
// This is part of the "server" module
// Copyright (C) 2012 Ben Wiederhake
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package cyril.server.io;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;

import java.io.IOException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import cyril.common.ProtocolConstants;
import cyril.server.io.SpamInitializer.ConnectionState;
import cyril.server.io.SpamInitializer.SpamListener;

/**
 * Phase 1 of the Northerner-Over-TCP protocol, server-side
 * 
 * For the reasoning and exact specification what this does, please refer to the
 * official wiki, page "Protocol":
 * https://github.com/BenWiederhake/northerner-cyril/wiki/Protocol
 * 
 * This is intended to be used on SocketChannels ONLY. I will definitely need to
 * revise this class again later, when we add SSL support.
 */

public final class SpamHandler extends ChannelInitializableInboundByteHandlerAdapter {
    private final ChannelInitializer<SocketChannel> next;
    final SpamListener listener;

    final AtomicReference<SpamState> spamState = new AtomicReference<>();

    public SpamHandler(ChannelInitializer<SocketChannel> next, SpamListener listener) {
        this.next = next;
        this.listener = listener;
    }

    @Override
    protected final void initialize(ChannelHandlerContext ctx) {
        listener.stateChanged(ctx.channel(), ConnectionState.START);
        spamState.set(SPAM_START);
        prepareFor(ctx, SPAM_START);
    }

    /**
     * Prepare timeouts for the current SpamState.
     * 
     * @param a
     *            The attribute to check
     * @param s
     *            The SpamState containing the applying values
     */
    private final void prepareFor(ChannelHandlerContext ctx, SpamState s) {
        new KillTask(ctx, s).start();
    }

    /**
     * A TimerTask that will kill the connection if the client hasn't yet
     * successfully moved the channel into a legal state
     */
    private final class KillTask implements Runnable {
        private final SocketChannel ch;
        private final SpamState reason;
        private final ChannelHandlerContext ctx;

        public KillTask(ChannelHandlerContext ctx, SpamState reason) {
            this.ctx = ctx;
            this.ch = (SocketChannel) ctx.channel();
            this.reason = reason;
        }

        @Override
        public void run() {
            if (spamState.compareAndSet(reason, null)) {
                // Connection has to be aborted, client took too much time
                listener.stateChanged(ch, ConnectionState.TIMEOUT);
                ch.close();
            } else {
                // If it's a different SpamState, then it's KillTask already has
                // been added.
                // => Do nothing
            }
        }

        public void start() {
            ctx.executor().schedule(this, reason.getTimeout(), TimeUnit.MILLISECONDS);
        }
    }

    @Override
    public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) throws IOException {
        SpamState s = spamState.getAndSet(null);

        if (s == null) {
            // KillTask has just triggered, ignore
            return;
        }

        if (in.readableBytes() < s.getExpectedLength()) {
            // Ignore, not enough bytes
            spamState.set(s);
            return;
        }

        final SocketChannel ch = (SocketChannel) ctx.channel();

        if (!s.isBufOkay(in)) {
            // Connection has to be aborted, client sent wrong data
            listener.stateChanged(ch, ConnectionState.WRONG_DATA);
            ch.close();
            return;
        }

        // Buf was okay, go to the next state:
        switch (s.getConnectionState()) {
        case START:
            s = SPAM_MAGIC;
            break;
        case FIRST_BYTE:
            // TODO: Correct buffer? Size? Usage?
            s = new SpamInteraction(ctx.nextOutboundByteBuffer());
            break;
        case GOOD_MAGIC:
            s = null;
            break;
        default: // Bad. Disconnect, so if it happens too often,
            // the server admin will notice.
            listener.stateChanged(ch, ConnectionState.WRONG_DATA);
            ch.close();
            return;
        }

        if (s != null) {
            // Stay in this handler
            spamState.set(s);
            listener.stateChanged(ch, s.getConnectionState());
            prepareFor(ctx, s);
            return;
        }

        // Help GC
        spamState.set(null);
        try {
            next.initChannel(ch);
        } catch (Exception e) {
            throw new IOException("Could not initialize next stage", e);
        } finally {
            ctx.pipeline().remove(this);
        }
    }

    private static abstract class SpamState {
        protected final long reached;

        public SpamState() {
            this.reached = System.currentTimeMillis();
        }

        /**
         * Returns the amount of bytes required to check whether the input is
         * okay or not.
         * 
         * @return the amount of bytes required to check whether the input is
         *         okay or not.
         */
        public abstract int getExpectedLength();

        /**
         * Returns the amount of milliseconds the client has to send the data.
         * 
         * @return the amount of milliseconds the client has to send the data.
         */
        public abstract int getTimeout();

        /**
         * Checks whether the buffer looks okay
         * 
         * @param bb
         * @return whether the buffer looks okay
         */
        public abstract boolean isBufOkay(ByteBuf bb);

        /**
         * Returns the connection state, assuming that no input has yet been
         * seen.
         * 
         * @return the connection state, assuming that no input has yet been
         *         seen.
         */
        public abstract ConnectionState getConnectionState();
    }

    // https://github.com/BenWiederhake/northerner-cyril/wiki/Protocol

    private static final byte MAGIC_START = 0x47; // 'G'
    static final byte[] MAGIC_REST = "ET Carcassonne\n\n".getBytes();

    private static final int MAGIC_REST_LENGTH = 16; // Coincidence

    static {
        if (MAGIC_REST.length != MAGIC_REST_LENGTH) {
            throw new InternalError(
                    "SpamHandler.MAGIC_REST <-> SpamHandler." + "MAGIC_REST_LENGTH are not consistent");
        }
        if (ProtocolConstants.START_MAGIC.length() != MAGIC_REST_LENGTH + 1) {
            throw new InternalError(
                    "ProtocolConstants.START_MAGIC <->" + " SpamHandler.MAGIC_REST_LENGTH are not consistent");
        }
        if (ProtocolConstants.START_MAGIC.charAt(0) != MAGIC_START) {
            throw new InternalError(
                    "ProtocolConstants.START_MAGIC <->" + " SpamHandler.MAGIC_START are not consistent");
        }
        for (int i = 0; i < MAGIC_REST_LENGTH; i++) {
            if (ProtocolConstants.START_MAGIC.charAt(i + 1) != MAGIC_REST[i]) {
                throw new InternalError("ProtocolConstants.START_MAGIC <->"
                        + " SpamHandler.MAGIC_REST are not consistent" + " (index " + i + " in MAGIC_REST)");
            }
        }
    }

    private static final SpamState SPAM_START = new SpamState() {
        @Override
        public int getExpectedLength() {
            return 1;
        }

        @Override
        public int getTimeout() {
            return ProtocolConstants.TIMEOUT_START;
        }

        @Override
        public boolean isBufOkay(ByteBuf bb) {
            return bb.readByte() == MAGIC_START;
        }

        @Override
        public ConnectionState getConnectionState() {
            return ConnectionState.START;
        }
    };

    private static final SpamState SPAM_MAGIC = new SpamState() {
        @Override
        public int getExpectedLength() {
            return MAGIC_REST_LENGTH;
        }

        @Override
        public int getTimeout() {
            return ProtocolConstants.TIMEOUT_REST;
        }

        @Override
        public boolean isBufOkay(ByteBuf bb) {
            for (byte b : MAGIC_REST) {
                if (!bb.readable() || bb.readByte() != b) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public ConnectionState getConnectionState() {
            return ConnectionState.FIRST_BYTE;
        }
    };

    private final class SpamInteraction extends SpamState {
        private final int result;

        public SpamInteraction(ByteBuf out) {
            result = ProtocolConstants.RANDOM.nextInt();

            final int diff = ProtocolConstants.RANDOM.nextInt();
            // Never used that buffer before -- it's definitely going to hold 8
            // bytes.
            out.writeInt(result - diff);
            out.writeInt(diff);
        }

        @Override
        public int getExpectedLength() {
            return 4;
        }

        @Override
        public int getTimeout() {
            return ProtocolConstants.TIMEOUT_INTERACTION;
        }

        @Override
        public boolean isBufOkay(ByteBuf bb) {
            return bb.readableBytes() >= 4 && bb.readInt() == result;
        }

        @Override
        public ConnectionState getConnectionState() {
            return ConnectionState.GOOD_MAGIC;
        }
    }
}