com.analog.lyric.dimple.solvers.gibbs.samplers.generic.SuwaTodoSampler.java Source code

Java tutorial

Introduction

Here is the source code for com.analog.lyric.dimple.solvers.gibbs.samplers.generic.SuwaTodoSampler.java

Source

/*******************************************************************************
*   Copyright 2013 Analog Devices, Inc.
*
*   Licensed under the Apache License, Version 2.0 (the "License");
*   you may not use this file except in compliance with the License.
*   You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
*   Unless required by applicable law or agreed to in writing, software
*   distributed under the License is distributed on an "AS IS" BASIS,
*   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*   See the License for the specific language governing permissions and
*   limitations under the License.
********************************************************************************/

package com.analog.lyric.dimple.solvers.gibbs.samplers.generic;

import static com.analog.lyric.dimple.environment.DimpleEnvironment.*;

import org.apache.commons.math3.random.RandomGenerator;

import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.values.DiscreteValue;
import com.analog.lyric.math.Utilities;

public class SuwaTodoSampler extends AbstractGenericSampler implements IDiscreteDirectSampler {
    protected double[] _samplerScratch = ArrayUtil.EMPTY_DOUBLE_ARRAY;
    protected int _lengthRoundedUp = 0;
    protected int _length = 0;

    @Override
    public void initialize(Domain variableDomain) {
        int length = ((DiscreteDomain) variableDomain).size();
        _length = length;
        _lengthRoundedUp = Utilities.nextPow2(length);
        _samplerScratch = new double[_lengthRoundedUp];
    }

    @Override
    public void nextSample(DiscreteValue sampleValue, double[] energy, double minEnergy,
            IDiscreteSamplerClient samplerClient) {
        RandomGenerator rand = activeRandom();
        final int length = sampleValue.getDomain().size(); // energy may be longer than domain size
        int sampleIndex;

        // Special-case length 2 for speed
        // This case is equivalent to MH
        if (length == 2) {
            final int previousIndex = sampleValue.getIndex();
            final double pdf0 = Math.exp(minEnergy - energy[0]);
            final double pdf1 = Math.exp(minEnergy - energy[1]);
            if (previousIndex == 0) {
                double rejectProb = pdf0 - pdf1;
                if (rejectProb < 0)
                    sampleIndex = 1; // Flip
                else if (rand.nextDouble() < rejectProb)
                    sampleIndex = 0;
                else
                    sampleIndex = 1; // Flip
            } else {
                double rejectProb = pdf1 - pdf0;
                if (rejectProb < 0)
                    sampleIndex = 0; // Flip
                if (rand.nextDouble() < rejectProb)
                    sampleIndex = 1;
                else
                    sampleIndex = 0; // Flip
            }

        } else // For all other lengths
        {
            // Calculate cumulative conditional probability (unnormalized)
            double sum = 0;
            final double[] samplerScratch = _samplerScratch;
            final int previousIndex = sampleValue.getIndex();
            double previousIntervalValue = 0;
            samplerScratch[0] = 0;
            for (int m = 1; m < length; m++) {
                final int mm1 = m - 1;
                final double unnormalizedValue = Math.exp(minEnergy - energy[mm1]);
                if (mm1 == previousIndex)
                    previousIntervalValue = unnormalizedValue;
                sum += unnormalizedValue;
                samplerScratch[m] = sum;
            }
            final int lm1 = length - 1;
            final double unnormalizedValue = Math.exp(minEnergy - energy[lm1]);
            if (previousIndex == lm1)
                previousIntervalValue = unnormalizedValue;
            sum += unnormalizedValue;
            for (int m = length; m < _lengthRoundedUp; m++)
                samplerScratch[m] = Double.POSITIVE_INFINITY;

            // Sample from a range circularly shifted by the largest interval with size of the previous value interval
            // In this scale, the largest interval is always 1
            double randomValue = samplerScratch[previousIndex] + 1 + previousIntervalValue * rand.nextDouble();
            randomValue = randomValue % sum; // Circularly wrap

            // Sample from the CDF using a binary search
            final int half = _lengthRoundedUp >> 1;
            sampleIndex = 0;
            for (int bitValue = half; bitValue > 0; bitValue >>= 1) {
                final int testIndex = sampleIndex | bitValue;
                if (randomValue > samplerScratch[testIndex])
                    sampleIndex = testIndex;
            }
        }

        samplerClient.setNextSampleIndex(sampleIndex);
    }
}