Java tutorial
/******************************************************************************* * Copyright 2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.sumproduct; import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*; import static java.util.Objects.*; import static org.junit.Assert.*; import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor; import org.apache.commons.math3.linear.EigenDecomposition; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.random.CorrelatedRandomVectorGenerator; import org.apache.commons.math3.random.GaussianRandomGenerator; import org.apache.commons.math3.stat.correlation.StorelessCovariance; import org.junit.Test; import com.analog.lyric.dimple.factorfunctions.MultivariateNormal; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel; import com.analog.lyric.dimple.model.variables.Complex; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters; import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph; import com.analog.lyric.dimple.test.DimpleTestBase; import com.google.common.primitives.Doubles; /** * Tests for sampled factors in SumProduct solver */ public class TestSampledFactors extends DimpleTestBase { /** * Adapted from MATLAB test4 in tests/algoGaussian/testSampledFactors.m */ @Test public void sampledComplexProduct() { // NOTE: test may fail if seed is changed! We keep the number of samples down so that the test doesn't // take too long. Increasing the samples produces better results. testRand.setSeed(42); try (CurrentModel cur = using(new FactorGraph())) { final Complex a = complex("a"); final Complex b = complex("b"); final Complex c = product(a, b); double[] aMean = new double[] { 10, 10 }; RealMatrix aCovariance = randCovariance(2); a.setPrior(new MultivariateNormal(aMean, aCovariance.getData())); double[] bMean = new double[] { -20, 20 }; RealMatrix bCovariance = randCovariance(2); b.setPrior(new MultivariateNormalParameters(bMean, bCovariance.getData())); GaussianRandomGenerator normalGenerator = new GaussianRandomGenerator(testRand); CorrelatedRandomVectorGenerator aGenerator = new CorrelatedRandomVectorGenerator(aMean, aCovariance, 1e-12, normalGenerator); CorrelatedRandomVectorGenerator bGenerator = new CorrelatedRandomVectorGenerator(bMean, bCovariance, 1e-12, normalGenerator); StorelessCovariance expectedCov = new StorelessCovariance(2); final int nSamples = 10000; RealVector expectedMean = MatrixUtils.createRealVector(new double[2]); double[] cSample = new double[2]; for (int i = 0; i < nSamples; ++i) { double[] aSample = aGenerator.nextVector(); double[] bSample = bGenerator.nextVector(); // Compute complex product cSample[0] = aSample[0] * bSample[0] - aSample[1] * bSample[1]; cSample[1] = aSample[0] * bSample[1] + aSample[1] * bSample[0]; expectedMean.addToEntry(0, cSample[0]); expectedMean.addToEntry(1, cSample[1]); expectedCov.increment(cSample); } expectedMean.mapDivideToSelf(nSamples); // normalize SumProductSolverGraph sfg = requireNonNull(cur.graph.setSolverFactory(new SumProductSolver())); sfg.setOption(GibbsOptions.numSamples, nSamples); sfg.solve(); MultivariateNormalParameters cBelief = requireNonNull(c.getBelief()); RealVector observedMean = MatrixUtils.createRealVector(cBelief.getMean()); double scaledMeanDistance = expectedMean.getDistance(observedMean) / expectedMean.getNorm(); // System.out.format("expectedMean = %s\n", expectedMean); // System.out.format("observedMean = %s\n", observedMean); // System.out.println(scaledMeanDistance); assertEquals(0.0, scaledMeanDistance, .02); RealMatrix expectedCovariance = expectedCov.getCovarianceMatrix(); RealMatrix observedCovariance = MatrixUtils.createRealMatrix(cBelief.getCovariance()); RealMatrix diffCovariance = expectedCovariance.subtract(observedCovariance); double scaledCovarianceDistance = diffCovariance.getNorm() / expectedCovariance.getNorm(); // System.out.println(expectedCovariance); // System.out.println(expectedCovariance.getNorm()); // System.out.println(diffCovariance); // System.out.println(diffCovariance.getNorm()); // System.out.println(diffCovariance.getNorm() / expectedCovariance.getNorm()); assertEquals(0.0, scaledCovarianceDistance, .2); } } /** * Generates a random covariance matrix with given dimension. */ RealMatrix randCovariance(int n) { RealMatrix A = MatrixUtils.createRealMatrix(n, n); // Randomize A.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { @Override public double visit(int row, int column, double value) { return testRand.nextDouble(); } }); RealMatrix B = A.add(A.transpose()); // B is symmetric double minEig = Doubles.min(new EigenDecomposition(B).getRealEigenvalues()); double r = testRand.nextGaussian(); r *= r; r += Math.ulp(1.0); RealMatrix I = MatrixUtils.createRealIdentityMatrix(n); RealMatrix C = B.add(I.scalarMultiply(r - minEig)); return C; } }