Java tutorial
/* * Software License Agreement (BSD License) * * Copyright 2013 Marc Pujol <mpujol@iiia.csic.es>. * * Redistribution and use of this software in source and binary forms, with or * without modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above * copyright notice, this list of conditions and the * following disclaimer. * * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the * following disclaimer in the documentation and/or other * materials provided with the distribution. * * Neither the name of IIIA-CSIC, Artificial Intelligence Research Institute * nor the names of its contributors may be used to * endorse or promote products derived from this * software without specific prior written permission of * IIIA-CSIC, Artificial Intelligence Research Institute * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ package es.csic.iiia.planes.util; import java.util.Arrays; import java.util.logging.Level; import java.util.logging.Logger; import org.apache.commons.math3.distribution.GammaDistribution; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.CholeskyDecomposition; import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.SingularMatrixException; import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.Well19937c; /** * Inverse Wishart distribution implementation, to sample random covariances matrices for * multivariate gaussian distributions. * <p/> * The sampling method follows the procedure described by Odell & Feiveson, 1966 to get samples * from a Wishart distribution, and then computes the inverse of the obtained samples. * * @author Marc Pujol <mpujol@iiia.csic.es> */ public class InverseWishartDistribution { private static final Logger LOG = Logger.getLogger(InverseWishartDistribution.class.getName()); private GammaDistribution[] gammas; private double df; private RealMatrix scaleMatrix; private CholeskyDecomposition cholesky; private RandomGenerator random; /** * Builds a new Inverse Wishart distribution with the given scale and degrees of freedom. * * @param scaleMatrix scale matrix. * @param df degrees of freedom. */ public InverseWishartDistribution(RealMatrix scaleMatrix, double df) { if (!scaleMatrix.isSquare()) { throw new RuntimeException("scaleMatrix must be square."); } this.scaleMatrix = scaleMatrix; this.df = df; this.random = new Well19937c(); initialize(); } private void initialize() { final int dim = scaleMatrix.getColumnDimension(); // Build gamma distributions for the diagonal gammas = new GammaDistribution[dim]; for (int i = 0; i < dim; i++) { gammas[i] = new GammaDistribution(df - i - .99 / 2, 2); } // Build the cholesky decomposition cholesky = new CholeskyDecomposition(scaleMatrix); } /** * Reseeds the random generator. * * @param seed new random seed. */ public void reseedRandomGenerator(long seed) { random.setSeed(seed); for (int i = 0, len = scaleMatrix.getColumnDimension(); i < len; i++) { gammas[i].reseedRandomGenerator(seed + i); } } /** * Returns a sample matrix from this distribution. * @return sampled matrix. */ public RealMatrix sample() { for (int i = 0; i < 100; i++) { try { RealMatrix A = sampleWishart(); RealMatrix result = new LUDecomposition(A).getSolver().getInverse(); LOG.log(Level.FINE, "Cov = {0}", result); return result; } catch (SingularMatrixException ex) { LOG.finer("Discarding singular matrix generated by the wishart distribution."); } } throw new RuntimeException("Unable to generate inverse wishart samples!"); } private RealMatrix sampleWishart() { final int dim = scaleMatrix.getColumnDimension(); // Build N_{ij} double[][] N = new double[dim][dim]; for (int j = 0; j < dim; j++) { for (int i = 0; i < j; i++) { N[i][j] = random.nextGaussian(); } } if (LOG.isLoggable(Level.FINEST)) { LOG.log(Level.FINEST, "N = {0}", Arrays.deepToString(N)); } // Build V_j double[] V = new double[dim]; for (int i = 0; i < dim; i++) { V[i] = gammas[i].sample(); } if (LOG.isLoggable(Level.FINEST)) { LOG.log(Level.FINEST, "V = {0}", Arrays.toString(V)); } // Build B double[][] B = new double[dim][dim]; // b_{11} = V_1 (first j, where sum = 0 because i == j and the inner // loop is never entered). // b_{jj} = V_j + \sum_{i=1}^{j-1} N_{ij}^2, j = 2, 3, ..., p for (int j = 0; j < dim; j++) { double sum = 0; for (int i = 0; i < j; i++) { sum += Math.pow(N[i][j], 2); } B[j][j] = V[j] + sum; } if (LOG.isLoggable(Level.FINEST)) { LOG.log(Level.FINEST, "B*_jj : = {0}", Arrays.deepToString(B)); } // b_{1j} = N_{1j} * \sqrt V_1 for (int j = 1; j < dim; j++) { B[0][j] = N[0][j] * Math.sqrt(V[0]); B[j][0] = B[0][j]; } if (LOG.isLoggable(Level.FINEST)) { LOG.log(Level.FINEST, "B*_1j = {0}", Arrays.deepToString(B)); } // b_{ij} = N_{ij} * \sqrt V_1 + \sum_{k=1}^{i-1} N_{ki}*N_{kj} for (int j = 1; j < dim; j++) { for (int i = 1; i < j; i++) { double sum = 0; for (int k = 0; k < i; k++) { sum += N[k][i] * N[k][j]; } B[i][j] = N[i][j] * Math.sqrt(V[i]) + sum; B[j][i] = B[i][j]; } } if (LOG.isLoggable(Level.FINEST)) { LOG.log(Level.FINEST, "B* = {0}", Arrays.deepToString(B)); } RealMatrix BMat = new Array2DRowRealMatrix(B); RealMatrix A = cholesky.getL().multiply(BMat).multiply(cholesky.getLT()); if (LOG.isLoggable(Level.FINER)) { LOG.log(Level.FINER, "A* = {0}", Arrays.deepToString(N)); } A = A.scalarMultiply(1 / df); return A; } }