com.anhth12.optimize.solvers.BackTrackLineSearch.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.optimize.solvers.BackTrackLineSearch.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.optimize.solvers;

import com.anhth12.nn.api.Model;
import com.anhth12.nn.exception.InvalidStepException;
import com.anhth12.optimize.api.LineOptimizer;
import com.anhth12.optimize.api.StepFunction;
import com.anhth12.optimize.stepfunction.DefaultStepFunction;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.conditions.Or;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 *
 * @author anhth12
 */
public class BackTrackLineSearch implements LineOptimizer {

    private static Logger logger = LoggerFactory.getLogger(BackTrackLineSearch.class.getName());

    Model function;
    StepFunction stepFunction = new DefaultStepFunction();
    BaseOptimizer optimizer;
    final int maxIterations = 100;
    double stpmax = 100;
    final double EPS = 3.0e-12f;

    // termination conditions: either
    //   a) abs(delta x/x) < REL_TOLX for all coordinates
    //   b) abs(delta x) < ABS_TOLX for all coordinates
    //   c) sufficient function increase (uses ALF)
    private double relTolx = 1e-10f;
    private double absTolx = 1e-4f; // tolerance on absolute value difference
    final double ALF = 1e-4f;

    public BackTrackLineSearch(Model function, StepFunction stepFunction, BaseOptimizer optimizer) {
        this.function = function;
        this.stepFunction = stepFunction;
        this.optimizer = optimizer;
    }

    public BackTrackLineSearch(Model optimizable, BaseOptimizer optimizer) {
        this(optimizable, new DefaultStepFunction(), optimizer);
    }

    public void setStpmax(double stpmax) {
        this.stpmax = stpmax;
    }

    public double getStpmax() {
        return stpmax;
    }

    /**
     * Sets the tolerance of relative diff in function value.
     *  Line search converges if abs(delta x / x) < tolx
     *  for all coordinates. */
    public void setRelTolx(double tolx) {
        relTolx = tolx;
    }

    /**
     * Sets the tolerance of absolute diff in function value.
     *  Line search converges if abs(delta x) < tolx
     *  for all coordinates. */
    public void setAbsTolx(double tolx) {
        absTolx = tolx;
    }

    // initialStep is ignored.  This is b/c if the initial step is not 1.0,
    //   it sometimes confuses the backtracking for reasons I don't
    //   understand.  (That is, the jump gets LARGER on iteration 1.)

    // returns fraction of step size (alam) if found a good step
    // returns 0.0 if could not step in direction
    public double optimize(INDArray line, int lineSearchIteration, double initialStep, INDArray x, INDArray g)
            throws InvalidStepException {
        INDArray oldParameters;
        double slope, test, alamin, alam, alam2, tmplam;
        double rhs1, rhs2, a, b, disc, oldAlam;
        double f, fold, f2;
        oldParameters = x.dup();

        alam2 = 0.0;
        f2 = fold = optimizer.score();
        if (logger.isDebugEnabled()) {
            logger.trace("ENTERING BACKTRACK\n");
            logger.trace("Entering BackTrackLinnSearch, value = " + fold + ",\ndirection.oneNorm:"
                    + line.norm1(Integer.MAX_VALUE) + "  direction.infNorm:" + FastMath.max(Float.NEGATIVE_INFINITY,
                            Transforms.abs(line).max(Integer.MAX_VALUE).getDouble(0)));
        }

        BooleanIndexing.applyWhere(g, new Or(Conditions.isNan(), Conditions.isInfinite()),
                new Value(Nd4j.EPS_THRESHOLD));
        LinAlgExceptions.assertValidNum(g);
        double sum = line.norm2(Integer.MAX_VALUE).getDouble(0);
        if (sum > stpmax) {
            logger.warn("attempted step too big. scaling: sum= " + sum + ", stpmax= " + stpmax);
            line.muli(stpmax / sum);
        }

        //dot product
        slope = Nd4j.getBlasWrapper().dot(g, line);
        logger.debug("slope = " + slope);

        if (slope < 0)
            throw new InvalidStepException("Slope = " + slope + " is negative");

        if (slope == 0)
            throw new InvalidStepException("Slope = " + slope + " is zero");

        // find maximum lambda
        // converge when (delta x) / x < REL_TOLX for all coordinates.
        //  the largest step size that triggers this threshold is
        //  precomputed and saved in alamin
        INDArray maxOldParams = Transforms.abs(oldParameters);
        BooleanIndexing.applyWhere(maxOldParams, new Condition() {
            @Override
            public Boolean apply(Number input) {
                return input.doubleValue() < 1.0;
            }

            @Override
            public Boolean apply(IComplexNumber input) {
                return false;
            }
        }, new Value(1.0));

        INDArray testMatrix = Transforms.abs(line).divi(maxOldParams);
        test = testMatrix.max(Integer.MAX_VALUE).getDouble(0);
        //no longer needed
        testMatrix = null;
        alamin = relTolx / test;

        alam = 1.0f;
        oldAlam = 0.0f;
        int iteration;
        // look for step size in direction given by "line"
        for (iteration = 0; iteration < maxIterations; iteration++) {
            // x = oldParameters + alam*line
            // initially, alam = 1.0, i.e. take full Newton step
            logger.trace("BackTrack loop iteration " + iteration + " : alam=" + alam + " oldAlam=" + oldAlam);
            logger.trace("before step, x.1norm: " + x.norm1(Integer.MAX_VALUE) + "\nalam: " + alam + "\noldAlam: "
                    + oldAlam);
            assert (alam != oldAlam) : "alam == oldAlam";

            if (stepFunction == null)
                stepFunction = new DefaultStepFunction();
            stepFunction.step(x, line, new Object[] { alam, oldAlam }); //step

            double norm1 = x.norm1(Integer.MAX_VALUE).getDouble(0);
            logger.debug("after step, x.1norm: " + norm1);

            // check for convergence
            //convergence on delta x
            if ((alam < alamin) || smallAbsDiff(oldParameters, x)) {
                function.setParams(oldParameters);
                f = function.score();
                logger.trace("EXITING BACKTRACK: Jump too small (alamin = " + alamin
                        + "). Exiting and using xold. Value = " + f);
                return 0.0f;
            }

            function.setParams(x);
            oldAlam = alam;
            f = function.score();

            logger.debug("value = " + f);

            // sufficient function increase (Wolf condition)
            if (f >= fold + ALF * alam * slope) {

                logger.debug("EXITING BACKTRACK: value=" + f);

                if (f < fold)
                    throw new IllegalStateException(
                            "Function did not increase: f = " + f + " < " + fold + " = fold");
                return alam;
            }

            // if value is infinite, i.e. we've
            // jumped to unstable territory, then scale down jump
            else if (Double.isInfinite(f) || Double.isInfinite(f2)) {
                logger.warn("Value is infinite after jump " + oldAlam + ". f=" + f + ", f2=" + f2
                        + ". Scaling back step size...");
                tmplam = .2f * alam;
                if (alam < alamin) { //convergence on delta x
                    function.setParams(oldParameters);
                    f = function.score();
                    logger.warn("EXITING BACKTRACK: Jump too small. Exiting and using xold. Value=" + f);
                    return 0.0f;
                }
            } else { // backtrack
                if (alam == 1.0) // first time through
                    tmplam = -slope / (2.0f * (f - fold - slope));
                else {
                    rhs1 = f - fold - alam * slope;
                    rhs2 = f2 - fold - alam2 * slope;
                    if ((alam - alam2) == 0)
                        throw new IllegalStateException("FAILURE: dividing by alam-alam2. alam=" + alam);
                    a = (rhs1 / (FastMath.pow(alam, 2)) - rhs2 / (FastMath.pow(alam2, 2))) / (alam - alam2);
                    b = (-alam2 * rhs1 / (alam * alam) + alam * rhs2 / (alam2 * alam2)) / (alam - alam2);
                    if (a == 0.0)
                        tmplam = -slope / (2.0f * b);
                    else {
                        disc = b * b - 3.0f * a * slope;
                        if (disc < 0.0) {
                            tmplam = .5f * alam;
                        } else if (b <= 0.0)
                            tmplam = (-b + FastMath.sqrt(disc)) / (3.0f * a);
                        else
                            tmplam = -slope / (b + FastMath.sqrt(disc));
                    }
                    if (tmplam > .5f * alam)
                        tmplam = .5f * alam; // lambda <= .5 lambda_1
                }
            }

            alam2 = alam;
            f2 = f;
            logger.debug("tmplam:" + tmplam);
            alam = Math.max(tmplam, .1f * alam); // lambda >= .1*Lambda_1

        }

        return 0.0f;
    }

    // returns true iff we've converged based on absolute x difference
    private boolean smallAbsDiff(INDArray x, INDArray xold) {

        for (int i = 0; i < x.length(); i++) {
            double comp = Math.abs(x.getDouble(i) - xold.getDouble(i));
            if (comp > absTolx) {
                return false;
            }
        }
        return true;
    }

}