Java tutorial
/* * NormalPeriodPriorDistribution.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.epidemiology.casetocase.periodpriors; import dr.inference.loggers.LogColumn; import dr.inference.model.Parameter; import dr.math.distributions.NormalDistribution; import dr.math.distributions.NormalGammaDistribution; import dr.math.functionEval.GammaFunction; import dr.xml.*; import org.apache.commons.math.MathException; import org.apache.commons.math.distribution.TDistributionImpl; import java.util.ArrayList; import java.util.Arrays; /** The assumption here is that the periods are drawn from a normal distribution with unknown mean and variance. The hyperprior is the conjugate, normal-gamma distribution. @author Matthew Hall */ public class NormalPeriodPriorDistribution extends AbstractPeriodPriorDistribution { public static final String NORMAL = "normalPeriodPriorDistribution"; public static final String LOG = "log"; public static final String ID = "id"; public static final String MU = "mu"; public static final String LAMBDA = "lambda"; public static final String ALPHA = "alpha"; public static final String BETA = "beta"; private NormalGammaDistribution hyperprior; private Parameter posteriorMean; private Parameter posteriorBeta; private Parameter posteriorExpectedPrecision; double normalApproximationThreshold = 30; private ArrayList<Double> dataValues; private double[] currentParameters; public NormalPeriodPriorDistribution(String name, boolean log, NormalGammaDistribution hyperprior) { super(name, log); this.hyperprior = hyperprior; posteriorBeta = new Parameter.Default(1); posteriorMean = new Parameter.Default(1); posteriorExpectedPrecision = new Parameter.Default(1); addVariable(posteriorBeta); addVariable(posteriorMean); addVariable(posteriorExpectedPrecision); } public NormalPeriodPriorDistribution(String name, boolean log, double mu_0, double lambda_0, double alpha_0, double beta_0) { this(name, log, new NormalGammaDistribution(mu_0, lambda_0, alpha_0, beta_0)); reset(); } public void reset() { dataValues = new ArrayList<Double>(); currentParameters = hyperprior.getParameters(); logL = 0; } // this returns the posterior predictive probability of the new value, and updates the total public double calculateLogPosteriorProbability(double newValue, double minValue) { double out = calculateLogPosteriorPredictiveProbability(newValue); if (minValue != Double.NEGATIVE_INFINITY) { out -= calculateLogPosteriorPredictiveCDF(minValue, true); } logL += out; update(newValue); return out; } public double calculateLogPosteriorCDF(double limit, boolean upper) { return calculateLogPosteriorPredictiveCDF(limit, upper); } public double calculateLogPosteriorPredictiveProbability(double value) { double mean = currentParameters[0]; double sd = Math.sqrt( currentParameters[3] * (currentParameters[1] + 1) / (currentParameters[2] * currentParameters[1])); double scaledValue = (value - mean) / sd; double out; if (2 * currentParameters[2] <= normalApproximationThreshold) { TDistributionImpl tDist = new TDistributionImpl(2 * currentParameters[2]); out = Math.log(tDist.density(scaledValue)); } else { out = NormalDistribution.logPdf(scaledValue, 0, 1); } return out; } public double calculateLogPosteriorPredictiveCDF(double value, boolean upperTail) { double mean = currentParameters[0]; double sd = Math.sqrt( currentParameters[3] * (currentParameters[1] + 1) / (currentParameters[2] * currentParameters[1])); double scaledValue = (value - mean) / sd; double out; if (2 * currentParameters[2] <= normalApproximationThreshold) { TDistributionImpl tDist = new TDistributionImpl(2 * currentParameters[2]); try { out = upperTail ? Math.log(tDist.cumulativeProbability(-scaledValue)) : Math.log(tDist.cumulativeProbability(scaledValue)); } catch (MathException e) { throw new RuntimeException(e.toString()); } } else { out = upperTail ? NormalDistribution.standardCDF(-scaledValue, true) : NormalDistribution.standardCDF(scaledValue, true); } return out; } private void update(double newData) { dataValues.add(newData); double[] originalParameters = hyperprior.getParameters(); double lambda_0 = originalParameters[1]; double oldMu = currentParameters[0]; double oldLambda = currentParameters[1]; double oldAlpha = currentParameters[2]; double oldBeta = currentParameters[3]; double count = dataValues.size(); double newMu = (newData - oldMu) / (lambda_0 + count) + oldMu; double newLambda = oldLambda + 1; double newAlpha = oldAlpha + 0.5; double newBeta = oldBeta + oldLambda * Math.pow(newData - oldMu, 2) / (2 * (oldLambda + 1)); posteriorMean.setParameterValue(0, newMu); posteriorBeta.setParameterValue(0, newBeta); posteriorExpectedPrecision.setParameterValue(0, newAlpha / newBeta); currentParameters = new double[] { newMu, newLambda, newAlpha, newBeta }; } public double calculateLogLikelihood(double[] values) { int count = values.length; double[] infPredictiveDistributionParameters = hyperprior.getParameters(); double mu_0 = infPredictiveDistributionParameters[0]; double lambda_0 = infPredictiveDistributionParameters[1]; double alpha_0 = infPredictiveDistributionParameters[2]; double beta_0 = infPredictiveDistributionParameters[3]; double lambda_n = lambda_0 + count; double alpha_n = alpha_0 + count / 2; double sum = 0; for (Double infPeriod : values) { sum += infPeriod; } double mean = sum / count; double sumOfDifferences = 0; for (Double infPeriod : values) { sumOfDifferences += Math.pow(infPeriod - mean, 2); } posteriorMean.setParameterValue(0, (lambda_0 * mu_0 + sum) / (lambda_0 + count)); double beta_n = beta_0 + 0.5 * sumOfDifferences + lambda_0 * count * Math.pow(mean - mu_0, 2) / (2 * (lambda_0 + count)); posteriorBeta.setParameterValue(0, beta_n); posteriorExpectedPrecision.setParameterValue(0, alpha_n / beta_n); logL = GammaFunction.logGamma(alpha_n) - GammaFunction.logGamma(alpha_0) + alpha_0 * Math.log(beta_0) - alpha_n * Math.log(beta_n) + 0.5 * Math.log(lambda_0) - 0.5 * Math.log(lambda_n) - (count / 2) * Math.log(2 * Math.PI); return logL; } public LogColumn[] getColumns() { ArrayList<LogColumn> columns = new ArrayList<LogColumn>(Arrays.asList(super.getColumns())); columns.add(new LogColumn.Abstract(getModelName() + "_posteriorMean") { protected String getFormattedValue() { return String.valueOf(posteriorMean.getParameterValue(0)); } }); columns.add(new LogColumn.Abstract(getModelName() + "_posteriorBeta") { protected String getFormattedValue() { return String.valueOf(posteriorBeta.getParameterValue(0)); } }); columns.add(new LogColumn.Abstract(getModelName() + "_posteriorExpectedPrecision") { protected String getFormattedValue() { return String.valueOf(posteriorExpectedPrecision.getParameterValue(0)); } }); return columns.toArray(new LogColumn[columns.size()]); } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return NORMAL; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { String id = (String) xo.getAttribute(ID); double mu = xo.getDoubleAttribute(MU); double lambda = xo.getDoubleAttribute(LAMBDA); double alpha = xo.getDoubleAttribute(ALPHA); double beta = xo.getDoubleAttribute(BETA); boolean log; log = xo.hasAttribute(LOG) ? xo.getBooleanAttribute(LOG) : false; return new NormalPeriodPriorDistribution(id, log, mu, lambda, alpha, beta); } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(LOG, true), AttributeRule.newStringRule(ID, false), AttributeRule.newDoubleRule(MU, false), AttributeRule.newDoubleRule(LAMBDA, false), AttributeRule.newDoubleRule(ALPHA, false), AttributeRule.newDoubleRule(BETA, false) }; public String getParserDescription() { return "Calculates the probability of a set of doubles being drawn from the prior posterior distribution" + "of a normal distribution of unknown mean and variance"; } public Class getReturnType() { return NormalPeriodPriorDistribution.class; } }; }