com.analog.lyric.dimple.solvers.gibbs.samplers.generic.CDFSampler.java Source code

Java tutorial

Introduction

Here is the source code for com.analog.lyric.dimple.solvers.gibbs.samplers.generic.CDFSampler.java

Source

/*******************************************************************************
*   Copyright 2013 Analog Devices, Inc.
*
*   Licensed under the Apache License, Version 2.0 (the "License");
*   you may not use this file except in compliance with the License.
*   You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
*   Unless required by applicable law or agreed to in writing, software
*   distributed under the License is distributed on an "AS IS" BASIS,
*   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*   See the License for the specific language governing permissions and
*   limitations under the License.
********************************************************************************/

package com.analog.lyric.dimple.solvers.gibbs.samplers.generic;

import static com.analog.lyric.dimple.environment.DimpleEnvironment.*;

import org.apache.commons.math3.random.RandomGenerator;

import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.values.DiscreteValue;
import com.analog.lyric.math.Utilities;

public class CDFSampler extends AbstractGenericSampler implements IDiscreteDirectSampler {
    protected double[] _samplerScratch = ArrayUtil.EMPTY_DOUBLE_ARRAY;
    protected int _lengthRoundedUp = 0;
    protected int _length = 0;

    @Override
    public void initialize(Domain variableDomain) {
        int length = ((DiscreteDomain) variableDomain).size();
        _length = length;
        _lengthRoundedUp = Utilities.nextPow2(length);
        _samplerScratch = new double[_lengthRoundedUp];
    }

    @Override
    public void nextSample(DiscreteValue sampleValue, double[] energy, double minEnergy,
            IDiscreteSamplerClient samplerClient) {
        final RandomGenerator rand = activeRandom();
        final int length = sampleValue.getDomain().size(); //energy may be longer than domain size
        int sampleIndex;

        // Special-case lengths 2, 3, and 4 for speed
        switch (length) {
        case 2: {
            sampleIndex = (rand.nextDouble() * (1 + Math.exp(energy[1] - energy[0])) > 1) ? 0 : 1;
            break;
        }
        case 3: {
            final double cumulative1 = Math.exp(minEnergy - energy[0]);
            final double cumulative2 = cumulative1 + Math.exp(minEnergy - energy[1]);
            final double sum = cumulative2 + Math.exp(minEnergy - energy[2]);
            final double randomValue = sum * rand.nextDouble();
            sampleIndex = (randomValue > cumulative2) ? 2 : (randomValue > cumulative1) ? 1 : 0;
            break;
        }
        case 4: {
            final double cumulative1 = Math.exp(minEnergy - energy[0]);
            final double cumulative2 = cumulative1 + Math.exp(minEnergy - energy[1]);
            final double cumulative3 = cumulative2 + Math.exp(minEnergy - energy[2]);
            final double sum = cumulative3 + Math.exp(minEnergy - energy[3]);
            final double randomValue = sum * rand.nextDouble();
            sampleIndex = (randomValue > cumulative2) ? ((randomValue > cumulative3) ? 3 : 2)
                    : ((randomValue > cumulative1) ? 1 : 0);
            break;
        }
        default: // For all other lengths
        {
            // Calculate cumulative conditional probability (unnormalized)
            double sum = 0;
            final double[] samplerScratch = _samplerScratch;
            samplerScratch[0] = 0;
            for (int m = 1; m < length; m++) {
                sum += expApprox(minEnergy - energy[m - 1]);
                samplerScratch[m] = sum;
            }
            sum += expApprox(minEnergy - energy[length - 1]);
            for (int m = length; m < _lengthRoundedUp; m++)
                samplerScratch[m] = Double.POSITIVE_INFINITY;

            final int half = _lengthRoundedUp >> 1;
            while (true) {
                // Sample from the distribution using a binary search.
                final double randomValue = sum * rand.nextDouble();
                sampleIndex = 0;
                for (int bitValue = half; bitValue > 0; bitValue >>= 1) {
                    final int testIndex = sampleIndex | bitValue;
                    if (randomValue > samplerScratch[testIndex])
                        sampleIndex = testIndex;
                }

                // Rejection sampling, since the approximation of the exponential function is so coarse
                final double logp = minEnergy - energy[sampleIndex];
                if (Double.isNaN(logp))
                    throw new DimpleException(
                            "The energy for all values of this variable is infinite. This may indicate a state inconsistent with the model.");
                if (rand.nextDouble() * expApprox(logp) <= Math.exp(logp))
                    break;
            }
        }
        }

        samplerClient.setNextSampleIndex(sampleIndex);
    }

    // This is an approximation to the exponential function; inputs must be non-positive
    // To facilitate subsequent rejection sampling, the error versus the correct exponential function needs to be always positive
    // This is true except for very large negative inputs, for values just as the output approaches zero
    // To ensure rejection is never in an infinite loop, this must reach 0 for large negative inputs before the Math.exp function does
    public final static double expApprox(double value) {
        // Convert input to base2 log, then convert integer part into IEEE754 exponent
        final long expValue = (long) ((int) (1512775.395195186 * value) + 0x3FF00000) << 32; // 1512775.395195186 = 2^20/log(2)
        return Double.longBitsToDouble(expValue & ~(expValue >> 63)); // Clip result if negative and convert to a double
    }
}