net.myrrix.online.eval.ParameterOptimizer.java Source code

Java tutorial

Introduction

Here is the source code for net.myrrix.online.eval.ParameterOptimizer.java

Source

/*
 * Copyright Myrrix Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.myrrix.online.eval;

import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.regex.Pattern;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.math3.util.Pair;
import org.apache.mahout.cf.taste.common.TasteException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * <p>This is an experimental utility class which can find a nearly-optimal set of learning algorithm parameters,
 * for a given set of parameters and range of values, and a given metric to optimize. It simply performs an  
 * exhaustive search.</p>
 * 
 * <p>It can be run on the command line with:</p>
 * 
 * <p>{@code java net.myrrix.online.eval.ParameterOptimizer dataDirectory numSteps evaluationPercentage
 *  property=min:max [property2=min2:max2 ...]}</p>
 *  
 * <ul>
 *   <li>{@code dataDirectory}: directory containing test data</li>
 *   <li>{@code numSteps}: number of different values of each parameter to try. Generally use 3-5.</li>
 *   <li>{@code evaluationPercentage}: fraction of all data to use in the test. Lower this to down-sample
 *    a very large data set. Must be in (0,1].</li>
 *   <li>{@code property=min:max}: repeated argument specifying a system property and the range of values 
 *    to try, inclusive</li>
 * </ul>
 * 
 * @author Sean Owen
 * @since 1.0
 */
public final class ParameterOptimizer implements Callable<Map<String, Number>> {

    private static final Logger log = LoggerFactory.getLogger(ParameterOptimizer.class);
    private static final Pattern EQUALS = Pattern.compile("=");
    private static final Pattern COLON = Pattern.compile(":");

    private final Map<String, ParameterRange> parameterRanges;
    private final int numSteps;
    private final Callable<? extends Number> evaluator;
    private final boolean minimize;

    public ParameterOptimizer(Map<String, ParameterRange> parameterRanges, Callable<? extends Number> evaluator) {
        this(parameterRanges, 4, evaluator, false);
    }

    /**
     * @param parameterRanges mapping between names of {@link System} properties whose parameters will be optimized
     *  (e.g. {@code model.als.lambda}), and a {@link ParameterRange} describing a range of parameter values to try
     * @param numSteps number of different values of each parameter to try. Note that with m parameters, and
     *  n steps, running time will scale proportionally to n<sup>m</sup>
     * @param evaluator the objective to maximize (or minimize). This typically wraps a call to something like
     *  {@link PrecisionRecallEvaluator}
     * @param minimize if {@code true}, find values that maximize {@code evaluator}'s value, otherwise minimize
     */
    public ParameterOptimizer(Map<String, ParameterRange> parameterRanges, int numSteps,
            Callable<? extends Number> evaluator, boolean minimize) {
        Preconditions.checkNotNull(parameterRanges);
        Preconditions.checkArgument(!parameterRanges.isEmpty(), "parameterRanges is empty");
        Preconditions.checkArgument(numSteps >= 2);
        Preconditions.checkNotNull(evaluator);
        this.parameterRanges = parameterRanges;
        this.numSteps = numSteps;
        this.evaluator = evaluator;
        this.minimize = minimize;
    }

    /**
     * @return {@link #findGoodParameterValues()}
     */
    @Override
    public Map<String, Number> call() throws ExecutionException {
        return findGoodParameterValues();
    }

    /**
     * @return a {@link Map} between the values of the given {@link System} properties and the best value found
     *  during search
     * @throws ExecutionException if an error occurs while calling {@code evaluator}; the cause is the
     *  underlying exception
     */
    public Map<String, Number> findGoodParameterValues() throws ExecutionException {

        int numProperties = parameterRanges.size();
        String[] propertyNames = new String[numProperties];
        Number[][] parameterValuesToTry = new Number[numProperties][];
        int index = 0;
        for (Map.Entry<String, ParameterRange> entry : parameterRanges.entrySet()) {
            propertyNames[index] = entry.getKey();
            parameterValuesToTry[index] = entry.getValue().buildSteps(numSteps);
            index++;
        }

        int numTests = 1;
        for (Number[] toTry : parameterValuesToTry) {
            numTests *= toTry.length;
        }

        List<Pair<Double, String>> testResultLinesByValue = Lists.newArrayListWithCapacity(numTests);

        Map<String, Number> bestParameterValues = Maps.newHashMap();
        double bestValue = minimize ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;

        for (int test = 0; test < numTests; test++) {

            StringBuilder testResultLine = new StringBuilder();
            for (int prop = 0; prop < numProperties; prop++) {
                String property = propertyNames[prop];
                Number parameterValue = getParameterValueToTry(parameterValuesToTry, test, prop);
                String propertyString = parameterValue.toString();
                log.info("Setting {}={}", property, propertyString);
                System.setProperty(property, propertyString);
                testResultLine.append('[').append(property).append('=').append(propertyString).append("] ");
            }

            Number evaluatorResult;
            try {
                evaluatorResult = evaluator.call();
            } catch (Exception e) {
                throw new ExecutionException(e);
            }
            if (evaluatorResult == null) {
                continue;
            }
            double testValue = evaluatorResult.doubleValue();
            testResultLine.append("= ").append(testValue);
            testResultLinesByValue.add(new Pair<Double, String>(testValue, testResultLine.toString()));
            log.info("{}", testResultLine);

            if (minimize ? testValue < bestValue : testValue > bestValue) {
                log.info("New best value {}", testValue);
                bestValue = testValue;
                for (int prop = 0; prop < numProperties; prop++) {
                    String property = propertyNames[prop];
                    Number parameterValue = getParameterValueToTry(parameterValuesToTry, test, prop);
                    bestParameterValues.put(property, parameterValue);
                }
            }

            Collections.sort(testResultLinesByValue, new Comparator<Pair<Double, String>>() {
                @Override
                public int compare(Pair<Double, String> a, Pair<Double, String> b) {
                    if (a.getFirst() > b.getFirst()) {
                        return -1;
                    }
                    if (a.getFirst() < b.getFirst()) {
                        return 1;
                    }
                    return 0;
                }
            });

            for (Pair<Double, String> result : testResultLinesByValue) {
                log.info("{}", result.getSecond());
            }
            log.info("Best parameter values so far are {}", bestParameterValues);
        }

        log.info("Final best parameter values are {}", bestParameterValues);
        return bestParameterValues;
    }

    private static Number getParameterValueToTry(Number[][] parameterValuesToTry, int test, int prop) {
        int whichValueToTry = test;
        for (int i = 0; i < prop; i++) {
            whichValueToTry /= parameterValuesToTry[i].length;
        }
        whichValueToTry %= parameterValuesToTry[prop].length;
        return parameterValuesToTry[prop][whichValueToTry];
    }

    public static void main(String[] args) throws Exception {
        if (args.length < 4) {
            System.err.println(
                    "Usage: dataDirectory numSteps evaluationPercentage property=min:max [property2=min2:max2 ...]");
            return;
        }

        final File dataDir = new File(args[0]);
        Preconditions.checkArgument(dataDir.exists() && dataDir.isDirectory(), "Not a directory: %s", dataDir);
        Preconditions.checkArgument(dataDir.listFiles().length > 0, "No files in: %s", dataDir);
        int numSteps = Integer.parseInt(args[1]);
        Preconditions.checkArgument(numSteps >= 2, "# steps must be at least 2: %s", numSteps);
        final double evaluationPercentage = Double.parseDouble(args[2]);
        Preconditions.checkArgument(evaluationPercentage > 0.0 && evaluationPercentage <= 1.0,
                "evaluationPercentage must be in (0,1]: %s", evaluationPercentage);

        Map<String, ParameterRange> parameterRanges = Maps.newHashMapWithExpectedSize(args.length);
        for (int i = 3; i < args.length; i++) {
            String[] propValue = EQUALS.split(args[i]);
            String systemProperty = propValue[0];
            String[] minMax = COLON.split(propValue[1]);
            ParameterRange range;
            try {
                int min = Integer.parseInt(minMax[0]);
                int max = Integer.parseInt(minMax.length == 1 ? minMax[0] : minMax[1]);
                range = new ParameterRange(min, max);
            } catch (NumberFormatException ignored) {
                double min = Double.parseDouble(minMax[0]);
                double max = Double.parseDouble(minMax.length == 1 ? minMax[0] : minMax[1]);
                range = new ParameterRange(min, max);
            }
            parameterRanges.put(systemProperty, range);
        }

        Callable<Number> evaluator = new Callable<Number>() {
            @Override
            public Number call() throws IOException, TasteException, InterruptedException {
                MyrrixIRStatistics stats = (MyrrixIRStatistics) new PrecisionRecallEvaluator().evaluate(dataDir,
                        0.9, evaluationPercentage, null);
                return stats == null ? null : stats.getMeanAveragePrecision();
            }
        };

        ParameterOptimizer optimizer = new ParameterOptimizer(parameterRanges, numSteps, evaluator, false);
        Map<String, Number> optimalValues = optimizer.findGoodParameterValues();
        System.out.println(optimalValues);
    }

}