Java tutorial
/* * Copyright (c) 2011 The S4 Project, http://s4.io. * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, * either express or implied. See the License for the specific * language governing permissions and limitations under the * License. See accompanying LICENSE file. */ package io.s4.model; import org.apache.commons.lang.NotImplementedException; import org.ejml.data.D1Matrix64F; import org.ejml.data.DenseMatrix64F; import org.ejml.ops.CommonOps; import io.s4.util.MatrixOps; /** * A multivariate Gaussian mixture model. Only diagonal covariance matrices are * supported. * * @author Leo Neumeyer * */ public class GaussianMixtureModel extends Model { /** Supported algorithms for training this model. */ public enum TrainMethod { /** * Estimate mean and variance of Gaussian distribution in the first * iteration. Create the target number of Gaussian components * (numComponents) in the mixture at the end of the first iteration * using the estimated mean and variance. */ STEP, /** * Double the number of Gaussian components at the end of each * iteration. */ DOUBLE, /** Do not allocate structures for training. */ NONE } final private int numElements; final private TrainMethod trainMethod; private int numComponents; private double numSamples; private D1Matrix64F posteriorSum; private D1Matrix64F weights; private D1Matrix64F logWeights; private D1Matrix64F tmpProbs1; private D1Matrix64F tmpProbs2; private double totalLikelihood; private GaussianModel[] components; private int iteration = 0; public GaussianMixtureModel(int numElements, int numComponents, TrainMethod trainMethod) { super(); this.numComponents = numComponents; this.numElements = numElements; this.trainMethod = trainMethod; if (trainMethod == TrainMethod.DOUBLE) throw new NotImplementedException("Want this? Join as a contributor at http://s4.io"); /* Allocate arrays needed for estimation. */ isTrain = false; if (trainMethod != TrainMethod.NONE) { setTrain(true); if (trainMethod == TrainMethod.STEP) { /* Set up for first iteration using a single Gaussian. */ allocateTrainDataStructures(1); } } } /* * TODO: we use this method when the pattern is: create a prototype of the * model and use it to create instances. Notice that we are allocating data * structures in the prototype that we will not use in this case. Not a big * deal but we may want to optimize this somehow later. */ public Model create() { return new GaussianMixtureModel(numElements, numComponents, trainMethod); } private void allocateTrainDataStructures(int numComp) { components = new GaussianModel[numComp]; for (int i = 0; i < numComp; i++) { this.components[i] = new GaussianModel(numElements, true); this.weights = new DenseMatrix64F(numComp, 1); CommonOps.set(this.weights, 1.0 / numComp); this.logWeights = new DenseMatrix64F(numComp, 1); CommonOps.set(this.logWeights, Math.log(1.0 / numComp)); posteriorSum = new DenseMatrix64F(numComp, 1); tmpProbs1 = new DenseMatrix64F(numComp, 1); tmpProbs2 = new DenseMatrix64F(numComp, 1); } } /* * This method is used in {@link TrainMethod#STEP} Convert to a mixture with * N components. This method guarantees that the data structures are created * and that all the variables are set for starting a new training iteration. */ private void increaseNumComponents(int newNumComponents) { /* * We use the Gaussian distribution of the parent GMM to create the * children. */ /* Get mean and variance from parent before we allocate resized data structures. */ D1Matrix64F mean = MatrixOps.doubleArrayToMatrix(this.components[0].getMean()); D1Matrix64F variance = MatrixOps.doubleArrayToMatrix(this.components[0].getVariance()); /* Throw away all previous data structures. */ allocateTrainDataStructures(newNumComponents); /* * Create new mixture components. Abandon the old ones. We already got * the mean and variance in the previous step. */ for (int i = 0; i < newNumComponents; i++) { components[i].setMean(MatrixOps.createRandom(i, mean, variance)); components[i].setVariance(new DenseMatrix64F(variance)); } } /* Thread safe internal logProb method. Must pass temp array. */ private double logProbInternal(D1Matrix64F obs, D1Matrix64F probs) { /* Compute log probabilities for this observation. */ for (int i = 0; i < components.length; i++) { probs.set(i, components[i].logProb(obs) + logWeights.get(i)); } /* * To simplify computation, use the max prob in the denominator instead * of the sum. */ return CommonOps.elementMax(probs); } public double logProb(D1Matrix64F obs) { return logProbInternal(obs, tmpProbs1); } public double logProb(double[] obs) { return logProb(MatrixOps.doubleArrayToMatrix(obs)); } public double logProb(float[] obs) { return logProb(MatrixOps.floatArrayToMatrix(obs)); } /** * @param obs * the observed data vector. * @return the probability. */ public double prob(D1Matrix64F obs) { return Math.exp(logProb(obs)); } /** * @param obs * the observed data vector. * @return the probability. */ public double prob(double[] obs) { return prob(MatrixOps.doubleArrayToMatrix(obs)); } /** * @param obs * the observed data vector. * @return the probability. */ public double prob(float[] obs) { return prob(MatrixOps.floatArrayToMatrix(obs)); } /** Update using Matrix array. */ public void update(D1Matrix64F obs) { if (isTrain() == true) { /* Compute log probabilities for this observation. */ double maxProb = logProbInternal(obs, tmpProbs2); totalLikelihood += maxProb; CommonOps.add(tmpProbs2, -maxProb); /* Compute posterior probabilities. */ MatrixOps.elementExp(tmpProbs2); /* Update posterior sum, needed to compute mixture weights. */ CommonOps.addEquals(posteriorSum, tmpProbs2); for (int i = 0; i < components.length; i++) { components[i].update(obs, tmpProbs2.get(i)); } /* Count number of observations. */ numSamples++; } } public void update(double[] obs) { update(MatrixOps.doubleArrayToMatrix(obs)); } /** Update using float array. */ public void update(float[] obs) { update(MatrixOps.floatArrayToMatrix(obs)); } @Override public void estimate() { if (isTrain() == true) { /* Estimate mixture weights. */ // double sum = CommonOps.elementSum(posteriorSum); // CommonOps.scale(1.0/sum, posteriorSum, weights); CommonOps.scale(1.0 / numSamples, posteriorSum, weights); MatrixOps.elementLog(weights, logWeights); /* Estimate component density. */ for (int i = 0; i < components.length; i++) { components[i].estimate(); } /* * After the first iteration, we can estimate the target number of * mixture components. */ if (iteration == 0 && trainMethod == TrainMethod.STEP) { increaseNumComponents(numComponents); } iteration++; } } @Override public void clearStatistics() { if (isTrain() == true) { for (int i = 0; i < components.length; i++) { components[i].clearStatistics(); } CommonOps.set(posteriorSum, 0.0); numSamples = 0; totalLikelihood = 0; } } /** Number of Gaussian components in the mixture. */ public int getNumComponents() { return this.numComponents; } /** Data vector size. */ public int getNumElements() { return this.numElements; } /** * @return the value of the parameters and sufficient statistics of this * model in a printable format. */ public String toString() { StringBuilder sb = new StringBuilder(""); sb.append("Gaussian Mixture Model\n"); sb.append("num samples: " + numSamples + "\n"); sb.append("num components: " + components.length + "\n"); sb.append("weights: " + weights.toString() + "\n"); sb.append("log weights: " + logWeights.toString() + "\n"); sb.append("total log likelihood: " + totalLikelihood + "\n"); for (int i = 0; i < components.length; i++) { sb.append(components[i].toString()); } return sb.toString(); } }