Java tutorial
/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.optimize.solvers; import org.apache.commons.math3.util.FastMath; import org.deeplearning4j.exception.InvalidStepException; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.stepfunctions.NegativeGradientStepFunction; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.LineOptimizer; import org.deeplearning4j.optimize.api.StepFunction; import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction; import org.nd4j.linalg.api.blas.Level1; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue; import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.functions.Value; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import static org.nd4j.linalg.ops.transforms.Transforms.abs; // "Line Searches and Backtracking", p385, "Numeric Recipes in C" /** @author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a> Adapted from mallet with original authors above. Modified to be a vectorized version that uses jblas matrices for computation rather than the mallet ops. Numerical Recipes in C: p.385. lnsrch. A simple backtracking line search. No attempt at accurately finding the true minimum is made. The goal is only to ensure that BackTrackLineSearch will return a position of higher value. @author Adam Gibson */ public class BackTrackLineSearch implements LineOptimizer { private static final Logger log = LoggerFactory.getLogger(BackTrackLineSearch.class); private Model layer; private StepFunction stepFunction; private ConvexOptimizer optimizer; private int maxIterations; double stepMax = 100; private boolean minObjectiveFunction = true; // termination conditions: either // a) abs(delta x/x) < REL_TOLX for all coordinates // b) abs(delta x) < ABS_TOLX for all coordinates // c) sufficient function increase (uses ALF) private double relTolx = 1e-7f; private double absTolx = 1e-4f; // tolerance on absolute value difference protected final double ALF = 1e-4f; /** * @param layer * @param stepFunction * @param optimizer */ public BackTrackLineSearch(Model layer, StepFunction stepFunction, ConvexOptimizer optimizer) { this.layer = layer; this.stepFunction = stepFunction; this.optimizer = optimizer; this.maxIterations = layer.conf().getMaxNumLineSearchIterations(); } /** * @param optimizable * @param optimizer */ public BackTrackLineSearch(Model optimizable, ConvexOptimizer optimizer) { this(optimizable, new NegativeDefaultStepFunction(), optimizer); } public void setStepMax(double stepMax) { this.stepMax = stepMax; } public double getStepMax() { return stepMax; } /** * Sets the tolerance of relative diff in function value. * Line search converges if abs(delta x / x) < tolx * for all coordinates. */ public void setRelTolx(double tolx) { relTolx = tolx; } /** * Sets the tolerance of absolute diff in function value. * Line search converges if abs(delta x) < tolx * for all coordinates. */ public void setAbsTolx(double tolx) { absTolx = tolx; } public int getMaxIterations() { return maxIterations; } public void setMaxIterations(int maxIterations) { this.maxIterations = maxIterations; } public double setScoreFor(INDArray parameters) { if (Nd4j.ENFORCE_NUMERICAL_STABILITY) { BooleanIndexing.applyWhere(parameters, Conditions.isNan(), new Value(Nd4j.EPS_THRESHOLD)); } layer.setParams(parameters); layer.computeGradientAndScore(); return layer.score(); } // returns fraction of step size if found a good step // returns 0.0 if could not step in direction // step == alam and score == f in book /** * @param parameters the parameters to optimize * @param gradients the line/rate of change * @param searchDirection the point for the line search to go in * @return the next step size * @throws InvalidStepException */ @Override public double optimize(INDArray parameters, INDArray gradients, INDArray searchDirection) throws InvalidStepException { double test, stepMin, step, step2, oldStep, tmpStep; double rhs1, rhs2, a, b, disc, score, scoreAtStart, score2; minObjectiveFunction = (stepFunction instanceof NegativeDefaultStepFunction || stepFunction instanceof NegativeGradientStepFunction); Level1 l1Blas = Nd4j.getBlasWrapper().level1(); double sum = l1Blas.nrm2(searchDirection); double slope = -1f * Nd4j.getBlasWrapper().dot(searchDirection, gradients); log.debug("slope = {}", slope); INDArray maxOldParams = abs(parameters); Nd4j.getExecutioner().exec(new ScalarSetValue(maxOldParams, 1)); INDArray testMatrix = abs(gradients).divi(maxOldParams); test = testMatrix.max(Integer.MAX_VALUE).getDouble(0); step = 1.0; // initially, step = 1.0, i.e. take full Newton step stepMin = relTolx / test; // relative convergence tolerance oldStep = 0.0; step2 = 0.0; score = score2 = scoreAtStart = layer.score(); double bestScore = score; double bestStepSize = 1.0; if (log.isTraceEnabled()) { double norm1 = l1Blas.asum(searchDirection); int infNormIdx = l1Blas.iamax(searchDirection); double infNorm = FastMath.max(Float.NEGATIVE_INFINITY, searchDirection.getDouble(infNormIdx)); log.trace("ENTERING BACKTRACK\n"); log.trace("Entering BackTrackLineSearch, value = " + scoreAtStart + ",\ndirection.oneNorm:" + norm1 + " direction.infNorm:" + infNorm); } if (sum > stepMax) { log.warn("Attempted step too big. scaling: sum= {}, stepMax= {}", sum, stepMax); searchDirection.muli(stepMax / sum); } // if (slope >= 0.0) { // throw new InvalidStepException("Slope " + slope + " is >= 0.0. Expect slope < 0.0 when minimizing objective function"); // } // find maximum lambda // converge when (delta x) / x < REL_TOLX for all coordinates. // the largest step size that triggers this threshold is precomputed and saved in stepMin // look for step size in direction given by "line" INDArray candidateParameters = null; for (int iteration = 0; iteration < maxIterations; iteration++) { if (log.isTraceEnabled()) { log.trace("BackTrack loop iteration {} : step={}, oldStep={}", iteration, step, oldStep); log.trace("before step, x.1norm: {} \nstep: {} \noldStep: {}", parameters.norm1(Integer.MAX_VALUE), step, oldStep); } if (step == oldStep) throw new IllegalArgumentException("Current step == oldStep"); // step candidateParameters = parameters.dup('f'); stepFunction.step(candidateParameters, searchDirection, step); oldStep = step; if (log.isTraceEnabled()) { double norm1 = l1Blas.asum(candidateParameters); log.trace("after step, x.1norm: " + norm1); } // check for convergence on delta x if ((step < stepMin) || Nd4j.getExecutioner() .execAndReturn(new Eps(parameters, candidateParameters, Shape.toOffsetZeroCopy(candidateParameters, 'f'), candidateParameters.length())) .sum(Integer.MAX_VALUE).getDouble(0) == candidateParameters.length()) { score = setScoreFor(parameters); log.debug( "EXITING BACKTRACK: Jump too small (stepMin = {}). Exiting and using original params. Score = {}", stepMin, score); return 0.0; } score = setScoreFor(candidateParameters); log.debug("Model score after step = {}", score); //Score best step size for use if we terminate on maxIterations if ((minObjectiveFunction && score < bestScore) || (!minObjectiveFunction && score > bestScore)) { bestScore = score; bestStepSize = step; } //Sufficient decrease in cost/loss function (Wolfe condition / Armijo condition) if (minObjectiveFunction && score <= scoreAtStart + ALF * step * slope) { log.debug( "Sufficient decrease (Wolfe cond.), exiting backtrack on iter {}: score={}, scoreAtStart={}", iteration, score, scoreAtStart); if (score > scoreAtStart) throw new IllegalStateException( "Function did not decrease: score = " + score + " > " + scoreAtStart + " = oldScore"); return step; } //Sufficient increase in cost/loss function (Wolfe condition / Armijo condition) if (!minObjectiveFunction && score >= scoreAtStart + ALF * step * slope) { log.debug("Sufficient increase (Wolfe cond.), exiting backtrack on iter {}: score={}, bestScore={}", iteration, score, scoreAtStart); if (score < scoreAtStart) throw new IllegalStateException("Function did not increase: score = " + score + " < " + scoreAtStart + " = scoreAtStart"); return step; } // if value is infinite, i.e. we've jumped to unstable territory, then scale down jump else if (Double.isInfinite(score) || Double.isInfinite(score2) || Double.isNaN(score) || Double.isNaN(score2)) { log.warn("Value is infinite after jump. oldStep={}. score={}, score2={}. Scaling back step size...", oldStep, score, score2); tmpStep = .2 * step; if (step < stepMin) { //convergence on delta x score = setScoreFor(parameters); log.warn( "EXITING BACKTRACK: Jump too small (step={} < stepMin={}). Exiting and using previous parameters. Value={}", step, stepMin, score); return 0.0; } } // backtrack else if (minObjectiveFunction) { if (step == 1.0) // first time through tmpStep = -slope / (2.0 * (score - scoreAtStart - slope)); else { rhs1 = score - scoreAtStart - step * slope; rhs2 = score2 - scoreAtStart - step2 * slope; if (step == step2) throw new IllegalStateException( "FAILURE: dividing by step-step2 which equals 0. step=" + step); double stepSquared = step * step; double step2Squared = step2 * step2; a = (rhs1 / stepSquared - rhs2 / step2Squared) / (step - step2); b = (-step2 * rhs1 / stepSquared + step * rhs2 / step2Squared) / (step - step2); if (a == 0.0) tmpStep = -slope / (2.0 * b); else { disc = b * b - 3.0 * a * slope; if (disc < 0.0) { tmpStep = 0.5 * step; } else if (b <= 0.0) tmpStep = (-b + FastMath.sqrt(disc)) / (3.0 * a); else tmpStep = -slope / (b + FastMath.sqrt(disc)); } if (tmpStep > 0.5 * step) tmpStep = 0.5 * step; // lambda <= 0.5 lambda_1 } } else { if (step == 1.0) // first time through tmpStep = -slope / (2.0 * (scoreAtStart - score - slope)); else { rhs1 = scoreAtStart - score - step * slope; rhs2 = scoreAtStart - score2 - step2 * slope; if (step == step2) throw new IllegalStateException( "FAILURE: dividing by step-step2 which equals 0. step=" + step); double stepSquared = step * step; double step2Squared = step2 * step2; a = (rhs1 / stepSquared - rhs2 / step2Squared) / (step - step2); b = (-step2 * rhs1 / stepSquared + step * rhs2 / step2Squared) / (step - step2); if (a == 0.0) tmpStep = -slope / (2.0 * b); else { disc = b * b - 3.0 * a * slope; if (disc < 0.0) { tmpStep = 0.5 * step; } else if (b <= 0.0) tmpStep = (-b + FastMath.sqrt(disc)) / (3.0 * a); else tmpStep = -slope / (b + FastMath.sqrt(disc)); } if (tmpStep > 0.5 * step) tmpStep = 0.5 * step; // lambda <= 0.5 lambda_1 } } step2 = step; score2 = score; log.debug("tmpStep: {}", tmpStep); step = Math.max(tmpStep, .1f * step); // lambda >= .1*Lambda_1 } if (minObjectiveFunction && bestScore < scoreAtStart) { //Return best step size log.debug( "Exited line search after maxIterations termination condition; bestStepSize={}, bestScore={}, scoreAtStart={}", bestStepSize, bestScore, scoreAtStart); return bestStepSize; } else if (!minObjectiveFunction && bestScore > scoreAtStart) { //Return best step size log.debug( "Exited line search after maxIterations termination condition; bestStepSize={}, bestScore={}, scoreAtStart={}", bestStepSize, bestScore, scoreAtStart); return bestStepSize; } else { log.debug( "Exited line search after maxIterations termination condition; score did not improve (bestScore={}, scoreAtStart={}). Resetting parameters", bestScore, scoreAtStart); setScoreFor(parameters); return 0.0; } } }