gedi.util.math.stat.distributions.NormalMixtureDistribution.java Source code

Java tutorial

Introduction

Here is the source code for gedi.util.math.stat.distributions.NormalMixtureDistribution.java

Source

/**
 * 
 *    Copyright 2017 Florian Erhard
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 * 
 */
package gedi.util.math.stat.distributions;

import gedi.util.ArrayUtils;

import java.util.Arrays;

import org.apache.commons.math3.distribution.AbstractRealDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NotFiniteNumberException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.commons.math3.util.FastMath;

public class NormalMixtureDistribution extends AbstractRealDistribution {

    private NormalDistribution[] components;
    private double[] mixing;
    private double[] mixingSum;

    public NormalMixtureDistribution(NormalDistribution[] components, double[] mixing) {
        super(new Well19937c());
        this.components = components;
        this.mixing = mixing;

        if (ArrayUtils.min(mixing) < 0)
            throw new NotPositiveException(ArrayUtils.min(mixing));
        if (components.length != mixing.length)
            throw new DimensionMismatchException(mixing.length, components.length);
        double sum = ArrayUtils.sum(mixing);
        if (Double.isInfinite(sum))
            throw new NotFiniteNumberException(sum);
        ArrayUtils.mult(mixing, 1 / sum);

        this.mixingSum = mixing.clone();
        ArrayUtils.cumSumInPlace(mixingSum, 1);

    }

    public int getNumComponents() {
        return components.length;
    }

    public NormalDistribution getComponent(int c) {
        return components[c];
    }

    public double getMixing(int c) {
        return mixing[c];
    }

    @Override
    public double density(double x) {
        double re = 0;
        int n = getNumComponents();
        for (int i = 0; i < n; i++)
            re += mixing[i] * components[i].density(x);
        return re;
    }

    @Override
    public double cumulativeProbability(double x) {
        double re = 0;
        int n = getNumComponents();
        for (int i = 0; i < n; i++)
            re += mixing[i] * components[i].cumulativeProbability(x);
        return re;
    }

    @Override
    public double getNumericalMean() {
        double re = 0;
        int n = getNumComponents();
        for (int i = 0; i < n; i++)
            re += mixing[i] * components[i].getNumericalMean();
        return re;
    }

    @Override
    public double getNumericalVariance() {
        double re = 0;
        int n = getNumComponents();
        double mean = getNumericalMean();
        for (int i = 0; i < n; i++) {
            double mm = mean - components[i].getNumericalMean();
            re += mixing[i] * (mm * mm + components[i].getNumericalVariance());
        }
        return re;
    }

    @Override
    public double getSupportLowerBound() {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double sample() {
        double r = random.nextDouble();

        for (int i = 0; i < mixingSum.length; i++) {
            if (mixingSum[i] >= r)
                return components[i].sample();
        }
        ;

        return components[components.length - 1].sample();
    }

    @Override
    public double getSupportUpperBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public boolean isSupportLowerBoundInclusive() {
        return false;
    }

    @Override
    public boolean isSupportUpperBoundInclusive() {
        return false;
    }

    @Override
    public boolean isSupportConnected() {
        return true;
    }

    public double logLikelihood(double[] data) {
        int n = data.length;
        double re = 0;
        for (int i = 0; i < n; i++) {
            final double rowDensity = density(data[i]);
            re += FastMath.log(rowDensity);
        }
        return re;
    }

    /**
     * Default maximum number of iterations allowed per fitting process.
     */
    private static final int DEFAULT_MAX_ITERATIONS = 1000;
    /**
     * Default convergence threshold for fitting.
     */
    private static final double DEFAULT_THRESHOLD = 1E-5;

    public static NormalMixtureDistribution fit(double[] data, int components) {
        return fit(init(data, components), data, DEFAULT_MAX_ITERATIONS, DEFAULT_THRESHOLD);
    }

    public static NormalMixtureDistribution fit(NormalMixtureDistribution initialMixture, double[] data,
            final int maxIterations, final double threshold) {

        if (maxIterations < 1) {
            throw new NotStrictlyPositiveException(maxIterations);
        }

        if (threshold < Double.MIN_VALUE) {
            throw new NotStrictlyPositiveException(threshold);
        }

        final int n = data.length;

        final int k = initialMixture.getNumComponents();

        if (k == 1)
            return new NormalMixtureDistribution(new NormalDistribution[] {
                    new NormalDistribution(new Mean().evaluate(data), new StandardDeviation().evaluate(data)) },
                    new double[] { 1 });

        int numIterations = 0;
        double previousLogLikelihood = 0d;

        double logLikelihood = Double.NEGATIVE_INFINITY;

        // Initialize model to fit to initial mixture.
        NormalMixtureDistribution fittedModel = new NormalMixtureDistribution(initialMixture.components,
                initialMixture.mixing);

        while (numIterations++ <= maxIterations
                && FastMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
            previousLogLikelihood = logLikelihood;
            logLikelihood = 0d;

            // E-step: compute the data dependent parameters of the expectation
            // function.
            // The percentage of row's total density between a row and a
            // component
            final double[][] gamma = new double[n][k];
            // Sum of gamma for each component
            final double[] gammaSums = new double[k];

            for (int i = 0; i < n; i++) {
                final double rowDensity = fittedModel.density(data[i]);
                logLikelihood += FastMath.log(rowDensity);

                for (int j = 0; j < k; j++) {
                    gamma[i][j] = fittedModel.mixing[j] * fittedModel.components[j].density(data[i]) / rowDensity;
                    gammaSums[j] += gamma[i][j];
                }
            }
            logLikelihood /= n;
            //         System.out.println(logLikelihood);

            // M-step: compute the new parameters based on the expectation
            // function.
            final double[] newWeights = gammaSums.clone();
            ArrayUtils.mult(newWeights, 1.0 / n);

            NormalDistribution[] comp = new NormalDistribution[k];
            for (int j = 0; j < k; j++) {
                double m = 0;
                for (int i = 0; i < n; i++) {
                    m += gamma[i][j] * data[i];
                }
                m /= gammaSums[j];

                double var = 0;
                for (int i = 0; i < n; i++) {
                    double d = m - data[i];
                    var += gamma[i][j] * d * d;
                }
                var /= gammaSums[j];

                comp[j] = new NormalDistribution(m, Math.sqrt(var));
            }

            // Update current model
            fittedModel = new NormalMixtureDistribution(comp, newWeights);
        }

        if (FastMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
            // Did not converge before the maximum number of iterations
            throw new ConvergenceException();
        }

        return fittedModel;
    }

    public static NormalMixtureDistribution init(final double[] data, final int numComponents)
            throws NotStrictlyPositiveException, DimensionMismatchException {

        if (numComponents == 1)
            return new NormalMixtureDistribution(new NormalDistribution[] {
                    new NormalDistribution(new Mean().evaluate(data), new StandardDeviation().evaluate(data)) },
                    new double[] { 1 });

        if (data.length < 2) {
            throw new NotStrictlyPositiveException(data.length);
        }
        if (numComponents < 1) {
            throw new NumberIsTooSmallException(numComponents, 2, true);
        }
        if (numComponents > data.length) {
            throw new NumberIsTooLargeException(numComponents, data.length, true);
        }

        final int numRows = data.length;
        double[] sortedData = data.clone();
        Arrays.sort(sortedData);

        // components of mixture model to be created
        double[] mixing = new double[numComponents];
        NormalDistribution[] comp = new NormalDistribution[numComponents];

        // create a component based on data in each bin
        for (int k = 0; k < numComponents; k++) {
            // minimum index (inclusive) from sorted data for this bin
            final int minIndex = (k * numRows) / numComponents;

            // maximum index (exclusive) from sorted data for this bin
            final int maxIndex = Math.min(numRows, ((k + 1) * numRows) / numComponents);

            double m = new Mean().evaluate(sortedData, minIndex, maxIndex - minIndex);
            double sd = new StandardDeviation().evaluate(sortedData, minIndex, maxIndex - minIndex);
            mixing[k] = 1d / numComponents;
            comp[k] = new NormalDistribution(m, sd);
        }

        return new NormalMixtureDistribution(comp, mixing);
    }
}