hivemall.xgboost.XGBoostUDTF.java Source code

Java tutorial

Introduction

Here is the source code for hivemall.xgboost.XGBoostUDTF.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.xgboost;

import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.Nonnull;

import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

/**
 * This is a base class to handle the options for XGBoost and provide common functions among various
 * tasks.
 */
public abstract class XGBoostUDTF extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(XGBoostUDTF.class);

    // Settings for the XGBoost native library
    static {
        NativeLibLoader.initXGBoost();
    }

    // For input buffer
    private final List<LabeledPoint> featuresList;

    // For input parameters
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private PrimitiveObjectInspector targetOI;

    // For XGBoost options
    @Nonnull
    protected final Map<String, Object> params = new HashMap<String, Object>();

    // XGBoost options can be found in https://github.com/dmlc/xgboost/blob/master/doc/parameter.md
    // Most of default parameters are set along with the official one.
    {
        /** General parameters */
        params.put("booster", "gbtree");
        params.put("num_round", 8);
        params.put("silent", 1);
        // Set to 1 by default because most of distributed systems assume
        // each worker has a single vcore.
        params.put("nthread", 1);

        /** Parameters for both boosters */
        params.put("alpha", 0.0);
        // This default value depends on a booster type
        // params.put("lambda", 0.0);

        /** Parameters for Tree Booster */
        params.put("eta", 0.3);
        params.put("gamma", 0.0);
        params.put("max_depth", 6);
        params.put("min_child_weight", 1);
        params.put("max_delta_step", 0);
        params.put("subsample", 1.0);
        params.put("colsample_bytree", 1.0);
        params.put("colsample_bylevel", 1.0);
        // The memory-based version of XGBoost only supports `exact`
        params.put("tree_method", "exact");

        /** Learning Task Parameters */
        params.put("base_score", 0.5);
    }

    public XGBoostUDTF() {
        this.featuresList = new ArrayList<>(1024);
    }

    @Override
    protected Options getOptions() {
        final Options opts = new Options();

        /** General parameters */
        opts.addOption("booster", true, "Set a booster to use, gbtree or gblinear. [default: gbree]");
        opts.addOption("num_round", true, "Number of boosting iterations [default: 8]");
        opts.addOption("silent", true, "0 means printing running messages, 1 means silent mode [default: 1]");
        opts.addOption("nthread", true, "Number of parallel threads used to run xgboost [default: 1]");
        opts.addOption("num_pbuffer", true, "Size of prediction buffer [set automatically by xgboost]");
        opts.addOption("num_feature", true,
                "Feature dimension used in boosting [default: set automatically by xgboost]");

        /** Parameters for both boosters */
        opts.addOption("alpha", true, "L1 regularization term on weights [default: 0.0]");
        opts.addOption("lambda", true,
                "L2 regularization term on weights [default: 1.0 for gbtree, 0.0 for gblinear]");

        /** Parameters for Tree Booster */
        opts.addOption("eta", true, "Step size shrinkage used in update to prevents overfitting [default: 0.3]");
        opts.addOption("gamma", true,
                "Minimum loss reduction required to make a further partition on a leaf node of the tree [default: 0.0]");
        opts.addOption("max_depth", true, "Max depth of decision tree [default: 6]");
        opts.addOption("min_child_weight", true,
                "Minimum sum of instance weight(hessian) needed in a child [default: 1]");
        opts.addOption("max_delta_step", true,
                "Maximum delta step we allow each tree's weight estimation to be [default: 0]");
        opts.addOption("subsample", true, "Subsample ratio of the training instance [default: 1.0]");
        opts.addOption("colsample_bytree", true,
                "Subsample ratio of columns when constructing each tree [default: 1.0]");
        opts.addOption("colsample_bylevel", true,
                "Subsample ratio of columns for each split, in each level [default: 1.0]");

        /** Parameters for Linear Booster */
        opts.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]");

        /** Learning Task Parameters */
        opts.addOption("base_score", true, "Initial prediction score of all instances, global bias [default: 0.5]");
        opts.addOption("eval_metric", true,
                "Evaluation metrics for validation data [default according to objective]");

        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;
        if (argOIs.length >= 3) {
            final String rawArgs = HiveUtils.getConstString(argOIs[2]);
            cl = this.parseOptions(rawArgs);

            /** General parameters */
            if (cl.hasOption("booster")) {
                params.put("booster", cl.getOptionValue("booster"));
            }
            if (cl.hasOption("num_round")) {
                params.put("num_round", Integer.valueOf(cl.getOptionValue("num_round")));
            }
            if (cl.hasOption("silent")) {
                params.put("silent", Integer.valueOf(cl.getOptionValue("silent")));
            }
            if (cl.hasOption("nthread")) {
                params.put("nthread", Integer.valueOf(cl.getOptionValue("nthread")));
            }
            if (cl.hasOption("num_pbuffer")) {
                params.put("num_pbuffer", Integer.valueOf(cl.getOptionValue("num_pbuffer")));
            }
            if (cl.hasOption("num_feature")) {
                params.put("num_feature", Integer.valueOf(cl.getOptionValue("num_feature")));
            }

            /** Parameters for both boosters */
            if (cl.hasOption("alpha")) {
                params.put("alpha", Double.valueOf(cl.getOptionValue("alpha")));
            }
            if (cl.hasOption("lambda")) {
                params.put("lambda", Double.valueOf(cl.getOptionValue("lambda")));
            }

            /** Parameters for Tree Booster */
            if (cl.hasOption("eta")) {
                params.put("eta", Double.valueOf(cl.getOptionValue("eta")));
            }
            if (cl.hasOption("gamma")) {
                params.put("gamma", Double.valueOf(cl.getOptionValue("gamma")));
            }
            if (cl.hasOption("max_depth")) {
                params.put("max_depth", Integer.valueOf(cl.getOptionValue("max_depth")));
            }
            if (cl.hasOption("min_child_weight")) {
                params.put("min_child_weight", Integer.valueOf(cl.getOptionValue("min_child_weight")));
            }
            if (cl.hasOption("max_delta_step")) {
                params.put("max_delta_step", Integer.valueOf(cl.getOptionValue("max_delta_step")));
            }
            if (cl.hasOption("subsample")) {
                params.put("subsample", Double.valueOf(cl.getOptionValue("subsample")));
            }
            if (cl.hasOption("colsample_bytree")) {
                params.put("colsamle_bytree", Double.valueOf(cl.getOptionValue("colsample_bytree")));
            }
            if (cl.hasOption("colsample_bylevel")) {
                params.put("colsamle_bylevel", Double.valueOf(cl.getOptionValue("colsample_bylevel")));
            }

            /** Parameters for Linear Booster */
            if (cl.hasOption("lambda_bias")) {
                params.put("lambda_bias", Double.valueOf(cl.getOptionValue("lambda_bias")));
            }

            /** Learning Task Parameters */
            if (cl.hasOption("base_score")) {
                params.put("base_score", Double.valueOf(cl.getOptionValue("base_score")));
            }
            if (cl.hasOption("eval_metric")) {
                params.put("eval_metric", cl.getOptionValue("eval_metric"));
            }
        }

        try {
            // Try to create a `Booster` instance to check if given XGBoost options
            // are valid, or not.
            createXGBooster(params, featuresList);
        } catch (Exception e) {
            throw new UDFArgumentException(e);
        }

        return cl;
    }

    /** All the functions return (string model_id, byte[] pred_model) as built models */
    @Nonnull
    private static StructObjectInspector getReturnOIs() {
        final List<String> fieldNames = new ArrayList<>(2);
        final List<ObjectInspector> fieldOIs = new ArrayList<>(2);
        fieldNames.add("model_id");
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        fieldNames.add("pred_model");
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        processOptions(argOIs);
        final ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
        final ObjectInspector elemOI = listOI.getListElementObjectInspector();
        this.featureListOI = listOI;
        this.featureElemOI = HiveUtils.asStringOI(elemOI);
        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
        return getReturnOIs();
    }

    /** It `target` has valid input range, it overrides this */
    protected void checkTargetValue(double target) throws HiveException {
    }

    @Override
    public void process(@Nonnull Object[] args) throws HiveException {
        if (args[0] == null) {
            return;
        }

        // TODO: Need to support dense inputs
        final List<?> features = (List<?>) featureListOI.getList(args[0]);
        final String[] fv = new String[features.size()];
        for (int i = 0; i < features.size(); i++) {
            fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i));
        }
        double target = PrimitiveObjectInspectorUtils.getDouble(args[1], this.targetOI);
        checkTargetValue(target);
        final LabeledPoint point = XGBoostUtils.parseFeatures(target, fv);
        if (point != null) {
            this.featuresList.add(point);
        }
    }

    @Nonnull
    private static String generateUniqueModelId() {
        return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString();
    }

    @Nonnull
    private static Booster createXGBooster(final Map<String, Object> params, final List<LabeledPoint> input)
            throws NoSuchMethodException, XGBoostError, IllegalAccessException, InvocationTargetException,
            InstantiationException {
        Class<?>[] args = { Map.class, DMatrix[].class };
        Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
        ctor.setAccessible(true);
        return ctor.newInstance(new Object[] { params, new DMatrix[] { new DMatrix(input.iterator(), "") } });
    }

    @Override
    public void close() throws HiveException {
        try {
            // Kick off training with XGBoost
            final DMatrix trainData = new DMatrix(featuresList.iterator(), "");
            final Booster booster = createXGBooster(params, featuresList);
            final int num_round = (Integer) params.get("num_round");
            for (int i = 0; i < num_round; i++) {
                booster.update(trainData, i);
            }

            // Output the built model
            final String modelId = generateUniqueModelId();
            final byte[] predModel = booster.toByteArray();
            logger.info("model_id:" + modelId.toString() + " size:" + predModel.length);
            forward(new Object[] { modelId, predModel });
        } catch (Exception e) {
            throw new HiveException(e);
        }
    }

}