com.insightml.math.optimization.AbstractOptimizable.java Source code

Java tutorial

Introduction

Here is the source code for com.insightml.math.optimization.AbstractOptimizable.java

Source

/*
 * Copyright (C) 2016 Stefan Hen
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.insightml.math.optimization;

import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.MaxIter;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer.PopulationSize;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer.Sigma;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.MultiDirectionalSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math3.random.MersenneTwister;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.insightml.math.Vectors;
import com.insightml.utils.Check;
import com.insightml.utils.types.Triple;
import com.insightml.utils.ui.SimpleFormatter;

public abstract class AbstractOptimizable implements Optimizable {

    private final Convergence convergence;
    private final SimpleBounds bounds;

    private final boolean log;
    private final Logger logger = LoggerFactory.getLogger(AbstractOptimizable.class);

    public AbstractOptimizable(final int maxIt, final double precision) {
        this(maxIt, precision, true);
    }

    private AbstractOptimizable(final int maxIt, final double precision, final boolean log) {
        this(maxIt, precision, null, null, null, log);
    }

    public AbstractOptimizable(final int maxIt, final double precision, final double lower, final double upper) {
        this(maxIt, precision, new double[] { lower }, new double[] { upper }, null, true);
    }

    public AbstractOptimizable(final int maxIt, final double precision, final double[] lower,
            final double[] upper) {
        this(maxIt, precision, lower, upper, null, true);
    }

    public AbstractOptimizable(final int maxIt, final double precision, final double[] lower, final double[] upper,
            final Double trainMax, final boolean log) {
        convergence = new Convergence(maxIt, trainMax, precision);
        bounds = lower == null ? null : new SimpleBounds(lower, upper);
        this.log = log;
    }

    public final Triple<double[], Double, Double> max() {
        return max(bounds.getLower());
    }

    public final Triple<double[], Double, Double> max(final double[] initial) {
        return max(null, initial);
    }

    @Override
    public final Triple<double[], Double, Double> max(final MultivariateFunction test, final double[] initial) {
        Check.size(initial, 1, 999);
        final double initialTrain = value(initial);
        final double initialTest = test == null ? Double.NEGATIVE_INFINITY : test.value(initial);
        PointValuePair result = new PointValuePair(initial, initialTrain);
        Triple<double[], Double, Double> bestTrain = Triple.create(initial, initialTrain, initialTest);
        Triple<double[], Double, Double> bestTest = Triple.create(initial, initialTrain, initialTest);

        while (true) {
            result = iteration(result);
            if (result.getSecond() < bestTrain.getSecond() + convergence.getAbsoluteThreshold()) {
                log("RESULT", result);
                break;
            }
            final double testScore = test == null ? 0 : test.value(result.getFirst());
            bestTrain = Triple.create(result.getFirst(), result.getSecond(), testScore);
            if (test != null && testScore > bestTest.getThird()) {
                bestTest = bestTrain;
            }
            // todo: prevent doing NM twice
            if (bounds == null) {
                break;
            }
        }

        if (test == null) {
            return bestTrain;
        }
        final double improveTrain = bestTrain.getSecond() - bestTest.getSecond();
        final double improveTest = bestTest.getThird() - bestTrain.getThird();
        if (improveTest > improveTrain) {
            logger.info(bestTrain + " vs. " + bestTest);
        }
        return improveTest > improveTrain ? bestTest : bestTrain;
    }

    private PointValuePair iteration(final PointValuePair init) {
        PointValuePair result = init;
        if (bounds != null) {
            result = select("CM", cmaes(init.getFirst(), 10), result);
        }
        result = select("NM", nelderMead(init.getFirst()), result);
        if (false) {
            result = select("MD", multiDirection(init.getFirst()), result);
            result = select("PO", powell(init.getFirst()), result);
            try {
                result = select("BO", bobyqa(init.getFirst()), result);
            } catch (final Exception e) {
                logger.error("{}", e);
            }
        }
        return result;
    }

    private PointValuePair select(final String method, final PointValuePair newResult,
            final PointValuePair oldResult) {
        if (newResult.getValue() < oldResult.getValue()
                || convergence.trainMax != null && newResult.getValue() > convergence.trainMax) {
            log(method + " (rejc.)", newResult);
            return oldResult;
        }
        log(method, newResult);
        return newResult;
    }

    private PointValuePair nelderMead(final double[] initialValues) {
        return simplex(new NelderMeadSimplex(initialValues.length, 1, 2, 0.5, 0.5), initialValues);
    }

    private PointValuePair multiDirection(final double[] initialValues) {
        return simplex(new MultiDirectionalSimplex(initialValues.length, 2, 0.5), initialValues);
    }

    private PointValuePair simplex(final AbstractSimplex simplex, final double[] initialValues) {
        return optimize(new SimplexOptimizer(conv()), initialValues, simplex);
    }

    private PointValuePair bobyqa(final double[] initialValues) {
        return optimize(new BOBYQAOptimizer(2 * initialValues.length), initialValues);
    }

    private PointValuePair cmaes(final double[] initialValues, final int initialSize) {
        final CMAESOptimizer cmaes = new CMAESOptimizer(30000, 0, true, 1, 0,
                new MersenneTwister(1334498400 * 1000), false, conv());
        return optimize(cmaes, initialValues, new PopulationSize(initialSize),
                new Sigma(Vectors.fill(0.3, initialValues.length)), bounds);
    }

    private PointValuePair powell(final double[] initialValues) {
        final PowellOptimizer powell = new PowellOptimizer(0.0000001, 0.0000001, conv());
        return optimize(powell, initialValues);
    }

    private Convergence conv() {
        return new Convergence(convergence.maxIt, convergence.trainMax, convergence.getAbsoluteThreshold());
    }

    private PointValuePair optimize(final MultivariateOptimizer optimizer, final double[] initialValues,
            final OptimizationData... data) {
        final OptimizationData[] d = new OptimizationData[5 + data.length];
        d[0] = new MaxIter(convergence.maxIt + 1);
        d[1] = new MaxEval(convergence.maxIt * 2);
        d[2] = new ObjectiveFunction(this);
        d[3] = GoalType.MAXIMIZE;
        d[4] = new InitialGuess(fixBounds(initialValues));
        for (int i = 0; i < data.length; ++i) {
            d[5 + i] = data[i];
        }
        return optimizer.optimize(d);
    }

    private double[] fixBounds(final double[] old) {
        if (bounds == null) {
            return old;
        }
        final double[] fix = new double[old.length];
        for (int i = 0; i < old.length; ++i) {
            fix[i] = Math.max(bounds.getLower()[i], Math.min(bounds.getUpper()[i], old[i]));
        }
        return fix;
    }

    private void log(final String method, final PointValuePair results) {
        if (log) {
            final SimpleFormatter format = new SimpleFormatter(5, true);
            logger.info(method + ": " + format.format(results.getKey()) + ", " + format.format(results.getValue()));
        }
    }
}