com.opengamma.analytics.math.minimization.QuasiNewtonVectorMinimizer.java Source code

Java tutorial

Introduction

Here is the source code for com.opengamma.analytics.math.minimization.QuasiNewtonVectorMinimizer.java

Source

/**
 * Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies
 * 
 * Please see distribution for license.
 */
package com.opengamma.analytics.math.minimization;

import org.apache.commons.lang.NotImplementedException;

import com.opengamma.analytics.math.MathException;
import com.opengamma.analytics.math.function.Function1D;
import com.opengamma.analytics.math.matrix.DoubleMatrix1D;
import com.opengamma.analytics.math.matrix.DoubleMatrix2D;
import com.opengamma.analytics.math.matrix.DoubleMatrixUtils;
import com.opengamma.analytics.math.matrix.MatrixAlgebra;
import com.opengamma.analytics.math.matrix.MatrixAlgebraFactory;
import com.opengamma.util.ArgumentChecker;

/**
 * 
 */
public class QuasiNewtonVectorMinimizer implements
        MinimizerWithGradient<Function1D<DoubleMatrix1D, Double>, Function1D<DoubleMatrix1D, DoubleMatrix1D>, DoubleMatrix1D> {

    private static final int RESET_FREQ = 200;
    private static final double ALPHA = 1e-4;
    private static final double BETA = 1.5;
    private static final double EPS = 1e-8;
    private static final int DEF_MAX_STEPS = 200;
    private static final MatrixAlgebra MA = MatrixAlgebraFactory.getMatrixAlgebra("OG");
    private static final QuasiNewtonInverseHessianUpdate DEF_UPDATER = new BroydenFletcherGoldfarbShannoInverseHessianUpdate();

    private final double _absoluteTol, _relativeTol;
    private final int _maxSteps;
    private final QuasiNewtonInverseHessianUpdate _hessainUpdater;

    public QuasiNewtonVectorMinimizer() {
        this(EPS, EPS, DEF_MAX_STEPS);
    }

    public QuasiNewtonVectorMinimizer(final double absTolerance, final double relTolerance,
            final int maxInterations) {
        this(absTolerance, relTolerance, maxInterations, DEF_UPDATER);
    }

    public QuasiNewtonVectorMinimizer(final double absoluteTol, final double relativeTol, final int maxInterations,
            final QuasiNewtonInverseHessianUpdate hessianUpdater) {
        ArgumentChecker.notNull(hessianUpdater, "null updater");
        ArgumentChecker.notNegative(absoluteTol, "absolute tolerance");
        ArgumentChecker.notNegative(relativeTol, "relative tolerance");
        ArgumentChecker.notNegative(maxInterations, "maxSteps");
        _absoluteTol = absoluteTol;
        _relativeTol = relativeTol;
        _maxSteps = maxInterations;
        _hessainUpdater = hessianUpdater;
    }

    /**
     * Disabled because not working properly (see JIRA issue)
     * @param function The function
     * @param startPosition The start position
     * @return The minimum
     */
    @Override
    public DoubleMatrix1D minimize(final Function1D<DoubleMatrix1D, Double> function,
            final DoubleMatrix1D startPosition) {
        throw new NotImplementedException("Please supply gradient function or use ConjugateGradient");
    }

    @Override
    public DoubleMatrix1D minimize(final Function1D<DoubleMatrix1D, Double> function,
            final Function1D<DoubleMatrix1D, DoubleMatrix1D> grad, final DoubleMatrix1D startPosition) {
        final DataBundle data = new DataBundle();
        final double y = function.evaluate(startPosition);
        data.setX(startPosition);
        data.setG0(y * y);
        data.setGrad(grad.evaluate(startPosition));
        data.setInverseHessianEsimate(getInitializedMatrix(startPosition));

        if (!getNextPosition(function, grad, data)) {
            throw new MathException("Cannot work with this starting position. Please choose another point");
        }

        int count = 0;
        int resetCount = 1;

        while (!isConverged(data)) {
            if ((resetCount) % RESET_FREQ == 0) {
                data.setInverseHessianEsimate(getInitializedMatrix(startPosition));
                resetCount = 1;
            } else {
                _hessainUpdater.update(data);
            }
            if (!getNextPosition(function, grad, data)) {
                data.setInverseHessianEsimate(getInitializedMatrix(startPosition));
                resetCount = 1;
                if (!getNextPosition(function, grad, data)) {
                    throw new MathException("Failed to converge in backtracking");
                }
            }
            count++;
            resetCount++;
            if (count > _maxSteps) {
                throw new MathException("Failed to converge after " + _maxSteps
                        + " iterations. Final point reached: " + data.getX().toString());
            }
        }
        return data.getX();
    }

    private DoubleMatrix2D getInitializedMatrix(final DoubleMatrix1D startPosition) {
        return DoubleMatrixUtils.getIdentityMatrix2D(startPosition.getNumberOfElements());
    }

    private DoubleMatrix1D getDirection(final DataBundle data) {
        return (DoubleMatrix1D) MA.multiply(data.getInverseHessianEsimate(), MA.scale(data.getGrad(), -1.0));
    }

    private boolean getNextPosition(final Function1D<DoubleMatrix1D, Double> function,
            final Function1D<DoubleMatrix1D, DoubleMatrix1D> grad, final DataBundle data) {
        final DoubleMatrix1D p = getDirection(data);
        if (data.getLambda0() < 1.0) {
            data.setLambda0(1.0);
        } else {
            data.setLambda0(data.getLambda0() * BETA);
        }
        updatePosition(p, function, data);
        final double g1 = data.getG1();
        // the function is invalid at the new position, try to recover
        if (Double.isInfinite(g1) || Double.isNaN(g1)) {
            bisectBacktrack(p, function, data);
        }
        if (data.getG1() > data.getG0() / (1 + ALPHA * data.getLambda0())) {
            quadraticBacktrack(p, function, data);
            int count = 0;
            while (data.getG1() > data.getG0() / (1 + ALPHA * data.getLambda0())) {
                if (count > 5) {
                    return false;
                }
                cubicBacktrack(p, function, data);
                count++;
            }
        }
        final DoubleMatrix1D deltaX = data.getDeltaX();
        data.setX((DoubleMatrix1D) MA.add(data.getX(), deltaX));
        data.setG0(data.getG1());
        final DoubleMatrix1D gradNew = grad.evaluate(data.getX());
        data.setDeltaGrad((DoubleMatrix1D) MA.subtract(gradNew, data.getGrad()));
        data.setGrad(gradNew);
        return true;
    }

    protected void updatePosition(final DoubleMatrix1D p, final Function1D<DoubleMatrix1D, Double> function,
            final DataBundle data) {
        final double lambda0 = data.getLambda0();
        final DoubleMatrix1D deltaX = (DoubleMatrix1D) MA.scale(p, lambda0);
        final DoubleMatrix1D xNew = (DoubleMatrix1D) MA.add(data.getX(), deltaX);
        data.setDeltaX(deltaX);
        data.setG2(data.getG1());
        final double y = function.evaluate(xNew);
        data.setG1(y * y);
    }

    private void bisectBacktrack(final DoubleMatrix1D p, final Function1D<DoubleMatrix1D, Double> function,
            final DataBundle data) {
        do {
            data.setLambda0(data.getLambda0() * 0.1);
            updatePosition(p, function, data);
        } while (Double.isNaN(data.getG1()) || Double.isInfinite(data.getG1()) || Double.isNaN(data.getG2())
                || Double.isInfinite(data.getG2()));
    }

    private void quadraticBacktrack(final DoubleMatrix1D p, final Function1D<DoubleMatrix1D, Double> function,
            final DataBundle data) {
        final double lambda0 = data.getLambda0();
        final double g0 = data.getG0();
        final double lambda = Math.max(0.01 * lambda0,
                g0 * lambda0 * lambda0 / (data.getG1() + g0 * (2 * lambda0 - 1)));
        data.swapLambdaAndReplace(lambda);
        updatePosition(p, function, data);
    }

    private void cubicBacktrack(final DoubleMatrix1D p, final Function1D<DoubleMatrix1D, Double> function,
            final DataBundle data) {
        double temp1, temp2, temp3, temp4, temp5;
        final double lambda0 = data.getLambda0();
        final double lambda1 = data.getLambda1();
        final double g0 = data.getG0();
        temp1 = 1.0 / lambda0 / lambda0;
        temp2 = 1.0 / lambda1 / lambda1;
        temp3 = data.getG1() + g0 * (2 * lambda0 - 1.0);
        temp4 = data.getG2() + g0 * (2 * lambda1 - 1.0);
        temp5 = 1.0 / (lambda0 - lambda1);
        final double a = temp5 * (temp1 * temp3 - temp2 * temp4);
        final double b = temp5 * (-lambda1 * temp1 * temp3 + lambda0 * temp2 * temp4);
        double lambda = (-b + Math.sqrt(b * b + 6 * a * g0)) / 3 / a;
        lambda = Math.min(Math.max(lambda, 0.01 * lambda0), 0.75 * lambda1); // make sure new lambda is between 1% & 75% of old value
        data.swapLambdaAndReplace(lambda);
        updatePosition(p, function, data);
    }

    private boolean isConverged(final DataBundle data) {
        final DoubleMatrix1D deltaX = data.getDeltaX();
        final DoubleMatrix1D x = data.getX();
        final int n = deltaX.getNumberOfElements();
        double diff, scale;
        for (int i = 0; i < n; i++) {
            diff = Math.abs(deltaX.getEntry(i));
            scale = Math.abs(x.getEntry(i));
            if (diff > _absoluteTol + scale * _relativeTol) {
                return false;
            }
        }
        return (MA.getNorm2(data.getGrad()) < _absoluteTol);
    }

    /**
     * Data bundle for intermediate data
     */
    public static class DataBundle {
        private double _g0;
        private double _g1;
        private double _g2;
        private double _lambda0;
        private double _lambda1;
        private DoubleMatrix1D _deltaGrad;
        private DoubleMatrix1D _grad;
        private DoubleMatrix1D _deltaX;
        private DoubleMatrix1D _x;
        private DoubleMatrix2D _h;

        public double getG0() {
            return _g0;
        }

        public double getG1() {
            return _g1;
        }

        public double getG2() {
            return _g2;
        }

        public double getLambda0() {
            return _lambda0;
        }

        public double getLambda1() {
            return _lambda1;
        }

        public DoubleMatrix1D getDeltaGrad() {
            return _deltaGrad;
        }

        public DoubleMatrix1D getGrad() {
            return _grad;
        }

        public DoubleMatrix1D getDeltaX() {
            return _deltaX;
        }

        public DoubleMatrix1D getX() {
            return _x;
        }

        public void setG0(final double g0) {
            _g0 = g0;
        }

        public void setG1(final double g1) {
            _g1 = g1;
        }

        public void setG2(final double g2) {
            _g2 = g2;
        }

        public void setLambda0(final double lambda0) {
            _lambda0 = lambda0;
        }

        public void setDeltaGrad(final DoubleMatrix1D deltaGrad) {
            _deltaGrad = deltaGrad;
        }

        public void setGrad(final DoubleMatrix1D grad) {
            _grad = grad;
        }

        public void setDeltaX(final DoubleMatrix1D deltaX) {
            _deltaX = deltaX;
        }

        public void setX(final DoubleMatrix1D x) {
            _x = x;
        }

        /**
         * Inverse Hessian matrix 
         * @return The inverse Hessian Matrix
         */
        public DoubleMatrix2D getInverseHessianEsimate() {
            return _h;
        }

        public void setInverseHessianEsimate(final DoubleMatrix2D estimate) {
            _h = estimate;
        }

        public void swapLambdaAndReplace(final double lambda0) {
            _lambda1 = _lambda0;
            _lambda0 = lambda0;
        }
    }

}