org.apache.mahout.math.random.MultiNormal.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.math.random.MultiNormal.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math.random;

import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.DiagonalMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;

import java.util.Random;

/**
 * Samples from a multi-variate normal distribution.
 * <p/>
 * This is done by sampling from several independent unit normal distributions to get a vector u.
 * The sample value that is returned is then A u + m where A is derived from the covariance matrix
 * and m is the mean of the result.
 * <p/>
 * If \Sigma is the desired covariance matrix, then you can use any value of A such that A' A =
 * \Sigma.  The Cholesky decomposition can be used to compute A if \Sigma is positive definite.
 * Slightly more expensive is to use the SVD U S V' = \Sigma and then set A = U \sqrt{S}.
 *
 * Useful special cases occur when \Sigma is diagonal so that A = \sqrt(\Sigma) or where \Sigma = r I.
 *
 * Another special case is where m = 0.
 */
public class MultiNormal implements Sampler<Vector> {
    private final Random gen;
    private final int dimension;
    private final Matrix scale;
    private final Vector mean;

    /**
     * Constructs a sampler with diagonal scale matrix.
     * @param diagonal The diagonal elements of the scale matrix.
     */
    public MultiNormal(Vector diagonal) {
        this(new DiagonalMatrix(diagonal), null);
    }

    /**
     * Constructs a sampler with diagonal scale matrix and (potentially)
     * non-zero mean.
     * @param diagonal The scale matrix's principal diagonal.
     * @param mean The desired mean.  Set to null if zero mean is desired.
     */
    public MultiNormal(Vector diagonal, Vector mean) {
        this(new DiagonalMatrix(diagonal), mean);
    }

    /**
     * Constructs a sampler with non-trivial scale matrix and mean.
     */
    public MultiNormal(Matrix a, Vector mean) {
        this(a, mean, a.columnSize());
    }

    public MultiNormal(int dimension) {
        this(null, null, dimension);
    }

    public MultiNormal(double radius, Vector mean) {
        this(new DiagonalMatrix(radius, mean.size()), mean);
    }

    private MultiNormal(Matrix scale, Vector mean, int dimension) {
        gen = RandomUtils.getRandom();
        this.dimension = dimension;
        this.scale = scale;
        this.mean = mean;
    }

    @Override
    public Vector sample() {
        Vector v = new DenseVector(dimension).assign(new DoubleFunction() {
            @Override
            public double apply(double ignored) {
                return gen.nextGaussian();
            }
        });
        if (mean != null) {
            if (scale != null) {
                return scale.times(v).plus(mean);
            } else {
                return v.plus(mean);
            }
        } else {
            if (scale != null) {
                return scale.times(v);
            } else {
                return v;
            }
        }
    }

    public Vector getScale() {
        return mean;
    }
}