Sums an array of numbers log(x1)...log(xn)
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
//package cc.mallet.util;
/**
*
*
* @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a>
* @version $Id: ArrayUtils.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $
*/
public class Util {
private static final double LOGTOLERANCE = 30.0;
/**
* Sums an array of numbers log(x1)...log(xn). This saves some of
* the unnecessary calls to Math.log in the two-argument version.
* <p>
* Note that this implementation IGNORES elements of the input
* array that are more than LOGTOLERANCE (currently 30.0) less
* than the maximum element.
* <p>
* Cursory testing makes me wonder if this is actually much faster than
* repeated use of the 2-argument version, however -cas.
* @param vals An array log(x1), log(x2), ..., log(xn)
* @return log(x1+x2+...+xn)
*/
public static double sumLogProb (double[] vals)
{
double max = Double.NEGATIVE_INFINITY;
int len = vals.length;
int maxidx = 0;
for (int i = 0; i < len; i++) {
if (vals[i] > max) {
max = vals[i];
maxidx = i;
}
}
boolean anyAdded = false;
double intermediate = 0.0;
double cutoff = max - LOGTOLERANCE;
for (int i = 0; i < maxidx; i++) {
if (vals[i] >= cutoff) {
anyAdded = true;
intermediate += Math.exp(vals[i] - max);
}
}
for (int i = maxidx + 1; i < len; i++) {
if (vals[i] >= cutoff) {
anyAdded = true;
intermediate += Math.exp(vals[i] - max);
}
}
if (anyAdded) {
return max + Math.log(1.0 + intermediate);
} else {
return max;
}
}
}
Related examples in the same category