com.analog.lyric.dimple.test.model.TestJunctionTreeTransform.java Source code

Java tutorial

Introduction

Here is the source code for com.analog.lyric.dimple.test.model.TestJunctionTreeTransform.java

Source

/*******************************************************************************
*   Copyright 2014-2015 Analog Devices, Inc.
*
*   Licensed under the Apache License, Version 2.0 (the "License");
*   you may not use this file except in compliance with the License.
*   You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
*   Unless required by applicable law or agreed to in writing, software
*   distributed under the License is distributed on an "AS IS" BASIS,
*   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*   See the License for the specific language governing permissions and
*   limitations under the License.
********************************************************************************/

package com.analog.lyric.dimple.test.model;

import static org.junit.Assert.*;

import org.apache.commons.math3.stat.descriptive.moment.Variance;
import org.eclipse.jdt.annotation.Nullable;
import org.junit.Test;

import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.transform.JunctionTreeTransform;
import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap;
import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap.AddedJointVariable;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.model.variables.VariableList;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.solvers.gibbs.ISolverVariableGibbs;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.analog.lyric.util.misc.Misc;

/**
 * Tests for {@link JunctionTreeTransform}
 */
public class TestJunctionTreeTransform extends DimpleTestBase {
    private final RandomGraphGenerator _graphGenerator = new RandomGraphGenerator(testRand);

    private static final DiscreteDomain d2 = DiscreteDomain.range(0, 1);
    private static final DiscreteDomain d3 = DiscreteDomain.range(0, 2);
    private static final DiscreteDomain d4 = DiscreteDomain.range(0, 3);
    private static final DiscreteDomain d5 = DiscreteDomain.range(0, 5);

    @Test
    public void testTrivialLoop() {
        testGraph(_graphGenerator.buildTrivialLoop());
    }

    @Test
    public void testTriangle() {
        testGraph(_graphGenerator.buildTriangle());

    }

    @Test
    public void testGrid2() {
        testGraph(_graphGenerator.buildGrid(2));
    }

    @Test
    public void testGrid3() {
        testGraph(_graphGenerator.domains(d2, d3, d4).buildGrid(3));
    }

    @Test
    public void testGrid4() {
        testGraph(_graphGenerator.buildGrid(4));
    }

    @Test
    public void testGrid2by20() {
        testGraph(_graphGenerator.domains(d2, d3, d5).buildGrid(2, 20));
    }

    @Test
    public void testGrid1by100() {
        FactorGraph model = _graphGenerator.domains(d2, d3, d4).buildGrid(1, 100);
        assertTrue(model.isTree());
        testTree(model);
    }

    @Test
    public void testRandomGraphs() {
        final int nGraphs = 20;
        final int maxSize = 1000;
        RandomGraphGenerator gen = _graphGenerator.maxBranches(2).maxTreeWidth(5);

        for (int i = 0; i < nGraphs; ++i) {
            testGraph(gen.buildRandomGraph(testRand.nextInt(maxSize) + 10), null);
        }
    }

    @Test
    public void testRandomTree() {
        FactorGraph tree = _graphGenerator.maxBranches(5).domains(d2, d3, d4, d5).buildRandomTree(500);
        assertTrue(tree.isTree());
        testTree(tree);
    }

    /**
     * @see RandomGraphGenerator#buildStudentNetwork()
     */
    @Test
    public void testStudentNetwork() {
        testGraph(_graphGenerator.buildStudentNetwork());
    }

    /*-----------------
     * Helper methods
     */

    /**
     * Assert that source and target graphs in {@code transformMap} represent the same
     * joint distribution down to some level of precision.
     * 
     * @param transformMap
     */
    @SuppressWarnings("null")
    private void assertModelsEquivalent(JunctionTreeTransformMap transformMap) {
        if (transformMap.isIdentity()) {
            return;
        }

        final FactorGraph source = transformMap.source();
        final FactorGraph target = transformMap.target();

        GibbsSolver gibbs = new GibbsSolver();
        GibbsSolverGraph sourceGibbs = source.setSolverFactory(gibbs);
        GibbsSolverGraph targetGibbs = target.setSolverFactory(gibbs);
        targetGibbs.initialize();

        final int nSamples = 100;

        final double[] differences = new double[nSamples];
        for (int n = 0; n < nSamples; ++n) {
            // Generate a sample on the source graph
            source.solve();

            // Copy sample values to new graph
            for (Variable sourceVar : source.getVariables()) {
                Variable targetVar = transformMap.sourceToTargetVariable(sourceVar);

                ISolverVariableGibbs sourceSVar = sourceGibbs.getSolverVariable(sourceVar);
                ISolverVariableGibbs targetSVar = targetGibbs.getSolverVariable(targetVar);

                targetSVar.setCurrentSample(sourceSVar.getCurrentSampleValue());
            }

            // Update values of added variables
            for (AddedJointVariable<?> added : transformMap.addedJointVariables()) {
                final ISolverVariableGibbs addedSVar = targetGibbs.getSolverVariable(added.getVariable());
                final Value value = addedSVar.getCurrentSampleValue();
                final Value[] inputs = new Value[added.getInputCount()];
                for (int i = inputs.length; --i >= 0;) {
                    final Variable inputVar = added.getInput(i);
                    final ISolverVariableGibbs inputSVar = targetGibbs.getSolverVariable(inputVar);
                    inputs[i] = inputSVar.getCurrentSampleValue();
                }

                added.updateValue(value, inputs);
            }

            // Compare the joint likelihoods
            final double sourceEnergy = sourceGibbs.getSampleScore();
            final double targetEnergy = targetGibbs.getSampleScore();
            final double difference = sourceEnergy - targetEnergy;
            if (Math.abs(difference) > 1e-10) {
                Misc.breakpoint();
            }
            differences[n] = difference;
        }

        double variance = new Variance().evaluate(differences);
        assertEquals(0.0, variance, 1e-10);
    }

    private void testGraph(FactorGraph model) {
        testGraph(model, false);
    }

    private void testTree(FactorGraph model) {
        testGraph(model, true);
    }

    private void testGraph(FactorGraph model, @Nullable Boolean expectIdentity) {
        try {
            testGraphImpl(model, expectIdentity);
        } catch (Throwable ex) {
            String msg = String.format("%s. TestJunctionTreeTransform._seed==%dL", ex.toString(),
                    testRand.getSeed());
            ex.printStackTrace(System.err);
            System.err.format(">>> TestJunctionTreeTransform._seed==%dL;<<<\n", testRand.getSeed());
            throw new RuntimeException(msg, ex);
        }
    }

    private void testGraphImpl(FactorGraph model, @Nullable Boolean expectIdentity) {
        JunctionTreeTransform jt = new JunctionTreeTransform().random(testRand);
        assertSame(testRand, jt.random());
        assertFalse(jt.useConditioning());

        JunctionTreeTransformMap transformMap = jt.transform(model);

        if (expectIdentity != null) {
            assertEquals(expectIdentity, transformMap.isIdentity());
        }
        if (transformMap.isIdentity()) {
            assertTrue(model.isForest());
        }

        for (Factor factor : transformMap.target().getFactors()) {
            // Name target factors as a debugging aid
            RandomGraphGenerator.labelFactor(factor);
        }
        assertTrue(transformMap.target().isForest());
        assertModelsEquivalent(transformMap);

        // Try with conditioning
        model.setSolverFactory(null);
        VariableList variables = model.getVariables();
        for (int i = 0; i < 100000; ++i) {
            Variable variable = variables.getByIndex(testRand.nextInt(variables.size()));
            if (variable instanceof Discrete) {
                Discrete discrete = (Discrete) variable;
                discrete.setPriorIndex(testRand.nextInt(discrete.getDomain().size()));
                break;
            }
        }
        jt.useConditioning(true);
        assertTrue(jt.useConditioning());
        transformMap = jt.transform(model);
        for (Factor factor : transformMap.target().getFactors()) {
            // Name target factors as a debugging aid
            RandomGraphGenerator.labelFactor(factor);
        }
        assertTrue(transformMap.target().isForest());
        assertModelsEquivalent(transformMap);
    }

}