cyril.server.io.AuthHandler.java Source code

Java tutorial

Introduction

Here is the source code for cyril.server.io.AuthHandler.java

Source

// Northerner Cyril, an online, token-based implementation of the game
// Carcassonne
// This is part of the "server" module
// Copyright (C) 2012 Ben Wiederhake
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package cyril.server.io;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.DefaultCompositeByteBuf;
import io.netty.buffer.HeapByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;

import java.nio.charset.CharacterCodingException;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import cyril.common.ProtocolConstants;
import cyril.server.auth.Authentication;
import cyril.server.auth.Login;
import cyril.server.io.SpamInitializer.ConnectionState;
import cyril.server.io.SpamInitializer.SpamListener;

public final class AuthHandler<C> extends ChannelInitializableInboundByteHandlerAdapter {
    private static final int KILL_MASK = 2, DATA_MASK = 1;

    private final NorthernerService service;
    private final Authentication<C> auth;
    final SpamListener listener;

    final AtomicBoolean killed = new AtomicBoolean(false);
    final AtomicInteger timeState = new AtomicInteger();

    // Touched by LoginWriter, so we need synching
    volatile PaketPosition paketState = PaketPosition.START;

    // == Does not need synching -- is only set once:

    private C wrapped;
    ScheduledFuture<?> killTask;

    // == Does not need synching -- is only used by one thread

    private byte[] username = null;
    // Both for username and token-read
    private int index = 0;
    private Login login = null;

    public AuthHandler(NorthernerService service, Authentication<C> auth, SpamListener listener) {
        this.service = service;
        this.auth = auth;
        this.listener = listener;
    }

    @Override
    public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
        if (!timeState.compareAndSet(0, DATA_MASK)) {
            // Other thread is reading (?!) or timeout has been reached in the
            // past.
            return;
        }

        boolean goOn = true;
        while (goOn && in.readable() && ((timeState.get() & KILL_MASK) == 0)) {
            // The killmask is set when something goes wrong
            // That why we test for it with every cycle

            PaketPosition s = paketState;
            switch (s) {
            case START:
                goOn = handleIncomingStart(ctx, in);
                break;

            case CREATE_PENDING:
                // Client tried to send something while *we* were sending the
                // login data.
                // Client is trying to troll / spam us.
                readerAborts(ctx);
                goOn = false;
                break;

            case MUST_LOGIN:
                if (in.readByte() == '!') {
                    paketState = PaketPosition.LOGIN;
                } else {
                    readerAborts(ctx);
                    goOn = false;
                }
                break;

            case LOGIN:
                // Awaiting the first two bytes
                if (in.readableBytes() < 2) {
                    goOn = false;
                } else {
                    int usernameLength = in.readShort();
                    goOn = auth.isValidUsernameLength(wrapped, usernameLength);
                    if (goOn) {
                        username = new byte[usernameLength];
                        index = 0;
                    }
                }
                break;

            case LOGIN_USER:
                goOn = handleIncomingUsername(ctx, in);
                break;

            case LOGIN_TOKEN_LENGTH:
                goOn = handleIncomingTokenLength(ctx, in);
                break;

            case LOGIN_TOKEN:
                goOn = handleIncomingToken(ctx, in);
                break;

            case DECLINED_WAIT:
                goOn = handleIncomingDeclined(ctx, in);
                break;

            case AUTHENTICATED:
            default:
                throw new InternalError("Dafuq: Illegal / Unknown" + " PaketPosition " + s);
            }
        }

        if (paketState == PaketPosition.AUTHENTICATED) {
            timeState.set(0);
            return;
        }

        int oldState;
        do {
            oldState = timeState.get();
        } while (!timeState.compareAndSet(oldState, oldState & ~DATA_MASK));
        if ((oldState | KILL_MASK) != 0) {
            // Don't care about anything when KILL_MASK is set.
            readerAborts(ctx, true);
            return;
        }
    }

    /**
     * Handles the first incoming byte and returns whether further processing
     * may happen.
     * 
     * @param ctx
     *            The current context
     * @param in
     *            The buffer to read from (guaranteed to have at least one
     *            readable byte)
     * @return whether further processing may happen
     */
    private boolean handleIncomingStart(ChannelHandlerContext ctx, ByteBuf in) {
        byte b = in.readByte();
        switch (b) {
        case '?':
            // Requesting a new account
            Login l = auth.createLogin(wrapped);
            if (l == null) {
                // Not allowed, at least not currently
                readerAborts(ctx, false);
            }
            paketState = PaketPosition.CREATE_PENDING;

            // Trigger transmission of token
            new LoginTransmission(ctx, l).start();

            return false;

        case '!':
            paketState = PaketPosition.LOGIN;
            return true;

        default: // Illegal paket at this stage
            readerAborts(ctx, false);
            return false;
        }

        // This should be dead code:
        // System.out.println();
    }

    /**
     * Handles and saves the incoming username. Note that Auth already aggreed
     * to saving a username of the given length, so flooding wouldn't be very
     * effective.
     * 
     * @param ctx
     *            The current context
     * @param in
     *            The buffer to read from (guaranteed to have at least one
     *            readable byte)
     * @return whether further processing may happen
     */
    private boolean handleIncomingUsername(ChannelHandlerContext ctx, ByteBuf in) {
        // The client may have sent username AND token in one big chunk
        // => Don't read past the username
        int transferable = Math.min(username.length - index, in.readableBytes());
        in.readBytes(username, index, transferable);
        index += transferable;
        if (index < username.length) {
            // Don't go on.
            // readableBytes() should return 0 now, but don't trust that
            // if not necessary
            return false;
        }

        // Full username has arrived
        String decodedName;
        try {
            decodedName = ProtocolConstants.decode(username);
        } catch (CharacterCodingException e) {
            username = null; // help GC
            readerAborts(ctx);
            return false;
        }
        username = null; // help GC

        login = auth.getLogin(wrapped, decodedName);
        if (login == null) {
            readerAborts(ctx);
            return false;
        }

        // Name exists, wait for token length
        paketState = PaketPosition.LOGIN_TOKEN_LENGTH;
        return true;
    }

    /**
     * Handles the length field of the login request
     * 
     * @param ctx
     *            The current context
     * @param in
     *            The buffer to be read from (not guaranteed to have both bytes)
     * @return whether further processing may happen
     */
    private boolean handleIncomingTokenLength(ChannelHandlerContext ctx, ByteBuf in) {
        // Awaiting the first two bytes
        if (in.readableBytes() < 2) {
            return false;
        }

        int incomingLength = in.readShort();
        if (incomingLength != login.token.length) {
            readerAborts(ctx);
            return false;
        }

        paketState = PaketPosition.LOGIN_TOKEN;
        index = 0;
        return true;
    }

    /**
     * Handles and compares the incoming token. Automatically switches to
     * "discard mode" to be more efficient without revealing the fact that a
     * wrong byte has been encountered.<br />
     * If the transmission of the token is complete (and both tokens match), the
     * login process is complete. The next handler is automatically called, and
     * this handler is dislodged from the channel.
     * 
     * @param ctx
     *            The current context
     * @param in
     *            The buffer to read from (guaranteed to have at least one
     *            readable byte)
     * @return whether further processing may happen
     */
    private boolean handleIncomingToken(ChannelHandlerContext ctx, ByteBuf in) {
        // This *does* work with 0 length tokens, but only if the next packet
        // starts right away.
        // This doesn't pose a security threat -- if you ever hand out 0 byte
        // tokens, there's no security anymore that could be threatened.

        int checkable = Math.min(login.token.length - index, in.readableBytes());
        for (; checkable > 0; checkable--) {
            if (login.token[index] != in.readByte()) {
                paketState = PaketPosition.DECLINED_WAIT;
                break;
            }
            index++;
        }

        if (index >= login.token.length) {
            // Token is correct!
            // Dispensing product
            killTask.cancel(false);
            paketState = PaketPosition.AUTHENTICATED;
            listener.stateChanged(ctx.channel(), ConnectionState.GOOD_AUTH);
            service.handleNewSession(ctx.channel(), login);
            ctx.pipeline().remove(this);
        }

        return false;
    }

    private boolean handleIncomingDeclined(ChannelHandlerContext ctx, ByteBuf in) {
        boolean goOn;
        int tokenStuff = Math.min(login.token.length - index, in.readableBytes());
        in.skipBytes(tokenStuff);
        index += tokenStuff;
        goOn = false;
        // Don't abort halfway, that might be exploited.
        if (index >= login.token.length) {
            readerAborts(ctx);
        }
        return goOn;
    }

    private final void readerAborts(ChannelHandlerContext ctx) {
        readerAborts(ctx, false);
    }

    /**
     * Called by the reader to abort this connection. The "timeout" parameter is
     * only for feedback to the SpamListener
     * 
     * @param ctx
     *            The current context
     * @param timeout
     *            Whether this was triggered by an timeout
     */
    private final void readerAborts(ChannelHandlerContext ctx, boolean timeout) {
        killTask.cancel(false);
        timeState.set(KILL_MASK);
        ctx.close();
        if (killed.compareAndSet(false, true)) {
            listener.stateChanged(ctx.channel(), ConnectionState.TIMEOUT);
        }
    }

    @Override
    protected void initialize(ChannelHandlerContext ctx) {
        wrapped = auth.wrap(ctx.channel());
        killTask = ctx.executor().schedule(new KillTask(ctx), ProtocolConstants.AUTH_TIMEOUT, TimeUnit.SECONDS);
    }

    private final class KillTask implements Runnable {
        private final ChannelHandlerContext ctx;

        public KillTask(ChannelHandlerContext ctx) {
            this.ctx = ctx;
        }

        @Override
        public void run() {
            int oldState;
            do {
                oldState = timeState.get();
                if ((oldState | KILL_MASK) != 0) {
                    // Dafuq
                    // Whatever o.O
                    ctx.close();
                    return;
                }
            } while (!timeState.compareAndSet(oldState, oldState | KILL_MASK));

            if ((oldState | DATA_MASK) == 0 && paketState != PaketPosition.AUTHENTICATED) {
                // Success, kill this connection directly
                ctx.close();
                if (killed.compareAndSet(false, true)) {
                    listener.stateChanged(ctx.channel(), ConnectionState.TIMEOUT);
                }
            } else {
                // Otherwise, the reader has to care about that (or not)
            }
        }
    }

    private final class LoginTransmission implements ChannelFutureListener {
        private final ChannelHandlerContext ctx;
        private final ByteBuf buf;
        private final int total;

        private int index;

        public LoginTransmission(ChannelHandlerContext ctx, Login l) {
            this.ctx = ctx;
            ByteBuf tokenHeader = new HeapByteBuf(2, 2);
            tokenHeader.writeShort(l.token.length);
            ByteBuf tokenBuf = new HeapByteBuf(l.token, l.token.length);

            ByteBuf nameHeader = new HeapByteBuf(2, 2);
            tokenHeader.writeShort(l.usernameBytes.length);
            ByteBuf nameBuf = new HeapByteBuf(l.usernameBytes, l.usernameBytes.length);

            buf = new DefaultCompositeByteBuf(4, tokenHeader, tokenBuf, nameHeader, nameBuf);
            total = 4 + l.token.length + l.usernameBytes.length;
        }

        public void start() {
            // Actually, this is exactly what we want.
            // Write as much as you can as soon as you can, then schedule again
            // later.
            run();
        }

        @Override
        public void operationComplete(ChannelFuture future) throws Exception {
            if (future.isSuccess()) {
                run();
            } else {
                // Otherwise, let the channel die *now*
                ctx.close();
                killTask.cancel(false);
                timeState.set(KILL_MASK);
                if (killed.compareAndSet(false, true)) {
                    listener.stateChanged(ctx.channel(), ConnectionState.WRONG_DATA);
                }
            }
        }

        private final void run() {
            if (killed.get()) {
                return;
            }

            final ByteBuf dst = ctx.nextOutboundByteBuffer();
            int transferable = Math.min(dst.writableBytes(), total - index);
            boolean repeat = true;

            if (transferable > 0) {
                if (total == index + transferable) {
                    // Set BEFORE writing, so no race conditions can arise
                    paketState = PaketPosition.MUST_LOGIN;
                    repeat = false;
                }
                dst.writeBytes(buf, index, transferable);
                index += transferable;
            }

            if (repeat) {
                ctx.flush().addListener(this);
            }
        }
    }

    private enum PaketPosition {
        START, CREATE_PENDING, MUST_LOGIN, LOGIN, LOGIN_USER, LOGIN_TOKEN_LENGTH, LOGIN_TOKEN, DECLINED_WAIT, AUTHENTICATED
    }
}