hivemall.fm.FactorizationMachineUDTF.java Source code

Java tutorial

Introduction

Here is the source code for hivemall.fm.FactorizationMachineUDTF.java

Source

/*
 * Hivemall: Hive scalable Machine Learning Library
 *
 * Copyright (C) 2015 Makoto YUI
 * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package hivemall.fm;

import hivemall.UDTFWithOptions;
import hivemall.common.ConversionState;
import hivemall.common.EtaEstimator;
import hivemall.common.LossFunctions;
import hivemall.common.LossFunctions.LossFunction;
import hivemall.common.LossFunctions.LossType;
import hivemall.fm.FMStringFeatureMapModel.Entry;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NioStatefullSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;

@Description(name = "train_fm", value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model")
public class FactorizationMachineUDTF extends UDTFWithOptions {
    private static final Log LOG = LogFactory.getLog(FactorizationMachineUDTF.class);
    private static final int INT_BYTES = Integer.SIZE / 8;

    protected ListObjectInspector _xOI;
    protected PrimitiveObjectInspector _yOI;

    /**
     * Probe for the input X
     */
    @Nullable
    protected Feature[] _probes;

    // ----------------------------------------
    // Learning hyper-parameters/options

    protected FMHyperParameters _params;

    protected boolean _classification;
    protected int _iterations;
    protected int _factors;
    protected boolean _parseFeatureAsInt;

    // adaptive regularization
    @Nullable
    protected Random _va_rand;
    protected float _validationRatio;
    protected int _validationThreshold;

    protected LossFunction _lossFunction;
    protected EtaEstimator _etaEstimator;
    protected ConversionState _cvState;

    // ----------------------------------------

    protected transient FactorizationMachineModel _model;

    /**
     * The number of training examples processed
     */
    protected long _t;

    // file IO
    private ByteBuffer _inputBuf;
    private NioStatefullSegment _fileIO;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("c", "classification", false, "Act as classification");
        opts.addOption("seed", true, "Seed value [default: -1 (random)]");
        opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]");
        opts.addOption("p", "num_features", true, "The size of feature dimensions");
        opts.addOption("factor", "factors", true, "The number of the latent variables [default: 5]");
        opts.addOption("sigma", true, "The standard deviation for initializing V [default: 0.1]");
        opts.addOption("lambda0", "lambda", true, "The initial lambda value for regularization [default: 0.01]");
        opts.addOption("lambdaW0", "lambda_w0", true,
                "The initial lambda value for W0 regularization [default: 0.01]");
        opts.addOption("lambdaWi", "lambda_wi", true,
                "The initial lambda value for Wi regularization [default: 0.01]");
        opts.addOption("lambdaV", "lambda_v", true,
                "The initial lambda value for V regularization [default: 0.01]");
        // regression
        opts.addOption("min", "min_target", true, "The minimum value of target variable");
        opts.addOption("max", "max_target", true, "The maximum value of target variable");
        // learning rates
        opts.addOption("eta", true, "The initial learning rate");
        opts.addOption("eta0", true, "The initial learning rate [default 0.05]");
        opts.addOption("t", "total_steps", true, "The total number of training examples");
        opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]");
        // conversion check
        opts.addOption("disable_cv", "disable_cvtest", false,
                "Whether to disable convergence check [default: OFF]");
        opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        // adaptive regularization
        if (isAdaptiveRegularizationSupported()) {
            opts.addOption("adareg", "adaptive_regularizaion", false,
                    "Whether to enable adaptive regularization [default: OFF]");
            opts.addOption("va_ratio", "validation_ratio", true,
                    "Ratio of training data used for validation [default: 0.05f]");
            opts.addOption("va_threshold", "validation_threshold", true, "Threshold to start validation. "
                    + "At least N training examples are used before validation [default: 1000]");
        }
        // initialization of V
        opts.addOption("init_v", true, "Initialization strategy of matrix V [random, gaussian]"
                + "(default: 'random' for regression / 'gaussian' for classification)");
        opts.addOption("maxval", "max_init_value", true,
                "The maximum initial value in the matrix V [default: 0.5]");
        opts.addOption("min_init_stddev", true,
                "The minimum standard deviation of initial matrix V [default: 0.1]");
        // feature representation
        opts.addOption("int_feature", "feature_as_integer", false, "Parse a feature as integer [default: OFF]");
        return opts;
    }

    protected boolean isAdaptiveRegularizationSupported() {
        return true;
    }

    @Override
    protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        final FMHyperParameters params = _params;

        CommandLine cl = null;
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString(argOIs[2]);
            cl = parseOptions(rawArgs);
            params.processOptions(cl);
        }

        this._classification = params.classification;
        this._iterations = params.iters;
        this._factors = params.factors;
        this._parseFeatureAsInt = params.parseFeatureAsInt;
        if (params.adaptiveReglarization) {
            this._va_rand = new Random(params.seed + 31L);
        }
        this._validationRatio = params.validationRatio;
        this._validationThreshold = params.validationThreshold;
        this._lossFunction = params.classification ? LossFunctions.getLossFunction(LossType.LogLoss)
                : LossFunctions.getLossFunction(LossType.SquaredLoss);
        this._etaEstimator = params.eta;
        this._cvState = new ConversionState(params.conversionCheck, params.convergenceRate);
        return cl;
    }

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            throw new UDFArgumentException(getClass().getSimpleName()
                    + " takes 2 or 3 arguments: array<string> x, double y [, CONSTANT STRING options]: "
                    + Arrays.toString(argOIs));
        }
        this._xOI = HiveUtils.asListOI(argOIs[0]);
        HiveUtils.validateFeatureOI(_xOI.getListElementObjectInspector());
        this._yOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);

        this._params = newHyperParameters();
        processOptions(argOIs);

        this._model = null;
        this._t = 0L;

        if (LOG.isInfoEnabled()) {
            LOG.info(_params);
        }

        return getOutputOI(_params);
    }

    @Nonnull
    protected FMHyperParameters newHyperParameters() {
        return new FMHyperParameters();
    }

    @Nonnull
    protected StructObjectInspector getOutputOI(@Nonnull FMHyperParameters params) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
        fieldNames.add("feature");
        if (params.parseFeatureAsInt) {
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        } else {
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        }
        fieldNames.add("W_i");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        fieldNames.add("V_if");
        fieldOIs.add(ObjectInspectorFactory
                .getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));

        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Nonnull
    protected FactorizationMachineModel initModel(@Nonnull FMHyperParameters params) throws UDFArgumentException {
        final FactorizationMachineModel model;
        if (params.parseFeatureAsInt) {
            if (params.numFeatures == -1) {
                model = new FMIntFeatureMapModel(params);
            } else {
                model = new FMArrayModel(params);
            }
        } else {
            model = new FMStringFeatureMapModel(params);
        }
        this._model = model;
        return model;
    }

    @Override
    public void process(Object[] args) throws HiveException {
        if (_model == null) {
            this._model = initModel(_params);
        }

        Feature[] x = parseFeatures(args[0]);
        if (x == null) {
            return;
        }
        this._probes = x;

        double y = PrimitiveObjectInspectorUtils.getDouble(args[1], _yOI);
        if (_classification) {
            y = (y > 0.d) ? 1.d : -1.d;
        }

        ++_t;
        recordTrain(x, y);
        boolean adaptiveRegularization = (_va_rand != null) && _t >= _validationThreshold;
        train(x, y, adaptiveRegularization);
    }

    @Nullable
    protected Feature[] parseFeatures(@Nonnull final Object arg) throws HiveException {
        return Feature.parseFeatures(arg, _xOI, _probes, _parseFeatureAsInt);
    }

    protected void recordTrain(@Nonnull final Feature[] x, final double y) throws HiveException {
        if (_iterations <= 1) {
            return;
        }

        ByteBuffer inputBuf = _inputBuf;
        NioStatefullSegment dst = _fileIO;
        if (inputBuf == null) {
            final File file;
            try {
                file = File.createTempFile("hivemall_fm", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath());
                }
                LOG.info("Record training examples to a file: " + file.getAbsolutePath());
            } catch (IOException ioe) {
                throw new UDFArgumentException(ioe);
            } catch (Throwable e) {
                throw new UDFArgumentException(e);
            }

            this._inputBuf = inputBuf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MiB
            this._fileIO = dst = new NioStatefullSegment(file, false);
        }

        int xBytes = Feature.requiredBytes(x);
        int recordBytes = (Integer.SIZE + Double.SIZE) / 8 + xBytes;
        int requiredBytes = (Integer.SIZE / 8) + recordBytes;
        int remain = inputBuf.remaining();
        if (remain < requiredBytes) {
            writeBuffer(inputBuf, dst);
        }

        inputBuf.putInt(recordBytes);
        inputBuf.putInt(x.length);
        for (Feature f : x) {
            f.writeTo(inputBuf);
        }
        inputBuf.putDouble(y);
    }

    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
            throws HiveException {
        srcBuf.flip();
        try {
            dst.write(srcBuf);
        } catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", e);
        }
        srcBuf.clear();
    }

    public void train(@Nonnull final Feature[] x, final double y, final boolean adaptiveRegularization)
            throws HiveException {
        _model.check(x);

        try {
            if (adaptiveRegularization) {
                assert (_va_rand != null);
                final float rnd = _va_rand.nextFloat();
                if (rnd < _validationRatio) {
                    trainLambda(x, y); // adaptive regularization
                } else {
                    trainTheta(x, y);
                }
            } else {
                trainTheta(x, y);
            }
        } catch (Exception ex) {
            throw new HiveException("Exception caused in the " + _t + "-th call of train()", ex);
        }
    }

    /**
     * Update model parameters
     */
    protected void trainTheta(final Feature[] x, final double y) throws HiveException {
        final float eta = _etaEstimator.eta(_t);

        final double p = _model.predict(x);
        final double lossGrad = _model.dloss(p, y);
        double loss = _lossFunction.loss(p, y);
        _cvState.incrLoss(loss);

        if (MathUtils.closeToZero(lossGrad)) {
            return;
        }

        // w0 update
        _model.updateW0(lossGrad, eta);

        final double[] sumVfx = _model.sumVfX(x);
        for (Feature xi : x) {
            // wi update
            _model.updateWi(lossGrad, xi, eta);
            for (int f = 0, k = _factors; f < k; f++) {
                // Vif update
                _model.updateV(lossGrad, xi, f, sumVfx[f], eta);
            }
        }
    }

    /**
     * Update regularization parameters `lambda` as follows:
     * 
     * <pre>
     *      grad_lambdaw0 = (grad l(p,y)) * (-2 * alpha * w_0)
     *      grad_lambdawg = (grad l(p,y)) * (-2 * alpha * (\sum_{l \in group(g)} x_l * w_l))
     *      grad_lambdafg = (grad l(p,y)) * (-2 * alpha * (\sum_{l} x_l * v'_lf) * \sum_{l \in group(g)} x_l * v_lf) - \sum_{l \in group(g)} x^2_l * v_lf * v'_lf)
     * </pre>
     */
    protected void trainLambda(final Feature[] x, final double y) throws HiveException {
        final float eta = _etaEstimator.eta(_t);
        final double p = _model.predict(x);
        final double lossGrad = _model.dloss(p, y);

        _model.updateLambdaW0(lossGrad, eta);
        _model.updateLambdaW(x, lossGrad, eta);
        _model.updateLambdaV(x, lossGrad, eta);
    }

    @Override
    public void close() throws HiveException {
        this._probes = null;

        if (_t == 0) {
            this._model = null;
            return;
        }
        if (_iterations > 1) {
            runTrainingIteration(_iterations);
        }

        final int P = _model.getSize();
        if (P <= 0) {
            LOG.warn("Model size P was less than zero: " + P);
            this._model = null;
            return;
        }

        forwardModel();
        this._model = null;
    }

    protected void forwardModel() throws HiveException {
        if (_parseFeatureAsInt) {
            forwardAsIntFeature(_model, _factors);
        } else {
            FMStringFeatureMapModel strModel = (FMStringFeatureMapModel) _model;
            forwardAsStringFeature(strModel, _factors);
        }
    }

    private void forwardAsIntFeature(@Nonnull final FactorizationMachineModel model, final int factors)
            throws HiveException {
        final IntWritable f_idx = new IntWritable(0);
        final FloatWritable f_Wi = new FloatWritable(0.f);
        final FloatWritable[] f_Vi = HiveUtils.newFloatArray(factors, 0.f);

        final Object[] forwardObjs = new Object[3];
        forwardObjs[0] = f_idx;
        forwardObjs[1] = f_Wi;
        forwardObjs[2] = null;
        // W0
        f_idx.set(0);
        f_Wi.set(model.getW0());
        // V0 is null
        forward(forwardObjs);

        // Wi, Vif (i starts from 1..P)
        forwardObjs[2] = Arrays.asList(f_Vi);

        for (int i = model.getMinIndex(), maxIdx = model.getMaxIndex(); i <= maxIdx; i++) {
            final float[] vi = model.getV(i, false);
            if (vi == null) {
                continue;
            }
            f_idx.set(i);
            // set Wi
            final float w = model.getW(i);
            f_Wi.set(w);
            // set Vif
            for (int f = 0; f < factors; f++) {
                float v = vi[f];
                f_Vi[f].set(v);
            }
            forward(forwardObjs);
        }
    }

    private void forwardAsStringFeature(@Nonnull final FMStringFeatureMapModel model, final int factors)
            throws HiveException {
        final Text feature = new Text();
        final FloatWritable f_Wi = new FloatWritable(0.f);
        final FloatWritable[] f_Vi = HiveUtils.newFloatArray(factors, 0.f);

        final Object[] forwardObjs = new Object[3];
        forwardObjs[0] = feature;
        forwardObjs[1] = f_Wi;
        forwardObjs[2] = null;
        // W0
        feature.set("0");
        f_Wi.set(model.getW0());
        // V0 is null
        forward(forwardObjs);

        // Wi, Vif (i starts from 1..P)
        forwardObjs[2] = Arrays.asList(f_Vi);

        final IMapIterator<String, Entry> itor = model.entries();
        while (itor.next() != -1) {
            String i = itor.getKey();
            assert (i != null);
            // set i
            feature.set(i);
            Entry entry = itor.getValue();
            // set Wi
            f_Wi.set(entry.W);
            // set Vif
            final float[] Vi = entry.Vf;
            for (int f = 0; f < factors; f++) {
                float v = Vi[f];
                f_Vi[f].set(v);
            }
            forward(forwardObjs);
        }
    }

    protected void runTrainingIteration(int iterations) throws HiveException {
        final ByteBuffer inputBuf = this._inputBuf;
        final NioStatefullSegment fileIO = this._fileIO;
        assert (inputBuf != null);
        assert (fileIO != null);
        final long numTrainingExamples = _t;
        final boolean adaregr = _va_rand != null;

        final Reporter reporter = getReporter();
        final Counter iterCounter = (reporter == null) ? null
                : reporter.getCounter("hivemall.fm.FactorizationMachines$Counter", "iteration");

        try {
            if (fileIO.getPosition() == 0L) {// run iterations w/o temporary file
                if (inputBuf.position() == 0) {
                    return; // no training example
                }
                inputBuf.flip();

                int iter = 2;
                for (; iter <= iterations; iter++) {
                    reportProgress(reporter);
                    setCounterValue(iterCounter, iter);

                    while (inputBuf.remaining() > 0) {
                        int bytes = inputBuf.getInt();
                        assert (bytes > 0) : bytes;
                        int xLength = inputBuf.getInt();
                        final Feature[] x = new Feature[xLength];
                        for (int j = 0; j < xLength; j++) {
                            x[j] = instantiateFeature(inputBuf);
                        }
                        double y = inputBuf.getDouble();
                        // invoke train
                        ++_t;
                        train(x, y, adaregr);
                    }
                    if (_cvState.isConverged(iter, numTrainingExamples)) {
                        break;
                    }
                    inputBuf.rewind();
                }
                LOG.info("Performed " + Math.min(iter, iterations) + " iterations of "
                        + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus "
                        + NumberUtils.formatNumber(_t) + " training updates in total) ");
            } else {// read training examples in the temporary file and invoke train for each example

                // write training examples in buffer to a temporary file
                if (inputBuf.remaining() > 0) {
                    writeBuffer(inputBuf, fileIO);
                }
                try {
                    fileIO.flush();
                } catch (IOException e) {
                    throw new HiveException("Failed to flush a file: " + fileIO.getFile().getAbsolutePath(), e);
                }
                if (LOG.isInfoEnabled()) {
                    File tmpFile = fileIO.getFile();
                    LOG.info(
                            "Wrote " + numTrainingExamples + " records to a temporary file for iterative training: "
                                    + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")");
                }

                // run iterations
                int iter = 2;
                for (; iter <= iterations; iter++) {
                    setCounterValue(iterCounter, iter);

                    inputBuf.clear();
                    fileIO.resetPosition();
                    while (true) {
                        reportProgress(reporter);
                        // TODO prefetch
                        // writes training examples to a buffer in the temporary file
                        final int bytesRead;
                        try {
                            bytesRead = fileIO.read(inputBuf);
                        } catch (IOException e) {
                            throw new HiveException("Failed to read a file: " + fileIO.getFile().getAbsolutePath(),
                                    e);
                        }
                        if (bytesRead == 0) { // reached file EOF
                            break;
                        }
                        assert (bytesRead > 0) : bytesRead;

                        // reads training examples from a buffer
                        inputBuf.flip();
                        int remain = inputBuf.remaining();
                        if (remain < INT_BYTES) {
                            throw new HiveException("Illegal file format was detected");
                        }
                        while (remain >= INT_BYTES) {
                            int pos = inputBuf.position();
                            int recordBytes = inputBuf.getInt();
                            remain -= INT_BYTES;
                            if (remain < recordBytes) {
                                inputBuf.position(pos);
                                break;
                            }

                            final int xLength = inputBuf.getInt();
                            final Feature[] x = new Feature[xLength];
                            for (int j = 0; j < xLength; j++) {
                                x[j] = instantiateFeature(inputBuf);
                            }
                            double y = inputBuf.getDouble();

                            // invoke training
                            ++_t;
                            train(x, y, adaregr);

                            remain -= recordBytes;
                        }
                        inputBuf.compact();
                    }
                    if (_cvState.isConverged(iter, numTrainingExamples)) {
                        break;
                    }
                }
                LOG.info("Performed " + Math.min(iter, iterations) + " iterations of "
                        + NumberUtils.formatNumber(numTrainingExamples)
                        + " training examples on a secondary storage (thus " + NumberUtils.formatNumber(_t)
                        + " training updates in total)");
            }
        } finally {
            // delete the temporary file and release resources
            try {
                fileIO.close(true);
            } catch (IOException e) {
                throw new HiveException("Failed to close a file: " + fileIO.getFile().getAbsolutePath(), e);
            }
            this._inputBuf = null;
            this._fileIO = null;
        }
    }

    @Nonnull
    protected Feature instantiateFeature(@Nonnull final ByteBuffer input) {
        if (_parseFeatureAsInt) {
            return new IntFeature(input);
        } else {
            return new StringFeature(input);
        }
    }

}