com.analog.lyric.dimple.test.solvers.gibbs.TestGibbsConjugateSampling.java Source code

Java tutorial

Introduction

Here is the source code for com.analog.lyric.dimple.test.solvers.gibbs.TestGibbsConjugateSampling.java

Source

/*******************************************************************************
*   Copyright 2015 Analog Devices, Inc.
*
*   Licensed under the Apache License, Version 2.0 (the "License");
*   you may not use this file except in compliance with the License.
*   You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
*   Unless required by applicable law or agreed to in writing, software
*   distributed under the License is distributed on an "AS IS" BASIS,
*   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*   See the License for the specific language governing permissions and
*   limitations under the License.
********************************************************************************/

package com.analog.lyric.dimple.test.solvers.gibbs;

import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;

import java.util.Random;

import org.apache.commons.math3.stat.StatUtils;
import org.junit.Test;

import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel;
import com.analog.lyric.dimple.model.variables.Real;
import com.analog.lyric.dimple.options.DimpleOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsReal;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.MHSampler;
import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.SliceSampler;
import com.analog.lyric.dimple.test.DimpleTestBase;

/**
 * 
 * @since 0.08
 * @author Christopher Barber
 */
public class TestGibbsConjugateSampling extends DimpleTestBase {
    @Test
    public void test3() {
        // Adapted from MATLAB testConjugateSampling.m/test3

        Random rand = new Random(42);

        final double priorMean = 3.0;
        final double priorPrecision = 0.01;
        final double dataMean = 10;
        final double dataPrecision = 0.001;
        final int numDatapoints = 100;

        double[] data = new double[numDatapoints];
        double dataSum = 0.0;
        for (int i = 0; i < numDatapoints; ++i) {
            dataSum += data[i] = dataMean + rand.nextGaussian() * dataPrecision;
        }

        final double expectedPrecision = priorPrecision + numDatapoints * dataPrecision;
        final double expectedStd = 1 / Math.sqrt(expectedPrecision);
        final double expectedMean = (priorMean * priorPrecision + dataSum * dataPrecision) / expectedPrecision;

        final FactorGraph fg = new FactorGraph();
        Real mean;
        Real[] x;

        try (CurrentModel currrent = using(fg)) {
            mean = real("mean");
            mean.setPrior(new Normal(priorMean, priorPrecision));
            x = fixed("x", data);
            addFactor(new Normal(), mean, dataPrecision, x);
        }

        fg.setOption(DimpleOptions.randomSeed, 1L);
        fg.setOption(GibbsOptions.numSamples, 1000);
        fg.setOption(GibbsOptions.saveAllSamples, true);
        fg.setOption(GibbsOptions.saveAllScores, true);

        GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver()));

        GibbsReal smean = sfg.getReal(mean);
        assertEquals("NormalSampler", smean.getSamplerName());

        fg.solve();

        double[] means = smean.getAllSamples();

        assertEquals(expectedMean, StatUtils.mean(means), 0.01);
        assertEquals(expectedStd, Math.sqrt(StatUtils.variance(means)), .05);

        // Try again with slice sampler
        mean.setOption(GibbsOptions.realSampler, SliceSampler.class);
        fg.setOption(GibbsOptions.numSamples, 2000);
        fg.solve();

        double[] means2 = smean.getAllSamples();

        assertEquals("SliceSampler", smean.getSamplerName());
        assertEquals("NormalSampler", sfg.getReal(x[0]).getSamplerName());
        assertEquals(expectedMean, StatUtils.mean(means2), 0.1);
        assertEquals(expectedStd, Math.sqrt(StatUtils.variance(means2)), 0.1);

        // Try again with MH
        mean.setOption(GibbsOptions.realSampler, MHSampler.class);
        fg.setOption(GibbsOptions.scansPerSample, 10);
        fg.setOption(GibbsOptions.numSamples, 1000);
        fg.solve();

        double[] means3 = smean.getAllSamples();

        assertEquals("MHSampler", smean.getSamplerName());
        assertEquals("NormalSampler", sfg.getReal(x[0]).getSamplerName());
        assertEquals(expectedMean, StatUtils.mean(means3), 0.25);
        assertEquals(expectedStd, Math.sqrt(StatUtils.variance(means3)), 0.25);
    }
}