Java tutorial
/*----------------------------------------------------------------------------- * GDSC SMLM Software * * Copyright (C) 2014 Alex Herbert * Genome Damage and Stability Centre * University of Sussex, UK * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 3 of the License, or * (at your option) any later version. * * This code is based on the ideas expressed in Numerical Recipes in C++, * The Art of Scientific Computing, Second Edition, W.H. Press, * S.A. Teukolsky, W.T. Vetterling, B.P. Flannery (Cambridge University Press, * Cambridge, 2002). *---------------------------------------------------------------------------*/ package org.apache.commons.math3.optim.nonlinear.scalar.gradient; import java.util.Locale; import org.apache.commons.math3.exception.MathUnsupportedOperationException; import org.apache.commons.math3.exception.TooManyEvaluationsException; import org.apache.commons.math3.exception.util.Localizable; import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.OptimizationData; import org.apache.commons.math3.optim.PointValuePair; import org.apache.commons.math3.optim.PositionChecker; import org.apache.commons.math3.optim.nonlinear.scalar.GradientMultivariateOptimizer; import org.apache.commons.math3.util.FastMath; /** * Implementation of the Broyden-Fletcher-Goldfarb-Shanno (BFGS) variant of the Davidson-Fletcher-Powell (DFP) * minimisation. * <p> * This is not part of the Apache Commons Math library but extends the same base classes to allow an easy swap with * existing code based on the Apache library. * <p> * Note that although rare, it may happen that the algorithm converges since the search direction no longer leads * downhill. In case of doubt restarting the algorithm should overcome this issue. * <p> * The implementation is based upon that presented in: Numerical Recipes in C++, The Art of Scientific Computing, Second * Edition, W.H. Press, S.A. Teukolsky, W.T. Vetterling, B.P. Flannery (Cambridge University Press, Cambridge, 2002). * The algorithm has been updated to support a bounded search and convergence checking on position and gradient. */ public class BFGSOptimizer extends GradientMultivariateOptimizer { /** Maximum step length used in line search. */ private double[] maximumStepLength = null; /** Convergence tolerance on gradient */ private double gradientTolerance; /** Maximum number of restarts on the convergence point */ private int restarts = 0; /** Maximum number of restarts in the event of roundoff error */ private int roundoffRestarts = 3; /** Convergence tolerance on position */ private PositionChecker positionChecker = null; /** Flags to indicate if bounds are present */ private boolean isLower, isUpper; private double[] lower, upper; private double sign = 0; /** * Specify the maximum step length in each dimension */ public static class StepLength implements OptimizationData { private double[] step; /** * Build an instance * * @param step * The maximum step size in each dimension */ public StepLength(double[] step) { this.step = step; } public double[] getStep() { return step; } } /** * Specify the tolerance on the gradient convergence with zero */ public static class GradientTolerance implements OptimizationData { private double tolerance; /** * Build an instance * * @param tolerance * The tolerance on the gradient */ public GradientTolerance(double tolerance) { this.tolerance = tolerance; } public double getTolerance() { return tolerance; } } /** * Specify the maximum number of restarts on the converged point in the event that the gradient has not yet * converged on zero. */ public static class MaximumRestarts implements OptimizationData { private int restarts; /** * Build an instance * * @param restarts * The restarts on the gradient */ public MaximumRestarts(int restarts) { this.restarts = restarts; } public int getRestarts() { return restarts; } } /** * Specify the maximum number of restarts in the event of roundoff error. */ public static class MaximumRoundoffRestarts extends MaximumRestarts { /** * Build an instance * * @param restarts * The restarts on the gradient */ public MaximumRoundoffRestarts(int restarts) { super(restarts); } } /** * Constructor */ public BFGSOptimizer() { super(null); } /** * @param checker * Convergence checker. */ public BFGSOptimizer(ConvergenceChecker<PointValuePair> checker) { super(checker); } /** * {@inheritDoc} * * @param optData * Optimization data. In addition to those documented in * {@link GradientMultivariateOptimizer#parseOptimizationData(OptimizationData[]) * GradientMultivariateOptimizer}, this method will register the following data: * <ul> * <li>{@link MaximumStepLength}</li> * </ul> * @return {@inheritDoc} * @throws TooManyEvaluationsException * if the maximal number of * evaluations (of the objective function) is exceeded. */ @Override public PointValuePair optimize(OptimizationData... optData) throws TooManyEvaluationsException { // Set up base class and perform computation. return super.optimize(optData); } private int converged; private static final int CHECKER = 0; private static final int POSITION = 1; private static final int GRADIENT = 2; private static final int ROUNDOFF_ERROR = 3; /** {@inheritDoc} */ @Override protected PointValuePair doOptimize() { final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker(); double[] p = getStartPoint(); // Assume minimisation sign = -1; LineStepSearch lineSearch = new LineStepSearch(); // In case there are no restarts if (restarts <= 0) return bfgsWithRoundoffCheck(checker, p, lineSearch); PointValuePair lastResult = null; PointValuePair result = null; //int lastConverge = 0; int iteration = 0; //int initialConvergenceIteration = 0; //int[] count = new int[3]; while (iteration <= restarts) { iteration++; result = bfgsWithRoundoffCheck(checker, p, lineSearch); //count[converged]++; //if (lastResult == null) // initialConvergenceIteration = getIterations(); if (converged == GRADIENT) { // If no gradient remains then we cannot move anywhere so return break; } if (lastResult != null) { //// Check if the optimum was improved using the last convergence criteria //if (lastConverge == CHECKER) //{ // if (checker.converged(getIterations(), lastResult, result)) // { // break; // } //} //else //{ // if (positionChecker.converged(lastResult.getPointRef(), result.getPointRef())) // { // break; // } //} // Check if the optimum was improved using the convergence criteria if (checker != null && checker.converged(getIterations(), lastResult, result)) { break; } if (positionChecker.converged(lastResult.getPointRef(), result.getPointRef())) { break; } } // Store the new optimum and repeat lastResult = result; //lastConverge = converged; p = lastResult.getPointRef(); } //System.out.printf("Iter=%d (%d > %d): %s\n", iteration, initialConvergenceIteration, getIterations(), // java.util.Arrays.toString(count)); return result; } /** * Repeat the BFGS algorithm until it converges without roundoff error on the search direction * * @param checker * @param p * @param lineSearch * @return */ protected PointValuePair bfgsWithRoundoffCheck(ConvergenceChecker<PointValuePair> checker, double[] p, LineStepSearch lineSearch) { // Note: Position might converge if the hessian becomes singular or non-positive-definite // In this case the simple check is to restart the algorithm. int iteration = 0; PointValuePair result = bfgs(checker, p, lineSearch); // Allow restarts in the case of roundoff convergence while (converged == ROUNDOFF_ERROR && iteration < roundoffRestarts) { iteration++; p = result.getPointRef(); result = bfgs(checker, p, lineSearch); } // If restarts did not work then this is a failure if (converged == ROUNDOFF_ERROR) throw new LineSearchRoundoffException(); //if (iteration > 0) // System.out.printf("Restarts for roundoff error = %d\n", iteration); return result; } protected PointValuePair bfgs(ConvergenceChecker<PointValuePair> checker, double[] p, LineStepSearch lineSearch) { final int n = p.length; final double EPS = epsilon; double[] hdg = new double[n]; double[] xi = new double[n]; double[][] hessian = new double[n][n]; // Get the gradient for the the bounded point applyBounds(p); double[] g = computeObjectiveGradient(p); checkGradients(g, p); // Initialise the hessian and search direction for (int i = 0; i < n; i++) { hessian[i][i] = 1.0; xi[i] = -g[i]; } PointValuePair current = null; while (true) { incrementIterationCount(); // Get the value of the point double fp = computeObjectiveValue(p); if (checker != null) { PointValuePair previous = current; current = new PointValuePair(p, fp); if (previous != null && checker.converged(getIterations(), previous, current)) { // We have found an optimum. converged = CHECKER; return current; } } // Move along the search direction. final double[] pnew; try { pnew = lineSearch.lineSearch(p, fp, g, xi); } catch (LineSearchRoundoffException e) { // This can happen if the Hessian is nearly singular or non-positive-definite. // In this case the algorithm should be restarted. converged = ROUNDOFF_ERROR; //System.out.printf("Roundoff error, iter=%d\n", getIterations()); return new PointValuePair(p, fp); } // We assume the new point is on/within the bounds since the line search is constrained double fret = lineSearch.f; // Test for convergence on change in position if (positionChecker.converged(p, pnew)) { converged = POSITION; return new PointValuePair(pnew, fret); } // Update the line direction for (int i = 0; i < n; i++) { xi[i] = pnew[i] - p[i]; } p = pnew; // Save the old gradient double[] dg = g; // Get the gradient for the new point g = computeObjectiveGradient(p); checkGradients(g, p); // If necessary recompute the function value. // Doing this after the gradient evaluation allows the value to be cached when // computing the objective gradient fp = fret; // Test for convergence on zero gradient. double test = 0; for (int i = 0; i < n; i++) { final double temp = Math.abs(g[i]) * FastMath.max(Math.abs(p[i]), 1); //final double temp = Math.abs(g[i]); if (test < temp) test = temp; } // Compute the biggest gradient relative to the objective function test /= FastMath.max(Math.abs(fp), 1); if (test < gradientTolerance) { converged = GRADIENT; return new PointValuePair(p, fp); } for (int i = 0; i < n; i++) dg[i] = g[i] - dg[i]; for (int i = 0; i < n; i++) { hdg[i] = 0.0; for (int j = 0; j < n; j++) hdg[i] += hessian[i][j] * dg[j]; } double fac = 0, fae = 0, sumdg = 0, sumxi = 0; for (int i = 0; i < n; i++) { fac += dg[i] * xi[i]; fae += dg[i] * hdg[i]; sumdg += dg[i] * dg[i]; sumxi += xi[i] * xi[i]; } if (fac > Math.sqrt(EPS * sumdg * sumxi)) { fac = 1.0 / fac; final double fad = 1.0 / fae; for (int i = 0; i < n; i++) dg[i] = fac * xi[i] - fad * hdg[i]; for (int i = 0; i < n; i++) { for (int j = i; j < n; j++) { hessian[i][j] += fac * xi[i] * xi[j] - fad * hdg[i] * hdg[j] + fae * dg[i] * dg[j]; hessian[j][i] = hessian[i][j]; } } } for (int i = 0; i < n; i++) { xi[i] = 0.0; for (int j = 0; j < n; j++) xi[i] -= hessian[i][j] * g[j]; } } } /** * Scans the list of (required and optional) optimization data that * characterize the problem. * * @param optData * Optimization data. * The following data will be looked for: * <ul> * <li>{@link GradientChecker}</li> * <li>{@link PositionChecker}</li> * <li>{@link MaximumStepLength}</li> * </ul> */ @Override protected void parseOptimizationData(OptimizationData... optData) { // Allow base class to register its own data. super.parseOptimizationData(optData); // The existing values (as set by the previous call) are reused if // not provided in the argument list. for (OptimizationData data : optData) { if (data instanceof PositionChecker) { positionChecker = (PositionChecker) data; } else if (data instanceof StepLength) { maximumStepLength = ((StepLength) data).getStep(); } else if (data instanceof GradientTolerance) { gradientTolerance = ((GradientTolerance) data).getTolerance(); } else if (data instanceof MaximumRestarts) { restarts = ((MaximumRestarts) data).getRestarts(); } else if (data instanceof MaximumRoundoffRestarts) { roundoffRestarts = ((MaximumRoundoffRestarts) data).getRestarts(); } } checkParameters(); } /** * The minimum value between two doubles */ private static double epsilon = calculateMachineEpsilonDouble(); /** * @return The minimum value between two doubles * @see http://en.wikipedia.org/wiki/Machine_epsilon#Approximation_using_C.2B.2B */ private static double calculateMachineEpsilonDouble() { double machEps = 1.0; do machEps /= 2.0; while ((1.0 + (machEps / 2.0)) != 1.0); // ISO standard is 2^-52 = 2.220446049e-16 return machEps; } public static class LineSearchRoundoffException extends RuntimeException { private static final long serialVersionUID = -8974644703023090107L; private final double slope; public LineSearchRoundoffException(double slope) { super(); this.slope = slope; } public LineSearchRoundoffException() { super(); this.slope = 0; } @Override public String getMessage() { return (slope != 0) ? "Round-off problem. Slope = " + slope : "Round-off problem"; } } /** * Internal class for a line search with backtracking * <p> * Adapted from NR::lnsrch, as discussed in Numerical Recipes section 9.7. The algorithm has been changed to support * bounds on the point, limits on the search direction in all dimensions and checking for bad function evaluations * when backtracking. */ private class LineStepSearch { /** * Set to true when the the new point is too close to the old point. In a minimisation algorithm this signifies * convergence. */ @SuppressWarnings("unused") boolean check; /** * The function value at the new point */ double f; /** * Given an n-dimension point, the function value and gradient at that point find a new point * along the given search direction so that the function value has decreased sufficiently. * * @param xOld * The old point * @param fOld * The old point function value * @param gradient * The old point function gradient * @param searchDirection * The search direction * @return The new point * @throws LineSearchRoundoffException * if the slope of the line search is positive */ double[] lineSearch(double[] xOld, final double fOld, double[] gradient, double[] searchDirection) throws LineSearchRoundoffException { final double ALF = 1.0e-4, TOLX = epsilon; double alam2 = 0.0, f2 = 0.0; // New point double[] x = new double[xOld.length]; final int n = xOld.length; check = false; // Limit the search step size for each dimension if (maximumStepLength != null) { double scale = 1; for (int i = 0; i < n; i++) { if (Math.abs(searchDirection[i]) * scale > maximumStepLength[i]) scale = maximumStepLength[i] / Math.abs(searchDirection[i]); } if (scale < 1) { // Scale the entire search direction for (int i = 0; i < n; i++) searchDirection[i] *= scale; } } double slope = 0.0; for (int i = 0; i < n; i++) slope += gradient[i] * searchDirection[i]; if (slope >= 0.0) { throw new LineSearchRoundoffException(slope); } // Compute lambda min double test = 0.0; for (int i = 0; i < n; i++) { final double temp = Math.abs(searchDirection[i]) / FastMath.max(Math.abs(xOld[i]), 1.0); if (temp > test) test = temp; } double alamin = TOLX / test; // Always try the full step first double alam = 1.0; // Count the number of backtracking steps int backtracking = 0; for (;;) { if (alam < alamin) { // Convergence (insignificant step). // Since we use the old f and x then we do not need to compute the objective value check = true; f = fOld; //System.out.printf("alam %f < alamin %f\n", alam, alamin); return xOld; } for (int i = 0; i < n; i++) x[i] = xOld[i] + alam * searchDirection[i]; applyBounds(x); f = BFGSOptimizer.this.computeObjectiveValue(x); //System.out.printf("f=%f @ %f : %s\n", f, alam, java.util.Arrays.toString(x)); if (f <= fOld + ALF * alam * slope) { // Sufficient function decrease //System.out.printf("f=%f < %f\n", f, fOld + ALF * alam * slope); return x; } else { // Check for bad function evaluation if (f == Double.POSITIVE_INFINITY) { // Reset backtracking backtracking = 0; alam *= 0.1; continue; } // Backtrack double tmplam; if (backtracking++ == 0) { // First backtrack iteration tmplam = -slope / (2.0 * (f - fOld - slope)); // Ensure the lambda is reduced, i.e. we take a step smaller than last time if (tmplam > 0.9 * alam) tmplam = 0.9 * alam; } else { // Subsequent backtracks final double rhs1 = f - fOld - alam * slope; final double rhs2 = f2 - fOld - alam2 * slope; final double a = (rhs1 / (alam * alam) - rhs2 / (alam2 * alam2)) / (alam - alam2); final double b = (-alam2 * rhs1 / (alam * alam) + alam * rhs2 / (alam2 * alam2)) / (alam - alam2); if (a == 0.0) tmplam = -slope / (2.0 * b); else { final double disc = b * b - 3.0 * a * slope; if (disc < 0.0) tmplam = 0.5 * alam; else if (b <= 0.0) tmplam = (-b + Math.sqrt(disc)) / (3.0 * a); else tmplam = -slope / (b + Math.sqrt(disc)); } // Ensure the lambda is <= 0.5 lamda1, i.e. we take a step smaller than last time if (tmplam > 0.5 * alam) tmplam = 0.5 * alam; } alam2 = alam; f2 = f; // Ensure the lambda is >= 0.1 lamda1, i.e. we take reasonable step alam = FastMath.max(tmplam, 0.1 * alam); } } } } /** * Checks if there are lower or upper bounds that are not -Infinity or +Infinity * * @throws MathUnsupportedOperationException * if invalid bounds were passed to the {@link #optimize(OptimizationData[]) optimize} method. */ private void checkParameters() { lower = getLowerBound(); upper = getUpperBound(); isLower = checkArray(lower, Double.NEGATIVE_INFINITY); isUpper = checkArray(upper, Double.POSITIVE_INFINITY); // Check that the upper bound is above the lower bound if (isUpper && isLower) { for (int i = 0; i < lower.length; i++) if (lower[i] > upper[i]) throw new MathUnsupportedOperationException( createError("Lower bound must be below upper bound")); } // Numerical Recipes set the position convergence very low if (positionChecker == null) positionChecker = new PositionChecker(4 * epsilon, 0); // Ensure that the step length is strictly positive if (maximumStepLength == null) { for (int i = 0; i < maximumStepLength.length; i++) { if (maximumStepLength[i] <= 0) throw new MathUnsupportedOperationException( createError("Maximum step length must be strictly positive")); } } // Set a tolerance? If not then the routine will iterate until position convergence //if (gradientTolerance == 0) // gradientTolerance = 1e-6; } private Localizable createError(final String message) { return new Localizable() { private static final long serialVersionUID = 1L; public String getSourceString() { return message; } public String getLocalizedString(Locale locale) { return message; } }; } /** * Check if the array contains anything other than value * * @param array * @param value * @return True if the array has another value */ private boolean checkArray(double[] array, double value) { if (array == null) return false; for (double v : array) if (v != value) return true; return false; } /** * Check the point falls within the configured bounds truncating if necessary * * @param point * @return true if the point was truncated */ private boolean applyBounds(double[] point) { boolean truncated = false; if (isUpper) { for (int i = 0; i < point.length; i++) if (point[i] > upper[i]) { point[i] = upper[i]; truncated = true; } } if (isLower) { for (int i = 0; i < point.length; i++) if (point[i] < lower[i]) { point[i] = lower[i]; truncated = true; } } return truncated; } /** * Check if the point falls on or outside configured bounds truncating the gradient to zero * if it is moving further outside the bounds * * @param r * @param point */ private void checkGradients(double[] r, double[] point) { checkGradients(r, point, sign); } /** * Check if the point falls on or outside configured bounds truncating the gradient to zero * if it is moving further outside the bounds (defined by the sign of the search direction) * * @param r * @param point * @param sign */ private void checkGradients(double[] r, double[] point, final double sign) { if (isUpper) { for (int i = 0; i < point.length; i++) if (point[i] >= upper[i] && Math.signum(r[i]) == sign) r[i] = 0; } if (isLower) { for (int i = 0; i < point.length; i++) if (point[i] <= lower[i] && Math.signum(r[i]) == -sign) r[i] = 0; } //boolean isNaN = false; //for (int i = 0; i < point.length; i++) // if (Double.isNaN(r[i])) // { // isNaN = true; // r[i] = 0; // } } }