com.anhth12.lambda.ml.param.HyperParams.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.lambda.ml.param.HyperParams.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.lambda.ml.param;

import com.anhth12.lambda.common.random.RandomManager;
import com.google.common.base.Preconditions;
import com.typesafe.config.Config;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.commons.math3.random.RandomDataGenerator;
import scala.Predef;

/**
 *
 * @author Tong Hoang Anh
 */
public class HyperParams {

    private static final int MAX_COMBOS = 65536;

    private HyperParams() {
    }

    public static HyperParamValues<Double> fixed(double fixValues) {
        return new ContinuousRange(fixValues, fixValues);
    }

    public static HyperParamValues<Double> range(double min, double max) {
        return new ContinuousRange(min, max);
    }

    public static HyperParamValues<Double> around(double value, double step) {
        return new ContinuousAround(value, step);
    }

    public static HyperParamValues<Integer> fixed(int fixedValue) {
        return new DiscreteRange(fixedValue, fixedValue);
    }

    public static HyperParamValues<Integer> range(int min, int max) {
        return new DiscreteRange(min, max);
    }

    public static HyperParamValues<Integer> around(int value, int step) {
        return new DiscreteAround(value, step);
    }

    public static <T> HyperParamValues<T> unorderedFromValues(Collection<T> values) {
        return new Unordered<>(values);
    }

    public static HyperParamValues<?> fromConfig(Config config, String key) {
        switch (config.getValue(key).valueType()) {
        case LIST:
            List<String> stringValues = config.getStringList(key);
            try {
                return range(Integer.parseInt(stringValues.get(0)), Integer.parseInt(stringValues.get(1)));
            } catch (NumberFormatException nfe) {
                // continue
            }
            try {
                return range(Double.parseDouble(stringValues.get(0)), Double.parseDouble(stringValues.get(1)));
            } catch (NumberFormatException nfe) {
                // continue
            }
            return unorderedFromValues(stringValues);
        case STRING:
        case NUMBER:
            String stringValue = config.getString(key);
            try {
                return fixed(Integer.parseInt(stringValue));
            } catch (NumberFormatException nfe) {

            }

            try {
                return fixed(Double.parseDouble(stringValue));
            } catch (NumberFormatException nfe) {
                // continue
            }
            return unorderedFromValues(Collections.singletonList(stringValue));
        default:
            throw new AssertionError(config.getValue(key).valueType().name());

        }
    }

    public static List<List<?>> chooseHyperParameterCombos(Collection<HyperParamValues<?>> ranges, int howMany,
            int perParam) {
        Preconditions.checkArgument(howMany > 0);
        Preconditions.checkArgument(perParam >= 0);

        int numParams = ranges.size();
        if (numParams == 0 || perParam == 0) {
            return Collections.<List<?>>singletonList(Collections.emptyList());
        }

        // Put some reasonable upper limit on the number of combos
        Preconditions.checkArgument(Math.pow(perParam, numParams) <= MAX_COMBOS);

        int howManyCombos = 1;
        List<List<?>> paramRanges = new ArrayList<>(numParams);
        for (HyperParamValues<?> range : ranges) {
            List<?> values = range.getTrialValues(perParam);
            paramRanges.add(values);
            howManyCombos *= values.size();
        }

        List<List<?>> allCombinations = new ArrayList<>(howManyCombos);
        for (int combo = 0; combo < howManyCombos; combo++) {
            List<Object> combination = new ArrayList<>(numParams);
            for (int param = 0; param < numParams; param++) {
                int whichValueToTry = combo;
                for (int i = 0; i < param; i++) {
                    whichValueToTry /= paramRanges.get(i).size();
                }
                whichValueToTry %= paramRanges.get(param).size();
                combination.add(paramRanges.get(param).get(whichValueToTry));
            }
            allCombinations.add(combination);
        }

        if (howMany >= howManyCombos) {
            Collections.shuffle(allCombinations);
            return allCombinations;
        }
        RandomDataGenerator rdg = new RandomDataGenerator(RandomManager.getRandom());
        int[] indices = rdg.nextPermutation(howManyCombos, howMany);
        List<List<?>> result = new ArrayList<>(indices.length);
        for (int i = 0; i < indices.length; i++) {
            result.add(allCombinations.get(i));
        }
        Collections.shuffle(result);
        return result;
    }

    /**
     * @param numParams number of different hyperparameters
     * @param candidates minimum number of candidates to be built
     * @return smallest value such that pow(value, numParams) is at least the
     * number of candidates requested to build. Returns 0 if numParams is less
     * than 1.
     */
    public static int chooseValuesPerHyperParam(int numParams, int candidates) {
        if (numParams < 1) {
            return 0;
        }
        int valuesPerHyperParam = 0;
        int total;
        do {
            valuesPerHyperParam++;
            total = 1;
            for (int i = 0; i < numParams; i++) {
                total *= valuesPerHyperParam;
            }
        } while (total < candidates);
        return valuesPerHyperParam;
    }

}