eu.crisis_economics.abm.algorithms.optimization.BrentLineSearch.java Source code

Java tutorial

Introduction

Here is the source code for eu.crisis_economics.abm.algorithms.optimization.BrentLineSearch.java

Source

/*
 * This file is part of CRISIS, an economics simulator.
 * 
 * Copyright (C) 2015 John Kieran Phillips
 *
 * CRISIS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * CRISIS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with CRISIS.  If not, see <http://www.gnu.org/licenses/>.
 */
package eu.crisis_economics.abm.algorithms.optimization;

import java.util.Arrays;

import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.univariate.BrentOptimizer;
import org.apache.commons.math3.optim.univariate.SearchInterval;
import org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction;
import org.apache.commons.math3.optim.univariate.UnivariateOptimizer;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;

import com.google.common.base.Preconditions;

import eu.crisis_economics.utilities.StateVerifier;

/**
  * Univariate line search minimization using Brent optimization.
  * @author phillips
  */
public final class BrentLineSearch {

    static final private double DEFAULT_ERROR_TARGET_RELATIVE = 1.e-10, DEFAULT_ERROR_TARGET_ABSOLUTE = 1.e-14;
    static final private int DEFAULT_MAXIMUM_EVALUATIONS = 100;

    /**
      * Treat a combination of:
      *   (a) a scalar-valued multivariate function,
      *   (b) a search direction (L), and
      *   (c) a starting point (P),
      * as a scalar-valued univariate function whole sole argument is 
      * the distance along the line search direction L starting at P.
      * 
      * @author phillips
      */
    static public final class LineSearchObjectiveFunction implements UnivariateFunction {
        private MultivariateFunction function;

        private double[] normalizedDirection, startingPoint, workspace;

        LineSearchObjectiveFunction(final MultivariateFunction function, final double[] startingPoint,
                final double[] vectorDirection) {
            StateVerifier.checkNotNull(function, startingPoint, vectorDirection);
            Preconditions.checkArgument(startingPoint.length > 0);
            Preconditions.checkArgument(startingPoint.length == vectorDirection.length);
            this.function = function;
            this.normalizedDirection = new double[vectorDirection.length];
            { // Normalize the input vector.
                double norm = 0.;
                for (int i = 0; i < vectorDirection.length; ++i)
                    norm += vectorDirection[i] * vectorDirection[i];
                norm = Math.sqrt(norm);
                for (int i = 0; i < vectorDirection.length; ++i)
                    normalizedDirection[i] = vectorDirection[i] / norm;
            }
            this.startingPoint = Arrays.copyOf(startingPoint, startingPoint.length);
            this.workspace = new double[startingPoint.length];
        }

        @Override
        public double value(final double x) {
            for (int i = 0; i < startingPoint.length; ++i)
                workspace[i] = startingPoint[i] + x * normalizedDirection[i];
            return function.value(workspace);
        }
    }

    /**
      * The result of a line search operation. This class is immutable.
      * @author phillips
      */
    static final public class LineSearchResult { // Immutable
        private double[] solutionVector;
        private double evaluationAtSolution;

        LineSearchResult(final double[] solutionVector, final double evaluationAtSolution) {
            this.solutionVector = solutionVector;
            this.evaluationAtSolution = evaluationAtSolution;
        }

        /**
          * @return The solution obtained from a line search operation.
          */
        public double[] getSolutionPoint() {
            return Arrays.copyOf(solutionVector, solutionVector.length);
        }

        /**
          * @return The evaluation of the merit function at the solution.
          */
        public double getEvaluationAtSolution() {
            return evaluationAtSolution;
        }

        /**
          * Returns a brief description of this object. The exact details of the
          * string are subject to change, and should not be regarded as fixed.
          */
        @Override
        public String toString() {
            return "LineSearchResult, solution vector: " + Arrays.toString(solutionVector)
                    + ", evaluation at solution: " + evaluationAtSolution + ".";
        }
    }

    /**
      * Perform a line search minimization. This function accepts as input:
      *   (a) a starting point (a vector),
      *   (b) a direction in which to travel (a vector),
      *   (c) limits on the total distance to travel along (b).
      *   
      * With these inputs the function attempts to find the minimum of a
      * scalar-valued multivariate function along the line starting at 
      * (a) and pointing in the direction of (b).
      * 
      * @param function
      *        A scalar-valued multivariate function to minimize,
      * @param startingPoint
      *        A vector starting point from which to begin the minimization (P),
      * @param vectorDirection
      *        A vector direction along which to travel from P, (V)
      * @param maximumDistanceToTravel
      *        The maximum distance to travel in the direction of V,
      * @param maximumEvaluations
      *        The maximum number of function evaluations to identify the minimum,
      * @param relativeErrorGoal
      *        The relative error target of the minimization,
      * @param absoluteErrorGoal
      *        The absolute error target of the minimization.
      * @return
      *        A lightweight immutable struct containing the vector solution and
      *        the evaluation of the function at this point.
      */
    static public LineSearchResult doLineSearch(final MultivariateFunction function, final double[] startingPoint,
            final double[] vectorDirection, final double maximumDistanceToTravel, final int maximumEvaluations,
            final double relativeErrorGoal, final double absoluteErrorGoal) {
        Preconditions.checkArgument(maximumEvaluations > 0);
        Preconditions.checkArgument(relativeErrorGoal > 0. || absoluteErrorGoal > 0.);
        Preconditions.checkArgument(maximumDistanceToTravel > 0.);
        final LineSearchObjectiveFunction lineSearcher = new LineSearchObjectiveFunction(function, startingPoint,
                vectorDirection);
        final UnivariateOptimizer optimizer = new BrentOptimizer(relativeErrorGoal, absoluteErrorGoal);
        UnivariatePointValuePair result = optimizer.optimize(new MaxEval(maximumEvaluations),
                new UnivariateObjectiveFunction(lineSearcher), GoalType.MINIMIZE,
                new SearchInterval(0, maximumDistanceToTravel, 0));
        final double[] vectorSolution = new double[startingPoint.length];
        for (int i = 0; i < vectorDirection.length; ++i)
            vectorSolution[i] = lineSearcher.startingPoint[i]
                    + lineSearcher.normalizedDirection[i] * result.getPoint();
        final LineSearchResult solution = new LineSearchResult(vectorSolution, result.getValue());
        return solution;
    }

    static public LineSearchResult doLineSearch(final MultivariateFunction function, final double[] startingPoint,
            final double[] vectorDirection, final double maximumDistanceToTravel) {
        return doLineSearch(function, startingPoint, vectorDirection, maximumDistanceToTravel,
                DEFAULT_MAXIMUM_EVALUATIONS, DEFAULT_ERROR_TARGET_RELATIVE, DEFAULT_ERROR_TARGET_ABSOLUTE);
    }

    /**
      * Compute the maximum distance a line search may travel whilst
      * still remaining within the coordinate box specified by a pair
      * of (upper/lower) domain bounds.
      * 
      * @param startingPoint (P)
      *        The starting point (root) of a line search.
      * @param vectorDirection (V)
      *        The vector direction of the line search.
      * @param domainMaxima
      *        Upper bounds for each coordinate to test during the line
      *        search.
      * @param domainMinima
      *        Lower bounds for each coordinate to test during the line
      *        search.
      * @return
      *        The maximum distance along the *normalized* direction V
      *        that a line search may travel, starting at P, such that
      *        no coordinates are outside of the domain bounds.
      */
    static private double computeMaximumSearchDistanceFromDomainBounds(final double[] startingPoint,
            final double[] vectorDirection, final double[] domainMaxima, final double[] domainMinima) {
        // Check whether the starting point is inside the boundary domain.
        for (int i = 0; i < startingPoint.length; ++i) {
            Preconditions.checkArgument(startingPoint[i] >= domainMinima[i]);
            Preconditions.checkArgument(startingPoint[i] <= domainMaxima[i]);
        }

        // Normalize the line search direction.
        double[] normalizedSearchDirection = new double[startingPoint.length];
        double norm = 0.;
        for (int i = 0; i < vectorDirection.length; ++i)
            norm += vectorDirection[i] * vectorDirection[i];
        norm = Math.sqrt(norm);
        if (norm == 0)
            return 0.;
        double maximumDistanceToTravel = Double.MAX_VALUE;

        // Identify the limiting step size for each coordinate direction.
        for (int i = 0; i < vectorDirection.length; ++i) {
            normalizedSearchDirection[i] = vectorDirection[i] / norm;
            double distanceLimitForThisCoordinate = 0.;
            if (normalizedSearchDirection[i] == 0)
                continue;
            else if (normalizedSearchDirection[i] > 0.)
                distanceLimitForThisCoordinate = (domainMaxima[i] - startingPoint[i])
                        / normalizedSearchDirection[i];
            else
                distanceLimitForThisCoordinate = -(startingPoint[i] - domainMinima[i])
                        / normalizedSearchDirection[i];
            maximumDistanceToTravel = Math.min(maximumDistanceToTravel, distanceLimitForThisCoordinate);
        }
        return maximumDistanceToTravel;
    }

    static public LineSearchResult doLineSearch(final MultivariateFunction function, final double[] startingPoint,
            final double[] vectorDirection, final double[] domainMaxima, final double[] domainMinima) {
        final double maximumDistanceToTravel = computeMaximumSearchDistanceFromDomainBounds(startingPoint,
                vectorDirection, domainMaxima, domainMinima);
        if (maximumDistanceToTravel == 0.)
            return new LineSearchResult(Arrays.copyOf(startingPoint, startingPoint.length),
                    function.value(startingPoint));
        return doLineSearch(function, startingPoint, vectorDirection, maximumDistanceToTravel,
                DEFAULT_MAXIMUM_EVALUATIONS, DEFAULT_ERROR_TARGET_RELATIVE, DEFAULT_ERROR_TARGET_ABSOLUTE);
    }
}