Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.lambda.ml.param; import com.anhth12.lambda.common.random.RandomManager; import com.google.common.base.Preconditions; import com.typesafe.config.Config; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import org.apache.commons.math3.random.RandomDataGenerator; import scala.Predef; /** * * @author Tong Hoang Anh */ public class HyperParams { private static final int MAX_COMBOS = 65536; private HyperParams() { } public static HyperParamValues<Double> fixed(double fixValues) { return new ContinuousRange(fixValues, fixValues); } public static HyperParamValues<Double> range(double min, double max) { return new ContinuousRange(min, max); } public static HyperParamValues<Double> around(double value, double step) { return new ContinuousAround(value, step); } public static HyperParamValues<Integer> fixed(int fixedValue) { return new DiscreteRange(fixedValue, fixedValue); } public static HyperParamValues<Integer> range(int min, int max) { return new DiscreteRange(min, max); } public static HyperParamValues<Integer> around(int value, int step) { return new DiscreteAround(value, step); } public static <T> HyperParamValues<T> unorderedFromValues(Collection<T> values) { return new Unordered<>(values); } public static HyperParamValues<?> fromConfig(Config config, String key) { switch (config.getValue(key).valueType()) { case LIST: List<String> stringValues = config.getStringList(key); try { return range(Integer.parseInt(stringValues.get(0)), Integer.parseInt(stringValues.get(1))); } catch (NumberFormatException nfe) { // continue } try { return range(Double.parseDouble(stringValues.get(0)), Double.parseDouble(stringValues.get(1))); } catch (NumberFormatException nfe) { // continue } return unorderedFromValues(stringValues); case STRING: case NUMBER: String stringValue = config.getString(key); try { return fixed(Integer.parseInt(stringValue)); } catch (NumberFormatException nfe) { } try { return fixed(Double.parseDouble(stringValue)); } catch (NumberFormatException nfe) { // continue } return unorderedFromValues(Collections.singletonList(stringValue)); default: throw new AssertionError(config.getValue(key).valueType().name()); } } public static List<List<?>> chooseHyperParameterCombos(Collection<HyperParamValues<?>> ranges, int howMany, int perParam) { Preconditions.checkArgument(howMany > 0); Preconditions.checkArgument(perParam >= 0); int numParams = ranges.size(); if (numParams == 0 || perParam == 0) { return Collections.<List<?>>singletonList(Collections.emptyList()); } // Put some reasonable upper limit on the number of combos Preconditions.checkArgument(Math.pow(perParam, numParams) <= MAX_COMBOS); int howManyCombos = 1; List<List<?>> paramRanges = new ArrayList<>(numParams); for (HyperParamValues<?> range : ranges) { List<?> values = range.getTrialValues(perParam); paramRanges.add(values); howManyCombos *= values.size(); } List<List<?>> allCombinations = new ArrayList<>(howManyCombos); for (int combo = 0; combo < howManyCombos; combo++) { List<Object> combination = new ArrayList<>(numParams); for (int param = 0; param < numParams; param++) { int whichValueToTry = combo; for (int i = 0; i < param; i++) { whichValueToTry /= paramRanges.get(i).size(); } whichValueToTry %= paramRanges.get(param).size(); combination.add(paramRanges.get(param).get(whichValueToTry)); } allCombinations.add(combination); } if (howMany >= howManyCombos) { Collections.shuffle(allCombinations); return allCombinations; } RandomDataGenerator rdg = new RandomDataGenerator(RandomManager.getRandom()); int[] indices = rdg.nextPermutation(howManyCombos, howMany); List<List<?>> result = new ArrayList<>(indices.length); for (int i = 0; i < indices.length; i++) { result.add(allCombinations.get(i)); } Collections.shuffle(result); return result; } /** * @param numParams number of different hyperparameters * @param candidates minimum number of candidates to be built * @return smallest value such that pow(value, numParams) is at least the * number of candidates requested to build. Returns 0 if numParams is less * than 1. */ public static int chooseValuesPerHyperParam(int numParams, int candidates) { if (numParams < 1) { return 0; } int valuesPerHyperParam = 0; int total; do { valuesPerHyperParam++; total = 1; for (int i = 0; i < numParams; i++) { total *= valuesPerHyperParam; } } while (total < candidates); return valuesPerHyperParam; } }