com.tinspx.util.concurrent.TimedSemaphoreTest.java Source code

Java tutorial

Introduction

Here is the source code for com.tinspx.util.concurrent.TimedSemaphoreTest.java

Source

/* Copyright (C) 2013-2014 Ian Teune <ian.teune@gmail.com>
 * 
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files (the
 * "Software"), to deal in the Software without restriction, including
 * without limitation the rights to use, copy, modify, merge, publish,
 * distribute, sublicense, and/or sell copies of the Software, and to
 * permit persons to whom the Software is furnished to do so, subject to
 * the following conditions:
 * 
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
 * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
 * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */
package com.tinspx.util.concurrent;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.base.Throwables;
import com.google.common.base.Ticker;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.Uninterruptibles;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import lombok.ToString;
import lombok.experimental.Builder;
import org.apache.commons.lang3.mutable.MutableInt;
import org.junit.After;
import org.junit.AfterClass;
import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;

/**
 * Tests {@link TimedSemaphore}
 * 
 * @author Ian
 */
public class TimedSemaphoreTest {

    private static long testCount;
    private static final AtomicLong stolen = new AtomicLong();

    public TimedSemaphoreTest() {
    }

    @BeforeClass
    public static void setUpClass() {
    }

    @AfterClass
    public static void tearDownClass() {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    @Test
    public void testInit() throws InterruptedException {
        TimedSemaphore ts = TimedSemaphore.minutes(3, 17);
        assertEquals(17, ts.getPeriod());
        assertSame(TimeUnit.MINUTES, ts.getUnit());
        assertEquals(3, ts.getLimit());
        assertSame(Ticker.systemTicker(), ts.getTicker());
        assertEquals(3, ts.availablePermits());

        ts.setPeriod(15, TimeUnit.MILLISECONDS);
        ts.setLimit(2);
        assertEquals(15, ts.getPeriod());
        assertSame(TimeUnit.MILLISECONDS, ts.getUnit());
        assertEquals(2, ts.getLimit());
        assertEquals(2, ts.availablePermits());

        assertEquals(0, ts.getAcquireCount());
        assertEquals(0, ts.getPeriodCount());
        assertEquals(0, ts.getAverageAcquiresPerPeriod(), 0);
        assertEquals(0, ts.estimatedWait());
        assertTrue(ts.nanosRemaining() > 10000000); //more than 10ms

        ts.acquire(3);
        assertEquals(3, ts.getAcquireCount());
        assertEquals(1, ts.getPeriodCount());
        assertEquals(2, ts.getAverageAcquiresPerPeriod(), 0);
        assertEquals(0, ts.estimatedWait());
        assertTrue(ts.nanosRemaining() > 10000000); //more than 10ms
        assertEquals(1, ts.availablePermits());

        assertTrue(ts.tryAcquire());
        assertEquals(4, ts.getAcquireCount());
        assertEquals(1, ts.getPeriodCount());
        assertEquals(2, ts.getAverageAcquiresPerPeriod(), 0);
        assertTrue(ts.estimatedWait() > 10000000);
        assertTrue(ts.nanosRemaining() > 10000000); //more than 10ms
        assertEquals(0, ts.availablePermits());
    }

    @Test
    @SuppressWarnings("ResultOfObjectAllocationIgnored")
    public void testInvalidArguments() throws InterruptedException {
        try {
            new TimedSemaphore(0, 1, TimeUnit.MILLISECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            new TimedSemaphore(-1, 1, TimeUnit.MILLISECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            new TimedSemaphore(1, 0, TimeUnit.MILLISECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            new TimedSemaphore(1, 1, null);
            fail();
        } catch (NullPointerException ex) {
        }

        TimedSemaphore ts = new TimedSemaphore(3, 100, TimeUnit.NANOSECONDS);
        try {
            ts.setLimit(0);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.setLimit(-1);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.setPeriod(0, TimeUnit.NANOSECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.setPeriod(-1, TimeUnit.NANOSECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.setPeriod(1, null);
            fail();
        } catch (NullPointerException ex) {
        }

        //acquiring too many permits
        try {
            ts.tryAcquire(4);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        ts.tryAcquire(3);

        try {
            ts.tryAcquire(4, 99, TimeUnit.NANOSECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.tryAcquire(6, 99, TimeUnit.NANOSECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        ts.tryAcquire(6, 100, TimeUnit.NANOSECONDS); //should not through
        try {
            ts.tryAcquire(7, 100, TimeUnit.NANOSECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.tryAcquire(7, 199, TimeUnit.NANOSECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        ts.tryAcquire(7, 200, TimeUnit.NANOSECONDS);
        ts.tryAcquire(9, 200, TimeUnit.NANOSECONDS);

        //acquire invalid args
        try {
            ts.acquire(-1);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.acquire(0);
            fail();
        } catch (IllegalArgumentException ex) {
        }

        //tryAcquire invalid
        try {
            ts.tryAcquire(-1);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.tryAcquire(0);
            fail();
        } catch (IllegalArgumentException ex) {
        }

        //tryAcquire with timeout, invalid arguments
        try {
            ts.tryAcquire(0, 10, TimeUnit.NANOSECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.tryAcquire(-1, 10, TimeUnit.NANOSECONDS);
            fail();
        } catch (IllegalArgumentException ex) {
        }
        try {
            ts.tryAcquire(1, 10, null);
            fail();
        } catch (NullPointerException ex) {
        }
        try {
            ts.tryAcquire(2, 10, null);
            fail();
        } catch (NullPointerException ex) {
        }

        //0 and negative timout should not throw
        ts.tryAcquire(0, TimeUnit.NANOSECONDS);
        ts.tryAcquire(-1, TimeUnit.NANOSECONDS);
        ts.tryAcquire(1, 0, TimeUnit.NANOSECONDS);
        ts.tryAcquire(1, -1, TimeUnit.NANOSECONDS);
        ts.tryAcquire(2, 0, TimeUnit.NANOSECONDS);
        ts.tryAcquire(2, -1, TimeUnit.NANOSECONDS);
    }

    interface Acquire {
        boolean acquire(TimedSemaphore ds, int permits) throws InterruptedException;

        Iterable<Permits> permits();
    }

    interface Permits {
        int permits(TimedSemaphore ds);
    }

    static final Acquire ACQUIRE_INCREMENT = new Acquire() {
        final AtomicLong counter = new AtomicLong();

        @Override
        public boolean acquire(TimedSemaphore ds, int permits) throws InterruptedException {
            switch ((int) (counter.incrementAndGet() % 3)) {
            case 0:
                return ACQUIRE_TRY_IMMEDIATE.acquire(ds, permits);
            case 1:
                return ACQUIRE_TRY_TIMEOUT.acquire(ds, permits);
            case 2:
                return ACQUIRE.acquire(ds, permits);
            default:
                throw new AssertionError();
            }
        }

        @Override
        public Iterable<Permits> permits() {
            return PERMITS_BOUNDED;
        }

        @Override
        public String toString() {
            return "ACQUIRE_INCREMENT";
        }
    };

    static final Acquire ACQUIRE_TRY_IMMEDIATE = new Acquire() {
        final AtomicLong counter = new AtomicLong();

        @Override
        public boolean acquire(TimedSemaphore ds, int permits) {
            if (permits == 1 && counter.incrementAndGet() % 2 == 0) {
                return ds.tryAcquire();
            } else {
                return ds.tryAcquire(permits);
            }
        }

        @Override
        public Iterable<Permits> permits() {
            return PERMITS_BOUNDED;
        }

        @Override
        public String toString() {
            return "ACQUIRE_TRY_IMMEDIATE";
        }
    };

    static long nanos(TimedSemaphore ts) {
        return ts.getUnit().toNanos(ts.getPeriod());
    }

    static final Acquire ACQUIRE_TRY_TIMEOUT = new Acquire() {
        final Random random = new Random();
        final AtomicLong delayCounter = new AtomicLong();
        final AtomicLong counter = new AtomicLong();

        @Override
        public boolean acquire(TimedSemaphore ts, int permits) throws InterruptedException {
            long timeout;
            if (permits > ts.limit) {
                switch ((int) (delayCounter.incrementAndGet() % 3)) {
                case 0:
                    timeout = random.nextInt(Ints.saturatedCast(ts.periodNanos * 2));
                    break;
                case 1:
                    timeout = random.nextInt(Ints.saturatedCast(ts.periodNanos * 3));
                    break;
                case 2:
                    timeout = random.nextInt(Ints.saturatedCast(ts.periodNanos));
                    break;
                default:
                    throw new AssertionError();
                }
                int mult = (permits - 1) / ts.limit;
                timeout += ts.periodNanos * mult;
            } else {
                switch ((int) (delayCounter.incrementAndGet() % 5)) {
                case 0:
                    timeout = ts.periodNanos * 2;
                    break;
                case 1:
                    timeout = ts.periodNanos * 3;
                    break;
                case 2:
                    timeout = ts.periodNanos;
                    break;
                case 3:
                    timeout = ts.periodNanos / 2;
                    break;
                case 4:
                    timeout = ts.periodNanos / 3;
                    break;
                default:
                    throw new AssertionError();
                }
            }

            long time = ts.getTicker().read();
            boolean result;
            if (permits == 1 && counter.incrementAndGet() % 2 == 0) {
                result = ts.tryAcquire(timeout, TimeUnit.NANOSECONDS);
            } else {
                int acount = ts.tryAcquire(permits, timeout, TimeUnit.NANOSECONDS);
                if (acount < permits) {
                    stolen.addAndGet(acount);
                } else if (acount > permits) {
                    fail(String.format("acount: %d, permits: %d", acount, permits));
                }
                result = permits == acount;
            }
            time = ts.getTicker().read() - time;
            if (time > timeout + 30000000) { //30ms wiggle room
                fail(String.format("elapsed: %d, timeout: %d, result: %b, permits: %d, ds: %s", time, timeout,
                        result, permits, ts));
            }
            return result;
        }

        @Override
        public Iterable<Permits> permits() {
            return PERMITS_UNBOUNDED;
        }

        @Override
        public String toString() {
            return "ACQUIRE_TRY_TIMEOUT";
        }
    };

    static final Acquire ACQUIRE = new Acquire() {
        final AtomicLong counter = new AtomicLong();

        @Override
        public boolean acquire(TimedSemaphore ds, int permits) throws InterruptedException {
            if (permits == 1 && counter.incrementAndGet() % 2 == 0) {
                ds.acquire();
            } else {
                ds.acquire(permits);
            }
            return true;
        }

        @Override
        public Iterable<Permits> permits() {
            return PERMITS_UNBOUNDED;
        }

        @Override
        public String toString() {
            return "ACQUIRE";
        }
    };

    static final Permits PERMITS_ALL = new Permits() {
        @Override
        public int permits(TimedSemaphore ds) {
            return ds.limit;
        }

        @Override
        public String toString() {
            return "PERMITS_ALL";
        }
    };

    static final Permits PERMITS_ONE = new Permits() {
        @Override
        public int permits(TimedSemaphore ds) {
            return 1;
        }

        @Override
        public String toString() {
            return "PERMITS_ONE";
        }
    };

    static final Permits PERMITS_RANDOM = new Permits() {
        final Random random = new Random();

        @Override
        public int permits(TimedSemaphore ds) {
            return random.nextInt(ds.limit) + 1;
        }

        @Override
        public String toString() {
            return "PERMITS_RANDOM";
        }
    };

    static final Permits PERMITS_INCREMENT_UNBOUNDED = new Permits() {
        final AtomicLong counter = new AtomicLong();

        @Override
        public int permits(TimedSemaphore ts) {
            switch ((int) (counter.incrementAndGet() % 4)) {
            case 0:
                return PERMITS_ALL.permits(ts);
            case 1:
                return PERMITS_ONE.permits(ts);
            case 2:
                return PERMITS_RANDOM.permits(ts);
            case 3:
                return PERMITS_MULITPLE.permits(ts);
            default:
                throw new AssertionError();
            }
        }

        @Override
        public String toString() {
            return "PERMITS_INCREMENT";
        }
    };

    static final Permits PERMITS_INCREMENT_BOUNDED = new Permits() {
        final AtomicLong counter = new AtomicLong();

        @Override
        public int permits(TimedSemaphore ts) {
            switch ((int) (counter.incrementAndGet() % 3)) {
            case 0:
                return PERMITS_ALL.permits(ts);
            case 1:
                return PERMITS_ONE.permits(ts);
            case 2:
                return PERMITS_RANDOM.permits(ts);
            default:
                throw new AssertionError();
            }
        }

        @Override
        public String toString() {
            return "PERMITS_INCREMENT_BOUNDED";
        }
    };

    static final Permits PERMITS_MULITPLE = new Permits() {
        final Random random = new Random();
        final AtomicLong counter = new AtomicLong();

        @Override
        public int permits(TimedSemaphore ts) {
            final int limit = ts.limit;
            int permits = ((int) (counter.incrementAndGet() % 3) + 1) * limit;
            permits += random.nextInt(limit);
            return permits;
        }

        @Override
        public String toString() {
            return "PERMITS_INCREMENT";
        }
    };

    static Permits constant(int permits) {
        return new PermitConstant(permits);
    }

    @ToString
    final static class PermitConstant implements Permits {
        final int permits;

        public PermitConstant(int permits) {
            checkArgument(permits > 0);
            this.permits = permits;
        }

        @Override
        public int permits(TimedSemaphore ds) {
            return permits;
        }
    }

    static final List<Acquire> ACQUIRES = Arrays.asList(ACQUIRE, ACQUIRE_TRY_IMMEDIATE, ACQUIRE_TRY_TIMEOUT,
            ACQUIRE_INCREMENT);

    static final List<Permits> PERMITS_BOUNDED = Arrays.asList(PERMITS_ONE, PERMITS_ALL, PERMITS_RANDOM,
            PERMITS_INCREMENT_BOUNDED);
    static final List<Permits> PERMITS_UNBOUNDED = Arrays.asList(PERMITS_ONE, PERMITS_ALL, PERMITS_RANDOM,
            PERMITS_MULITPLE, PERMITS_INCREMENT_UNBOUNDED);

    @ToString(exclude = "permits")
    static class History {
        final long periodNanos;
        final int limit;
        final long initFrame;
        int[] permits = new int[1024 * 16];

        public History(TimedSemaphore ts) {
            this.periodNanos = ts.periodNanos;
            this.limit = ts.getLimit();
            this.initFrame = ts.frame;
            assertEquals(nanos(ts), periodNanos);
        }

        /**
         * Attempts to acquire/reserve count permits. start is the time that
         * the acquisition began. now is the current time. this method must be
         * externally synchronized
         * 
         * @return if all count permits can be acquired
         */
        boolean reserve(long start, int count, long now) {
            assertTrue(count > 0);
            int pos = Ints.checkedCast((start - initFrame) / periodNanos);
            int max = Ints.checkedCast((now - initFrame) / periodNanos);
            int acquired = 0;
            for (; acquired < count && pos <= max; pos++) {
                if (pos >= permits.length) {
                    System.out.printf("exapanding to %d; start: %d, count: %d, now: %d; this: %s\n",
                            permits.length * 2, start, count, now, this);
                    permits = Arrays.copyOf(permits, permits.length * 2);
                }

                int remaining = limit - permits[pos];
                if (remaining > 0) {
                    int take = Math.min(remaining, count - acquired);
                    permits[pos] += take;
                    acquired += take;
                }
            }
            return acquired == count;
        }
    }

    @ToString
    static class DelayTest implements Runnable {
        final CountDownLatch complete = new CountDownLatch(1);
        final int totalThreads;
        AtomicBoolean stop;
        String fail;
        boolean started;
        final CountDownLatch start;
        final int thread;
        final Lock lock;
        final int acquisitions;

        final Ticker ticker;
        final TimedSemaphore ts;
        final Acquire acquire;
        final Permits permits;
        final MutableInt tests;
        final History history;

        @Builder
        public DelayTest(int totalThreads, CountDownLatch start, int thread, Lock lock, int acquisitions,
                Ticker ticker, TimedSemaphore ts, Acquire acquire, Permits permits, MutableInt tests,
                AtomicBoolean stop, History history) {
            this.totalThreads = totalThreads;
            this.start = checkNotNull(start);
            this.thread = thread;
            this.lock = checkNotNull(lock);
            checkArgument(acquisitions > 0);
            this.acquisitions = acquisitions;
            this.ts = checkNotNull(ts);
            this.ticker = checkNotNull(ticker);
            assertSame(ticker, ts.getTicker());
            this.acquire = checkNotNull(acquire);
            this.permits = checkNotNull(permits);
            this.tests = checkNotNull(tests);
            this.stop = checkNotNull(stop);
            this.history = checkNotNull(history);
        }

        synchronized void checkStart() {
            checkState(!started);
            started = true;
        }

        @Override
        @SuppressWarnings({ "BroadCatchBlock", "TooBroadCatch" })
        public void run() {
            try {
                tryRun();
            } catch (Throwable t) {
                fail = this + "\n" + Throwables.getStackTraceAsString(t);
                stop.set(true);
            } finally {
                complete.countDown();
            }
        }

        public void tryRun() throws InterruptedException {
            checkStart();
            start.countDown();
            Uninterruptibles.awaitUninterruptibly(start);

            for (int i = 0; i < acquisitions; i++) {
                final int p = permits.permits(ts);
                final long startTime = ticker.read();
                while (!acquire.acquire(ts, p)) {
                    //continue until acquired
                }
                final long acquireTime = ticker.read();
                boolean success;
                lock.lock();
                try {
                    tests.increment();
                    success = history.reserve(startTime, p, acquireTime);
                } finally {
                    lock.unlock();
                }
                if (!success) {
                    fail = String.format("TS: %s\np: %s, startTime: %d, acquireTime: %d\nthis: %s\n", ts, p,
                            startTime, acquireTime, this);
                    stop.set(true);
                }
                if (stop.get()) {
                    return;
                }
            }
        }
    }

    @SuppressWarnings("UnnecessaryUnboxing")
    static void runTest(Executor executor, TimedSemaphore ts, Ticker ticker, int threadCount, int acquisitions,
            Acquire acquire, Permits permits) throws InterruptedException {
        checkArgument(threadCount > 0);

        DelayTest.DelayTestBuilder builder = DelayTest.builder();
        builder.stop(new AtomicBoolean());
        builder.start(new CountDownLatch(threadCount));
        builder.lock(new ReentrantLock());
        builder.acquisitions(acquisitions);
        builder.ticker(ticker).ts(ts);
        builder.acquire(acquire).permits(permits);
        builder.tests(new MutableInt());
        builder.totalThreads(threadCount);
        builder.history(new History(ts));

        DelayTest[] testers = new DelayTest[threadCount];
        for (int i = 0; i < threadCount; i++) {
            testers[i] = builder.thread(i).build();
            executor.execute(testers[i]);
        }
        for (int i = 0; i < threadCount; i++) {
            testers[i].complete.await();
        }
        String errorMsg = null;
        for (int i = 0; i < threadCount; i++) {
            if (testers[i].fail != null) {
                errorMsg = testers[i].fail;
                System.out.println(errorMsg);
                System.out.println();
            }
        }
        if (errorMsg != null) {
            fail(errorMsg);
        }

        assertEquals(threadCount * acquisitions, builder.tests.getValue().intValue());
        if (++testCount % 10 == 0) {
            System.out.printf("%d, Tests: %s\n", testCount, builder.tests);
        }
    }

    /**
     * This tests many different variations of threads counts and permit acquire
     * methods. It is not normally run as it takes a long time to run. (30 min
     * when only testing 20 ms period on my computer).
     */
    //    @Test
    public void testAll() throws InterruptedException {
        //        System.out.println("Total Tests: " + (10 * 3 * 6 * 6 * ACQUIRES.size() * PERMITS_LIST.size()));

        final Executor executor = Executors.newFixedThreadPool(8, ThreadUtils.daemonThreadFactory());
        final int A = 10;
        //        for(int delay : Arrays.asList(20000000, 50000000, 100000000)) {
        for (int period : Arrays.asList(20000000)) {
            //        for(int threads = 3; threads <= 8; threads++) {
            for (int threads : Arrays.asList(1, 2, 3, 4, 6, 8)) {
                for (int limit = 1; limit <= 6 && limit <= threads + 3; limit++) {
                    for (Acquire acquire : ACQUIRES) {
                        for (Permits permits : acquire.permits()) {
                            runTest(executor, new TimedSemaphore(limit, period, TimeUnit.NANOSECONDS),
                                    Ticker.systemTicker(), threads, A, acquire, permits);

                        } //permits
                    } //acquires
                } //limit
            } //threads
        } //period

        System.out.println("stolen permits: " + stolen.get());
    }

}