WorkQueue.java Source code

Java tutorial

Introduction

Here is the source code for WorkQueue.java

Source

/*
 * Copyright 2011 David Jurgens
 *
 * This file is part of the S-Space package and is covered under the terms and
 * conditions therein.
 *
 * The S-Space package is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation and distributed hereunder to you.
 *
 * THIS SOFTWARE IS PROVIDED "AS IS" AND NO REPRESENTATIONS OR WARRANTIES,
 * EXPRESS OR IMPLIED ARE MADE.  BY WAY OF EXAMPLE, BUT NOT LIMITATION, WE MAKE
 * NO REPRESENTATIONS OR WARRANTIES OF MERCHANT- ABILITY OR FITNESS FOR ANY
 * PARTICULAR PURPOSE OR THAT THE USE OF THE LICENSED SOFTWARE OR DOCUMENTATION
 * WILL NOT INFRINGE ANY THIRD PARTY PATENTS, COPYRIGHTS, TRADEMARKS OR OTHER
 * RIGHTS.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */

//package edu.ucla.sspace.util;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

/**
 * A utility class that receives a collection of tasks to execute internally and
 * then distributes the tasks among a thread pool.  This class offers to methods
 * of use.  In the first, a user can pass in a collection of tasks to run and
 * then wait until the tasks are finished.
 *<pre>
 *Collection<Runnable> tasks = new LinkedList<Runnable>();
 *WorkQueue q = new WorkQueue();
 *for (int i = 0; i < numTasks; ++i)
 *    tasks.add(new Runnable() { }); // job to do goes here
 *q.run(tasks);
 *</pre>
 * <br>
 *
 * Alternately, a use may register a task group identifier and then iteratively
 * add new tasks associated with that identifier.  At some point in the future,
 * the user can then wait for all the tasks associated with that identifier to
 * finish.  This second method allows for the iterative construction of tasks,
 * or for cases where not all of the data for the tasks is availabe at once
 * (although the number of tasks is known).
 *<pre>
 *WorkQueue q = new WorkQueue();
 *Object taskGroupId = Thread.currentThread(); // a unique id
 *q.registerTaskGroup(taskGroupId, numTasks);
 *for (int i = 0; i < numTasks; ++i)
 *    q.add(taskGroupId, new Runnable() { }); // job to do goes here
 *q.await(taskGroupId);
 *</pre>
 *
 * In the above example, the current thread is used as the group identifier,
 * which ensures that any other thread executing the same code won't use the
 * same identifier, which could result in either thread returning prematurely
 * before its tasks have finished.  However, a <i>shared</i> group identifier
 * can allow multiple threads to add tasks for a common goal, with each being
 * able await until all the tasks are finished.
 *
 * @author David Jurgens
 */
public class WorkQueue {

    /**
     * The list of all threads drawing work from the queue.
     */
    private final List<Thread> threads;

    /**
     * The queue from which worker threads run word-word comparisons
     */
    private final BlockingQueue<Runnable> workQueue;

    /**
     * A mapping from a group identifier to the associated latch.
     */
    private final ConcurrentMap<Object, CountDownLatch> taskKeyToLatch;

    /**
     * Creates a new work queue with the number of threads executing tasks the
     * same as the number as processors on the system.
     */
    public WorkQueue() {
        this(Runtime.getRuntime().availableProcessors());
    }

    /**
     * Creates a new work queue with the specified number of threads executing
     * tasks.
     */
    public WorkQueue(int numThreads) {
        workQueue = new LinkedBlockingQueue<Runnable>();
        threads = new ArrayList<Thread>();
        taskKeyToLatch = new ConcurrentHashMap<Object, CountDownLatch>();
        for (int i = 0; i < numThreads; ++i) {
            Thread t = new WorkerThread(workQueue);
            threads.add(t);
            t.start();
        }
    }

    /**
     * Adds the provided task to the work queue on behalf of the task group
     * identifier.  Note that unlike the {@link #run(Collection) run} method,
     * this method returns immediately without waiting for the task to finish.
     *
     * @param taskGroupId an identifier associated with a set of tasks.
     * @param task a task to run
     *
     * @throws IllegalArgumentException if the {@code taskGroupId} is not
     *         currently associated with any active taskGroup
     */
    public void add(Object taskGroupId, Runnable task) {
        CountDownLatch latch = taskKeyToLatch.get(taskGroupId);
        if (latch == null)
            throw new IllegalArgumentException("Unknown task id: " + taskGroupId);
        workQueue.offer(new CountingRunnable(task, latch));
    }

    /**
     * Waits until all the tasks associated with the group identifier have
     * finished.  Once a task group has been successfully waited upon, the group
     * identifier is removed from the queue and is valid to be reused for a new
     * task group.
     *
     * @throws IllegalArgumentException if the {@code taskGroupId} is not
     *         currently associated with any active taskGroup
     */
    public void await(Object taskGroupId) {
        CountDownLatch latch = taskKeyToLatch.get(taskGroupId);
        if (latch == null)
            throw new IllegalArgumentException("Unknown task group: " + taskGroupId);
        try {
            while (!latch.await(5, TimeUnit.SECONDS))
                System.out.println("cur count: " + latch.getCount());
            // Once finished, remove the key so it can be associated with a new
            // task
            taskKeyToLatch.remove(taskGroupId);
        } catch (InterruptedException ie) {
            throw new IllegalStateException("Not all tasks finished", ie);
        }
    }

    /**
     * Waits until all the tasks associated with the group identifier have
     * finished.  Once a task group has been successfully waited upon, the group
     * identifier is removed from the queue and is valid to be reused for a new
     * task group.
     *
     * @throws IllegalArgumentException if the {@code taskGroupId} is not
     *         currently associated with any active taskGroup
     */
    public boolean await(Object taskGroupId, long timeout, TimeUnit unit) {
        CountDownLatch latch = taskKeyToLatch.get(taskGroupId);
        if (latch == null)
            throw new IllegalArgumentException("Unknown task group: " + taskGroupId);
        try {
            if (latch.await(timeout, unit)) {
                // Once finished, remove the key so it can be associated with a
                // new task
                taskKeyToLatch.remove(taskGroupId);
                return true;
            }
            return false;
        } catch (InterruptedException ie) {
            throw new IllegalStateException("Not all tasks finished", ie);
        }
    }

    /**
     * Registers a new task group with the specified number of tasks to execute and 
     * returns a task group identifier to use when registering its tasks.
     *
     * @param numTasks the number of tasks that will be eventually run as a part
     *        of this group.
     *
     * @returns an identifier associated with a group of tasks
     */
    public Object registerTaskGroup(int numTasks) {
        Object key = new Object();
        taskKeyToLatch.putIfAbsent(key, new CountDownLatch(numTasks));
        return key;
    }

    /**
     * Registers a new task group with the specified number of tasks to execute,
     * or returns {@code false} if a task group with the same identifier has
     * already been registered.  This identifier will remain valid in the queue
     * until {@link #await(Object) await} has been called.
     *
     * @param taskGroupId an identifier to be associated with a group of tasks
     * @param numTasks the number of tasks that will be eventually run as a part
     *        of this group.
     *
     * @returns {@code true} if a new task group was registered or {@code false}
     *          if a task group with the same identifier had already been
     *          registered.
     */
    public boolean registerTaskGroup(Object taskGroupId, int numTasks) {
        return taskKeyToLatch.putIfAbsent(taskGroupId, new CountDownLatch(numTasks)) == null;
    }

    /**
     * Executes the tasks using a thread pool and returns once all tasks have
     * finished.
     *
     * @throws IllegalStateException if interrupted while waiting for the tasks
     *         to finish
     */
    public void run(Runnable... tasks) {
        run(Arrays.asList(tasks));
    }

    /**
     * Executes the tasks using a thread pool and returns once all tasks have
     * finished.
     *
     * @throws IllegalStateException if interrupted while waiting for the tasks
     *         to finish
     */
    public void run(Collection<Runnable> tasks) {
        // Create a semphore that the wrapped runnables will execute
        int numTasks = tasks.size();
        CountDownLatch latch = new CountDownLatch(numTasks);
        for (Runnable r : tasks) {
            workQueue.offer(new CountingRunnable(r, latch));
        }
        try {
            // Wait until all the tasks have finished
            latch.await();
        } catch (InterruptedException ie) {
            throw new IllegalStateException("Not all tasks finished", ie);
        }
    }

    /**
     * Returns the number of threads being used to process the enqueued tasks.
     */
    public int numThreads() {
        return threads.size();
    }

    /**
     * A utility class that wraps an existing runnable and updates the latch
     * when the task has finished.
     */
    private static class CountingRunnable implements Runnable {

        /**
         * The task to execute
         */
        private final Runnable task;

        /**
         * The latch to update once the task has finished
         */
        private final CountDownLatch latch;

        public CountingRunnable(Runnable task, CountDownLatch latch) {
            this.task = task;
            this.latch = latch;
        }

        /**
         * Executes the task and count down once finished.
         */
        public void run() {
            try {
                task.run();
            } finally {
                latch.countDown();
            }
        }
    }
}

/**
 * A daemon thread that continuously dequeues {@code Runnable} instances from a
 * queue and executes them.  This class is intended to be used with a {@link
 * java.util.concurrent.Semaphore Semaphore}, whereby work is added the to the
 * queue and the semaphore indicates when processing has finished.
 *
 * @author David Jurgens
 */
class WorkerThread extends Thread {

    /**
     * A static variable to indicate which instance of the class the current
     * thread is in its name.
     */
    private static int threadInstanceCount;

    /**
     * The queue from which work items will be taken
     */
    private final BlockingQueue<Runnable> workQueue;

    /**
     * An internal queue that holds thread-local tasks.  This queue is intended
     * to hold multiple tasks to avoid thread contention on the work queue.
     */
    private final Queue<Runnable> internalQueue;

    /**
     * The number of items that should be queued to run by this thread at once.
     */
    private final int threadLocalItems;

    /**
     * Creates a thread that continuously dequeues from the {@code workQueue} at
     * once and excutes each item.
     */
    public WorkerThread(BlockingQueue<Runnable> workQueue) {
        this(workQueue, 1);
    }

    /**
     * Creates a thread that continuously dequeues {@code threadLocalItems} from
     * {@code workQueue} at once and excutes them sequentially.
     *
     * @param threadLocalItems the number of items this thread should dequeue
     *        from the work queue at one time.  Setting this value too high can
     *        result in a loss of concurrency; setting it too low can result in
     *        high contention on the work queue if the time per task is also
     *        low.
     */
    public WorkerThread(BlockingQueue<Runnable> workQueue, int threadLocalItems) {
        this.workQueue = workQueue;
        this.threadLocalItems = threadLocalItems;
        internalQueue = new ArrayDeque<Runnable>();
        setDaemon(true);
        synchronized (WorkerThread.class) {
            setName("WorkerThread-" + (threadInstanceCount++));
        }
    }

    /**
     * Continuously dequeues {@code Runnable} instances from the work queue and
     * execute them.
     */
    public void run() {
        Runnable r = null;
        while (true) {
            // Try to drain the maximum capacity of thread-local items, checking
            // whether any were available
            if (workQueue.drainTo(internalQueue, threadLocalItems) == 0) {
                // block until a work item is available
                try {
                    internalQueue.offer(workQueue.take());
                } catch (InterruptedException ie) {
                    throw new Error(ie);
                }
            }
            // Execute all of the thread-local items
            while ((r = internalQueue.poll()) != null)
                r.run();
        }
    }
}
/*
 * Copyright 2011 David Jurgens
 *
 * This file is part of the S-Space package and is covered under the terms and
 * conditions therein.
 *
 * The S-Space package is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation and distributed hereunder to you.
 *
 * THIS SOFTWARE IS PROVIDED "AS IS" AND NO REPRESENTATIONS OR WARRANTIES,
 * EXPRESS OR IMPLIED ARE MADE.  BY WAY OF EXAMPLE, BUT NOT LIMITATION, WE MAKE
 * NO REPRESENTATIONS OR WARRANTIES OF MERCHANT- ABILITY OR FITNESS FOR ANY
 * PARTICULAR PURPOSE OR THAT THE USE OF THE LICENSED SOFTWARE OR DOCUMENTATION
 * WILL NOT INFRINGE ANY THIRD PARTY PATENTS, COPYRIGHTS, TRADEMARKS OR OTHER
 * RIGHTS.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */
/*
package edu.ucla.sspace.util;
    
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
    
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
    
import org.junit.Ignore;
import org.junit.Test;
    
import static org.junit.Assert.*;
    
    
// * A collection of unit tests for {@link WorkQueue} 
public class WorkQueueTests {
    
@Test public void testSingleOp() {
    WorkQueue w = new WorkQueue(4);
    final AtomicInteger j = new AtomicInteger();
    w.run(new Runnable() {
            public void run() {
                System.out.println(j.incrementAndGet());
            }
        });
    assertEquals(1, j.get());
}
    
@Test public void testSingleThreadedOp() {
    WorkQueue w = new WorkQueue(1);
    final AtomicInteger j = new AtomicInteger();
    w.run(new Runnable() {
            public void run() {
                j.incrementAndGet();
            }
        });
    assertEquals(1, j.get());
}
    
    
@Test public void testMultiple() {
    WorkQueue w = new WorkQueue(4);
    final AtomicInteger j = new AtomicInteger();
    Collection<Runnable> c = new ArrayList<Runnable>();
    for (int i = 0; i < 100; ++i)
        c.add(new Runnable() {
            public void run() {
                j.incrementAndGet();
            }
        });
    w.run(c);
    assertEquals(100, j.get());
}
    
@Test public void testSingleThreadMultipleOps() {
    WorkQueue w = new WorkQueue(1);
    final AtomicInteger j = new AtomicInteger();
    Collection<Runnable> c = new ArrayList<Runnable>();
    for (int i = 0; i < 100; ++i)
        c.add(new Runnable() {
            public void run() {
                j.incrementAndGet();
            }
        });
    w.run(c);
    assertEquals(100, j.get());
}
    
}
*/