hivemall.common.EtaEstimator.java Source code

Java tutorial

Introduction

Here is the source code for hivemall.common.EtaEstimator.java

Source

/*
 * Hivemall: Hive scalable Machine Learning Library
 *
 * Copyright (C) 2015 Makoto YUI
 * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package hivemall.common;

import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

public abstract class EtaEstimator {

    protected final float eta0;

    public EtaEstimator(float eta0) {
        this.eta0 = eta0;
    }

    public float eta0() {
        return eta0;
    }

    public abstract float eta(long t);

    public void update(@Nonnegative float multipler) {
    }

    public static final class FixedEtaEstimator extends EtaEstimator {

        public FixedEtaEstimator(float eta) {
            super(eta);
        }

        @Override
        public float eta(long t) {
            return eta0;
        }

    }

    public static final class SimpleEtaEstimator extends EtaEstimator {

        private final float finalEta;
        private final double total_steps;

        public SimpleEtaEstimator(float eta0, long total_steps) {
            super(eta0);
            this.finalEta = (float) (eta0 / 2.d);
            this.total_steps = total_steps;
        }

        @Override
        public float eta(final long t) {
            if (t > total_steps) {
                return finalEta;
            }
            return (float) (eta0 / (1.d + (t / total_steps)));
        }

    }

    public static final class InvscalingEtaEstimator extends EtaEstimator {

        private final double power_t;

        public InvscalingEtaEstimator(float eta0, double power_t) {
            super(eta0);
            this.power_t = power_t;
        }

        @Override
        public float eta(final long t) {
            return (float) (eta0 / Math.pow(t, power_t));
        }

    }

    /**
     * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic
     * gradient descent, KDD 2011.
     */
    public static final class AdjustingEtaEstimator extends EtaEstimator {

        private float eta;

        public AdjustingEtaEstimator(float eta) {
            super(eta);
            this.eta = eta;
        }

        @Override
        public float eta(long t) {
            return eta;
        }

        @Override
        public void update(@Nonnegative float multipler) {
            float newEta = eta * multipler;
            if (!NumberUtils.isFinite(newEta)) {
                // avoid NaN or INFINITY
                return;
            }
            this.eta = Math.min(eta0, newEta); // never be larger than eta0
        }

    }

    @Nonnull
    public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException {
        return get(cl, 0.1f);
    }

    @Nonnull
    public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0) throws UDFArgumentException {
        if (cl == null) {
            return new InvscalingEtaEstimator(defaultEta0, 0.1d);
        }

        if (cl.hasOption("boldDriver")) {
            float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f);
            return new AdjustingEtaEstimator(eta);
        }

        String etaValue = cl.getOptionValue("eta");
        if (etaValue != null) {
            float eta = Float.parseFloat(etaValue);
            return new FixedEtaEstimator(eta);
        }

        float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0);
        if (cl.hasOption("t")) {
            long t = Long.parseLong(cl.getOptionValue("t"));
            return new SimpleEtaEstimator(eta0, t);
        }

        double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1d);
        return new InvscalingEtaEstimator(eta0, power_t);
    }

}