jgnash.engine.concurrent.DistributedLockServer.java Source code

Java tutorial

Introduction

Here is the source code for jgnash.engine.concurrent.DistributedLockServer.java

Source

/*
 * jGnash, a personal finance application
 * Copyright (C) 2001-2015 Craig Cavanaugh
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package jgnash.engine.concurrent;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.handler.codec.string.StringEncoder;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.GlobalEventExecutor;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Level;
import java.util.logging.Logger;

import jgnash.util.EncodeDecode;
import jgnash.util.EncryptionManager;

/**
 * Distributed Lock Server
 *
 * @author Craig Cavanaugh
 */
public class DistributedLockServer {

    private static final Logger logger = Logger.getLogger(DistributedLockServer.class.getName());

    private final ExecutorService executorService = Executors.newCachedThreadPool();

    private final ChannelGroup channelGroup = new DefaultChannelGroup("lock-server", GlobalEventExecutor.INSTANCE);

    private NioEventLoopGroup eventLoopGroup;

    private final int port;

    private final Map<String, ReadWriteLock> lockMap = new HashMap<>();

    private final Map<ChannelHandlerContext, String> handlerContextMap = new HashMap<>();

    static final String LOCK = "lock";

    static final String UNLOCK = "unlock";

    static final String LOCK_TYPE_READ = "READ";

    static final String LOCK_TYPE_WRITE = "WRITE";

    private static final String EOL_DELIMITER = "\r\n";

    private EncryptionManager encryptionManager = null;

    public DistributedLockServer(final int port) {
        this.port = port;
    }

    private String encrypt(final String message) {
        if (encryptionManager != null) {
            return encryptionManager.encrypt(message);
        }
        return message;
    }

    private void processMessage(final ChannelHandlerContext ctx, final String msg) {

        final String message;

        if (encryptionManager != null) {
            message = encryptionManager.decrypt(msg);
        } else {
            message = msg;
        }

        // Look for a uuid announcement for a channel
        if (message.startsWith(DistributedLockManager.UUID_PREFIX)) {
            handlerContextMap.put(ctx, message.substring(DistributedLockManager.UUID_PREFIX.length()));
            return;
        }

        /** lock_action, lock_id, thread_id, lock_type */
        // unlock,account,1194917570,read
        // lock,account,1194917570,write

        // decode the message into it's parts
        final String[] strings = EncodeDecode.decodeStringCollection(message).toArray(new String[4]);

        final String action = strings[0];
        final String lockId = strings[1];
        final String remoteThread = strings[2];
        final String lockType = strings[3];

        final ReadWriteLock lock = getLock(lockId);

        try {

            // request a lock or unlock.  This may block
            switch (action) {
            case LOCK:
                switch (lockType) {
                case LOCK_TYPE_READ:
                    lock.lockForRead(remoteThread);
                    break;
                case LOCK_TYPE_WRITE:
                    lock.lockForWrite(remoteThread);
                    break;
                default:
                    break;
                }
                break;
            case UNLOCK:
                switch (lockType) {
                case LOCK_TYPE_READ:
                    lock.unlockRead(remoteThread);
                    break;
                case LOCK_TYPE_WRITE:
                    lock.unlockWrite(remoteThread);
                    break;
                default:
                    break;
                }
                break;
            }

            // return the message as an acknowledgment lock state has changed
            if (ctx.channel().isOpen()) {
                ctx.writeAndFlush(encrypt(message) + EOL_DELIMITER).sync();
            }
        } catch (final Exception e) {
            logger.log(Level.SEVERE, e.getLocalizedMessage(), e);
        }
    }

    private ReadWriteLock getLock(final String lockId) {
        ReadWriteLock readWriteLock = lockMap.get(lockId);

        if (readWriteLock == null) {
            readWriteLock = new ReadWriteLock(lockId);
            lockMap.put(lockId, readWriteLock);
        }

        return readWriteLock;
    }

    public boolean startServer(final char[] password) {
        boolean result = false;

        // If a password has been specified, create an EncryptionManager
        if (password != null && password.length > 0) {
            encryptionManager = new EncryptionManager(password);
        }

        eventLoopGroup = new NioEventLoopGroup();

        final ServerBootstrap bootstrap = new ServerBootstrap();

        try {
            bootstrap.group(eventLoopGroup).channel(NioServerSocketChannel.class).childHandler(new Initializer())
                    .childOption(ChannelOption.SO_KEEPALIVE, true);

            final ChannelFuture future = bootstrap.bind(port);
            future.sync();

            if (future.isDone() && future.isSuccess()) {
                logger.info("Distributed Lock Server started successfully");
                result = true;
            } else {
                logger.info("Failed to start the Distributed Lock Server");
            }
        } catch (final InterruptedException e) {
            logger.log(Level.SEVERE, e.getLocalizedMessage(), e);
            stopServer();
        }

        return result;
    }

    public void stopServer() {
        try {
            channelGroup.close().sync();
            executorService.shutdown();
            eventLoopGroup.shutdownGracefully();

            eventLoopGroup = null;

            logger.info("Distributed Lock Server Stopped");
        } catch (final InterruptedException e) {
            logger.log(Level.SEVERE, e.getLocalizedMessage(), e);
        }
    }

    @ChannelHandler.Sharable
    private class ServerHandler extends ChannelInboundHandlerAdapter {

        @Override
        public void channelActive(final ChannelHandlerContext ctx) throws Exception {
            channelGroup.add(ctx.channel()); // maintain channels

            logger.log(Level.INFO, "Remote connection from: {0}", ctx.channel().remoteAddress().toString());
        }

        @Override
        public void channelInactive(final ChannelHandlerContext ctx) throws Exception {
            logger.log(Level.INFO, "Remote connection {0} closed", ctx.channel().remoteAddress().toString());

            final String uuid = handlerContextMap.get(ctx);

            // Search through the lock map and remove any stale locks
            if (uuid != null) {
                for (ReadWriteLock readWriteLock : lockMap.values()) { // look at every lock

                    // if the remoteThread starts with the uuid, request a cleanup
                    // cleanup a stale lock
                    readWriteLock.readingThreads.keySet().stream()
                            .filter(remoteThread -> remoteThread.startsWith(uuid))
                            .forEach(readWriteLock::cleanupStaleThread);

                    if (readWriteLock.hasWriteThread(uuid)) {
                        readWriteLock.cleanupStaleWriteThread();
                    }
                }
            }

            handlerContextMap.remove(ctx);
            channelGroup.remove(ctx.channel());
            super.channelInactive(ctx);
        }

        @Override
        public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
            executorService.submit(() -> {
                processMessage(ctx, msg.toString());
                ReferenceCountUtil.release(msg);
            });
        }

        @Override
        public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception {
            logger.log(Level.WARNING, "Unexpected exception from downstream.", cause);
            ctx.close();
        }
    }

    private class Initializer extends ChannelInitializer<SocketChannel> {

        @Override
        public void initChannel(final SocketChannel ch) throws Exception {
            ChannelPipeline pipeline = ch.pipeline();

            // Add the text line codec combination first,
            pipeline.addLast("framer", new DelimiterBasedFrameDecoder(8192, true, Delimiters.lineDelimiter()));

            // the encoder and decoder are static as these are sharable
            pipeline.addLast("decoder", new StringDecoder(CharsetUtil.UTF_8));
            pipeline.addLast("encoder", new StringEncoder(CharsetUtil.UTF_8));

            // and then business logic.
            pipeline.addLast("handler", new ServerHandler());
        }
    }

    /**
     * Reentrant Read Write lock.
     * <p/>
     * A unique integer must be supplied to identify the thread instead of the current thread.
     */
    private static class ReadWriteLock {

        private final String id;

        /**
         * The key is the uuid of the manager plus the remote thread id
         * <p/>
         * uuid-integer
         */
        private final Map<String, Integer> readingThreads = new ConcurrentHashMap<>();

        private int writeAccesses = 0;
        private int writeRequests = 0;
        private String writingThread = null;

        private ReadWriteLock(final String id) {
            this.id = id;
        }

        synchronized boolean hasWriteThread(final String id) {
            boolean result = false;

            if (writingThread != null) {
                result = writingThread.startsWith(id);
            }

            return result;
        }

        synchronized void cleanupStaleThread(final String remoteThread) {
            if (readingThreads.containsKey(remoteThread)) {
                unlockRead(remoteThread);
                logger.log(Level.WARNING, "Removed a stale read lock for: {0}", id);
            }

            if (writingThread != null && writingThread.equals(remoteThread)) {
                unlockWrite(remoteThread);
                logger.log(Level.WARNING, "Removed a stale write lock for: {0}", id);
            }
        }

        synchronized void cleanupStaleWriteThread() {
            if (readingThreads.containsKey(writingThread)) {
                unlockRead(writingThread);
                logger.log(Level.WARNING, "Removed a stale read lock for: {0}", id);
            }

            if (writingThread != null) {
                unlockWrite(writingThread);
                logger.log(Level.WARNING, "Removed a stale write lock for: {0}", id);
            }
        }

        synchronized void lockForRead(final String remoteThread) throws InterruptedException {

            while (!canGrantReadAccess(remoteThread)) {
                wait();
            }

            readingThreads.put(remoteThread, (getReadHoldCount(remoteThread) + 1));
        }

        synchronized void lockForWrite(final String remoteThread) throws InterruptedException {
            writeRequests++;

            while (!canGrantWriteAccess(remoteThread)) {
                wait();
            }

            writeRequests--;
            writeAccesses++; // bump, if greater than 1, then the lock is reentrant
            writingThread = remoteThread;
        }

        synchronized void unlockRead(final String remoteThread) {

            if (!isReadLockedByCurrentThread(remoteThread)) {
                throw new IllegalMonitorStateException(
                        "Remote Thread: " + remoteThread + " does not hold a read lock for: " + id);
            }

            int holdCount = getReadHoldCount(remoteThread);

            if (holdCount == 1) {
                readingThreads.remove(remoteThread);
            } else {
                readingThreads.put(remoteThread, (holdCount - 1));
            }

            notifyAll();
        }

        synchronized void unlockWrite(final String remoteThread) {

            if (!isWriteLockedByCurrentThread(remoteThread)) {
                throw new IllegalMonitorStateException(
                        "Remote Thread: " + remoteThread + " does not hold the write lock for: " + id);
            }

            writeAccesses--;

            if (writeAccesses == 0) {
                writingThread = null;
            }

            notifyAll();
        }

        private synchronized boolean canGrantReadAccess(final String remoteThread) {

            if (isWriteLockedByCurrentThread(remoteThread)) { // lock down grade is allowed
                return true;
            }

            if (writingThread != null) {
                return false;
            }
            if (isReadLockedByCurrentThread(remoteThread)) {
                return true;
            }

            return writeRequests <= 0;
        }

        private synchronized boolean canGrantWriteAccess(final String remoteThread) {

            if (!readingThreads.isEmpty()) {
                return false;
            }
            if (writingThread == null) {
                return true;
            }
            return isWriteLockedByCurrentThread(remoteThread); // reentrant write
        }

        private synchronized int getReadHoldCount(final String remoteThread) {
            final Integer accessCount = readingThreads.get(remoteThread);

            if (accessCount == null) {
                return 0;
            }

            return accessCount;
        }

        private synchronized boolean isReadLockedByCurrentThread(final String remoteThread) {
            return readingThreads.get(remoteThread) != null;
        }

        private synchronized boolean isWriteLockedByCurrentThread(final String remoteThread) {
            if (writingThread != null) {
                return writingThread.equals(remoteThread);
            }
            return false;
        }
    }
}