Java tutorial
/* * IndianBuffetProcessPrior.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.model; import dr.math.Poisson; import dr.math.distributions.PoissonDistribution; import org.apache.commons.math.special.Beta; /** * @author Max Tolkoff */ public class IndianBuffetProcessPrior extends AbstractModelLikelihood implements MatrixSizePrior { public IndianBuffetProcessPrior(Parameter alpha, Parameter beta, AdaptableSizeFastMatrixParameter data) { super(null); this.alpha = alpha; alpha.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0, 1)); addVariable(alpha); this.beta = beta; beta.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0, 1)); addVariable(beta); this.data = data; addVariable(data); for (int i = 0; i < data.getRowDimension(); i++) { if (data.getParameterValue(i, 0) != 0) containsNonZeroElements[0] = true; } for (int i = 0; i < data.getColumnDimension(); i++) { for (int j = 0; j < data.getRowDimension(); j++) { rowCount[i] += Math.abs(data.getParameterValue(j, i)); } } ncols = data.getColumnDimension(); } private int factorial(int num) { if (num < 0) { throw new RuntimeException("Cannot take a negative factorial"); } else if (num == 0) { return 1; } else { int fac = 1; for (int i = 0; i < num; i++) { fac *= (i + 1); } return fac; } } private double H() { if (!betaKnown) { H = 0; for (int i = 0; i < data.getRowDimension(); i++) { H += beta.getParameterValue(0) / (beta.getParameterValue(0) + i); } } return H; } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { } @Override protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { if (ncols != data.getColumnDimension()) { int sum = 0; for (int i = 0; i < data.getRowDimension(); i++) { sum += data.getParameterValue(i, data.getColumnDimension() - 1); } rowCount[data.getColumnDimension() - 1] = sum; ncols = data.getColumnDimension(); } else { double value = data.getParameterValue(index); int col = index / data.getRowDimension(); if (value == 0.0) { rowCount[col] -= 1; if (rowCount[col] == 0) { containsNonZeroElements[col] = false; } } else { rowCount[col] += 1; containsNonZeroElements[col] = true; } } likelihoodKnown = false; if (variable == beta) betaKnown = false; if (variable == data) dataKnown = false; } @Override protected void storeState() { storedBetaKnown = betaKnown; storedContainsNonZeroElements = containsNonZeroElements; storedDataKnown = dataKnown; storedLikelihoodKnown = likelihoodKnown; storedLogLikelihood = logLikelihood; storedRowCount = rowCount; storedKPlus = KPlus; storedH = H; storedBottom = bottom; storedSum2 = sum2; storedncols = ncols; } @Override protected void restoreState() { betaKnown = storedBetaKnown; containsNonZeroElements = storedContainsNonZeroElements; dataKnown = storedDataKnown; likelihoodKnown = storedLikelihoodKnown; logLikelihood = storedLogLikelihood; rowCount = storedRowCount; KPlus = storedKPlus; H = storedH; bottom = storedBottom; sum2 = storedSum2; ncols = storedncols; } @Override protected void acceptState() { } @Override public Model getModel() { return this; } @Override public double getLogLikelihood() { if (!likelihoodKnown) { logLikelihood = calculateLogLikelihood(); likelihoodKnown = true; } return logLikelihood; } private double calculateLogLikelihood() { int sum; if (!dataKnown) { bottom = 1; boolean[] isExplored = new boolean[data.getColumnDimension()]; containsNonZeroElements = new boolean[data.getColumnDimension()]; rowCount = new int[data.getColumnDimension()]; boolean same; for (int i = 0; i < data.getColumnDimension(); i++) { sum = 1; if (!isExplored[i]) { for (int j = i + 1; j < data.getColumnDimension(); j++) { same = true; if (!isExplored[j]) { for (int k = 0; k < data.getRowDimension(); k++) { if (Math.abs(data.getParameterValue(k, i)) != Math .abs(data.getParameterValue(k, j))) same = false; // if (data.getParameterValue(k, j) != 0) { // containsNonZeroElements[j] = true; // } // rowCount[j]+=data.getParameterValue(k,j); } } if (same && containsNonZeroElements[j]) { isExplored[j] = true; sum += 1; } else if (!containsNonZeroElements[j]) { isExplored[j] = true; } } } bottom *= factorial(sum); } } if (!dataKnown || !betaKnown) { sum2 = 0; KPlus = 0; for (int i = 0; i < data.getColumnDimension(); i++) { if (containsNonZeroElements[i]) { KPlus++; sum2 += Beta.logBeta(rowCount[i], data.getRowDimension() + beta.getParameterValue(0) - rowCount[i]); } } } double p1 = KPlus * Math.log(alpha.getParameterValue(0) * beta.getParameterValue(0) / bottom); double p2 = -alpha.getParameterValue(0) * H(); double p3 = sum2; betaKnown = true; dataKnown = true; return p1 + p2 + p3; } @Override public double getSizeLogLikelihood() { PoissonDistribution poisson = new PoissonDistribution(alpha.getParameterValue(0) * H()); calculateLogLikelihood(); return poisson.logPdf(KPlus) - Math.log(1 - Math.exp(-poisson.mean())); } public int[] getRowCount() { return rowCount; } public AdaptableSizeFastMatrixParameter getData() { return data; } @Override public void makeDirty() { betaKnown = false; dataKnown = false; } boolean likelihoodKnown; boolean storedLikelihoodKnown; double logLikelihood; double storedLogLikelihood; boolean betaKnown = false; boolean dataKnown = false; boolean storedDataKnown; boolean storedBetaKnown; int[] rowCount; int[] storedRowCount; int KPlus; int storedKPlus; boolean[] containsNonZeroElements; boolean[] storedContainsNonZeroElements; double H; double storedH; int bottom; int storedBottom; double sum2; double storedSum2; int ncols; int storedncols; AdaptableSizeFastMatrixParameter data; Parameter alpha; Parameter beta; }