RandomChooser.java Source code

Java tutorial

Introduction

Here is the source code for RandomChooser.java

Source

//package gr.forth.ics.util;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Random;

/**
 * A random choice maker, where each choice is associated with a probability. This implementation
 * is based on the fast <em>alias</em> method, where for each random choice two random
 * numbers are generated and only a single table lookup performed.
 *
 * @param <T> the type of the choices to be made
 * @see <a href="http://cg.scs.carleton.ca/~luc/rnbookindex.html">L. Devroye, Non-Uniform Random Variate Generation, 1986, p. 107</a>
 * @author  Andreou Dimitris, email: jim.andreou (at) gmail (dot) com 
 */
public class RandomChooser<T> {
    private final double[] probs;
    private final int[] indexes;
    private final List<T> events;

    private final Random random;

    private RandomChooser(List<Double> weights, List<T> events, Random random) {
        double sum = 0.0;
        for (double prob : weights)
            sum += prob;

        this.probs = new double[weights.size()];
        for (int i = 0; i < weights.size(); i++) {
            probs[i] = weights.get(i) * weights.size() / sum; //average = 1.0
        }

        Deque<Integer> smaller = new ArrayDeque<Integer>(weights.size() / 2 + 2);
        Deque<Integer> greater = new ArrayDeque<Integer>(weights.size() / 2 + 2);
        for (int i = 0; i < probs.length; i++) {
            if (probs[i] < 1.0) {
                smaller.push(i);
            } else {
                greater.push(i);
            }
        }
        indexes = new int[weights.size()];
        while (!smaller.isEmpty()) {
            Integer i = smaller.pop();
            Integer k = greater.peek();
            indexes[i] = k;
            probs[k] -= (1 - probs[i]);
            if (probs[k] < 1.0) {
                greater.pop();
                if (greater.isEmpty())
                    break;
                smaller.push(k);
            }
        }
        this.events = events;
        this.random = random;
    }

    /**
     * Returns a random choice.
     *
     * @return a random choice
     * @see RandomChooserBuilder about how to configure the available choices
     */
    public T choose() {
        int index = random.nextInt(probs.length);
        double x = random.nextDouble();
        return x < probs[index] ? events.get(index) : events.get(indexes[index]);
    }

    /**
     * Creates a builder of a {@link RandomChooser} instance. The builder is responsible
     * for configuring the choices and probabilities of the random chooser.
     *
     * @param <T> the type of the choices that will be randomly made
     * @return a builder of a {@code RandomChooser} object
     */
    public static <T> RandomChooserBuilder<T> newInstance() {
        return new RandomChooserBuilder<T>();
    }

    /**
     * A builder of {@link RandomChooser}.
     * 
     * @param <T> the type of the choices that the created {@code RandomChooser} will make
     */
    public static class RandomChooserBuilder<T> {
        private final List<Double> probs = new ArrayList<Double>();
        private final List<T> events = new ArrayList<T>();
        private Random random = new Random(0);

        private RandomChooserBuilder() {
        }

        /**
         * Adds the possibility of a given choice, weighted by a relative probability.
         * (Relative means that it is not needed that all probabilities have sum {@code 1.0}).
         *
         * @param choice a possible choice
         * @param prob the relative probability of the choice; must be {@code >= 0}
         * @return this
         */
        public RandomChooserBuilder<T> choice(T choice, double prob) {
            Args.gte(prob, 0.0);
            Args.notNull(choice);
            probs.add(prob);
            events.add(choice);
            return this;
        }

        /**
         * Specifies the random number generator to be used by the created {@link RandomChooser}.
         *
         * @param random the random number generator to use
         * @return this
         */
        public RandomChooserBuilder<T> setRandom(Random random) {
            this.random = random;
            return this;
        }

        /**
         * Builds a {@link RandomChooser} instance, ready to make random choices based on the
         * probabilities configured by this builder.
         *
         * @return a {@code RandomChooser}
         */
        public RandomChooser<T> build() {
            if (probs.isEmpty()) {
                throw new IllegalStateException("No choice was defined");
            }
            return new RandomChooser<T>(new ArrayList<Double>(probs), new ArrayList<T>(events), random);
        }
    }
}

class Args {
    private static final String GT = " must be greater than ";
    private static final String GTE = " must be greater or equal to ";
    private static final String LT = " must be less than ";
    private static final String LTE = " must be less or equal to ";

    private static final String EQUALS = " must be equal to ";

    public static void doesNotContainNull(Iterable<?> iterable) {
        notNull(iterable);
        for (Object o : iterable) {
            notNull("Iterable contains null", o);
        }
    }

    public static void isTrue(boolean condition) {
        isTrue("Condition failed", condition);
    }

    public static void isTrue(String msg, boolean condition) {
        if (!condition) {
            throw new RuntimeException(msg);
        }
    }

    public static void notNull(Object o) {
        notNull(null, o);
    }

    public static void notNull(String arg, Object o) {
        if (arg == null) {
            arg = "Argument";
        }
        if (o == null) {
            throw new IllegalArgumentException(arg + " is null");
        }
    }

    public static void notNull(Object... args) {
        notNull(null, args);
    }

    public static void notNull(String message, Object... args) {
        if (message == null) {
            message = "Some argument";
        }
        for (Object o : args) {
            notNull(message, o);
        }
    }

    public static void notEmpty(Iterable<?> iter) {
        notEmpty(null, iter);
    }

    public static void notEmpty(String arg, Iterable<?> iter) {
        if (arg == null) {
            arg = "Iterable";
        }
        notNull(iter);
        if (iter.iterator().hasNext())
            return;
        throw new IllegalArgumentException(arg + " is empty");
    }

    public static void hasNoNull(Iterable<?> iter) {
        hasNoNull(null, iter);
    }

    public static void hasNoNull(String arg, Iterable<?> iter) {
        notNull(iter);
        if (arg == null) {
            arg = "Iterable";
        }
        for (Object o : iter) {
            if (o == null) {
                throw new IllegalArgumentException(arg + " contains null");
            }
        }
    }

    public static void equals(int value, int expected) {
        if (value == expected)
            return;
        throw new IllegalArgumentException(value + EQUALS + expected);
    }

    public static void equals(long value, long expected) {
        if (value == expected)
            return;
        throw new IllegalArgumentException(value + EQUALS + expected);
    }

    public static void equals(double value, double expected) {
        if (value == expected)
            return;
        throw new IllegalArgumentException(value + EQUALS + expected);
    }

    public static void equals(float value, float expected) {
        if (value == expected)
            return;
        throw new IllegalArgumentException(value + EQUALS + expected);
    }

    public static void equals(char value, char expected) {
        if (value == expected)
            return;
        throw new IllegalArgumentException(value + EQUALS + expected);
    }

    public static void equals(short value, short expected) {
        if (value == expected)
            return;
        throw new IllegalArgumentException(value + EQUALS + expected);
    }

    public static void equals(byte value, byte expected) {
        if (value == expected)
            return;
        throw new IllegalArgumentException(value + EQUALS + expected);
    }

    public static void equals(Object value, Object expected) {
        if (value == expected || value.equals(expected))
            return;
        throw new IllegalArgumentException(value + EQUALS + expected);
    }

    public static void gt(int value, int from) {
        if (value > from)
            return;
        throw new IllegalArgumentException(value + GT + from);
    }

    public static void lt(int value, int from) {
        if (value < from)
            return;
        throw new IllegalArgumentException(value + LT + from);
    }

    public static void gte(int value, int from) {
        if (value >= from)
            return;
        throw new IllegalArgumentException(value + GTE + from);
    }

    public static void lte(int value, int from) {
        if (value <= from)
            return;
        throw new IllegalArgumentException(value + LTE + from);
    }

    public static void gt(long value, long from) {
        if (value > from)
            return;
        throw new IllegalArgumentException(value + GT + from);
    }

    public static void lt(long value, long from) {
        if (value < from)
            return;
        throw new IllegalArgumentException(value + LT + from);
    }

    public static void gte(long value, long from) {
        if (value >= from)
            return;
        throw new IllegalArgumentException(value + GTE + from);
    }

    public static void lte(long value, long from) {
        if (value <= from)
            return;
        throw new IllegalArgumentException(value + LTE + from);
    }

    public static void gt(short value, short from) {
        if (value > from)
            return;
        throw new IllegalArgumentException(value + GT + from);
    }

    public static void lt(short value, short from) {
        if (value < from)
            return;
        throw new IllegalArgumentException(value + LT + from);
    }

    public static void gte(short value, short from) {
        if (value >= from)
            return;
        throw new IllegalArgumentException(value + GTE + from);
    }

    public static void lte(short value, short from) {
        if (value <= from)
            return;
        throw new IllegalArgumentException(value + LTE + from);
    }

    public static void gt(byte value, byte from) {
        if (value > from)
            return;
        throw new IllegalArgumentException(value + GT + from);
    }

    public static void lt(byte value, byte from) {
        if (value < from)
            return;
        throw new IllegalArgumentException(value + LT + from);
    }

    public static void gte(byte value, byte from) {
        if (value >= from)
            return;
        throw new IllegalArgumentException(value + GTE + from);
    }

    public static void lte(byte value, byte from) {
        if (value <= from)
            return;
        throw new IllegalArgumentException(value + LTE + from);
    }

    public static void gt(char value, char from) {
        if (value > from)
            return;
        throw new IllegalArgumentException(value + GT + from);
    }

    public static void lt(char value, char from) {
        if (value < from)
            return;
        throw new IllegalArgumentException(value + LT + from);
    }

    public static void gte(char value, char from) {
        if (value >= from)
            return;
        throw new IllegalArgumentException(value + GTE + from);
    }

    public static void lte(char value, char from) {
        if (value <= from)
            return;
        throw new IllegalArgumentException(value + LTE + from);
    }

    public static void gt(double value, double from) {
        if (value > from)
            return;
        throw new IllegalArgumentException(value + GT + from);
    }

    public static void lt(double value, double from) {
        if (value < from)
            return;
        throw new IllegalArgumentException(value + LT + from);
    }

    public static void gte(double value, double from) {
        if (value >= from)
            return;
        throw new IllegalArgumentException(value + GTE + from);
    }

    public static void lte(double value, double from) {
        if (value <= from)
            return;
        throw new IllegalArgumentException(value + LTE + from);
    }

    public static void gt(float value, float from) {
        if (value > from)
            return;
        throw new IllegalArgumentException(value + GT + from);
    }

    public static void lt(float value, float from) {
        if (value < from)
            return;
        throw new IllegalArgumentException(value + LT + from);
    }

    public static void gte(float value, float from) {
        if (value >= from)
            return;
        throw new IllegalArgumentException(value + GTE + from);
    }

    public static void lte(float value, float from) {
        if (value <= from)
            return;
        throw new IllegalArgumentException(value + LTE + from);
    }

    public static <T> void gt(Comparable<T> c1, T c2) {
        if (c1.compareTo(c2) > 0)
            return;
        throw new IllegalArgumentException(c1 + GT + c2);
    }

    public static <T> void lt(Comparable<T> c1, T c2) {
        if (c1.compareTo(c2) < 0)
            return;
        throw new IllegalArgumentException(c1 + LT + c2);
    }

    public static <T> void gte(Comparable<T> c1, T c2) {
        if (c1.compareTo(c2) >= 0)
            return;
        throw new IllegalArgumentException(c1 + GTE + c2);
    }

    public static <T> void lte(Comparable<T> c1, T c2) {
        if (c1.compareTo(c2) <= 0)
            return;
        throw new IllegalArgumentException(c1 + LTE + c2);
    }

    public static <T> void inRangeII(Comparable<T> value, T from, T to) {
        gte(value, from);
        lte(value, to);
    }

    public static <T> void inRangeEE(Comparable<T> value, T from, T to) {
        gt(value, from);
        lt(value, to);
    }

    public static <T> void inRangeIE(Comparable<T> value, T from, T to) {
        gt(value, from);
        lt(value, to);
    }

    public static <T> void inRangeEI(Comparable<T> value, T from, T to) {
        gt(value, from);
        lte(value, to);
    }

    public static void inRangeII(int value, int from, int to) {
        gte(value, from);
        lte(value, to);
    }

    public static void inRangeEE(int value, int from, int to) {
        gt(value, from);
        lt(value, to);
    }

    public static void inRangeIE(int value, int from, int to) {
        gte(value, from);
        lt(value, to);
    }

    public static void inRangeEI(int value, int from, int to) {
        gt(value, from);
        lte(value, to);
    }

    public static void inRangeII(long value, long from, long to) {
        gte(value, from);
        lte(value, to);
    }

    public static void inRangeEE(long value, long from, long to) {
        gt(value, from);
        lt(value, to);
    }

    public static void inRangeIE(long value, long from, long to) {
        gte(value, from);
        lt(value, to);
    }

    public static void inRangeEI(long value, long from, long to) {
        gt(value, from);
        lte(value, to);
    }

    public static void inRangeII(short value, short from, short to) {
        gte(value, from);
        lte(value, to);
    }

    public static void inRangeEE(short value, short from, short to) {
        gt(value, from);
        lt(value, to);
    }

    public static void inRangeIE(short value, short from, short to) {
        gte(value, from);
        lt(value, to);
    }

    public static void inRangeEI(short value, short from, short to) {
        gt(value, from);
        lte(value, to);
    }

    public static void inRangeII(byte value, byte from, byte to) {
        gte(value, from);
        lte(value, to);
    }

    public static void inRangeEE(byte value, byte from, byte to) {
        gt(value, from);
        lt(value, to);
    }

    public static void inRangeIE(byte value, byte from, byte to) {
        gte(value, from);
        lt(value, to);
    }

    public static void inRangeEI(byte value, byte from, byte to) {
        gt(value, from);
        lte(value, to);
    }

    public static void check(boolean assertion, String messageIfFailed) {
        if (!assertion) {
            throw new RuntimeException(messageIfFailed);
        }
    }
}