Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.xgboost; import hivemall.UDTFWithOptions; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.Primitives; import java.io.ByteArrayInputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import javax.annotation.Nonnull; import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoostError; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; public abstract class XGBoostPredictUDTF extends UDTFWithOptions { // For input parameters private PrimitiveObjectInspector rowIdOI; private ListObjectInspector featureListOI; private PrimitiveObjectInspector featureElemOI; private PrimitiveObjectInspector modelIdOI; private PrimitiveObjectInspector modelOI; // For input buffer private Map<String, Booster> mapToModel; private Map<String, List<LabeledPointWithRowId>> rowBuffer; private int batch_size; // Settings for the XGBoost native library static { NativeLibLoader.initXGBoost(); } public XGBoostPredictUDTF() { super(); } @Override protected Options getOptions() { Options opts = new Options(); opts.addOption("batch_size", true, "Number of rows to predict together [default: 128]"); return opts; } @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { int _batch_size = 128; CommandLine cl = null; if (argOIs.length >= 5) { String rawArgs = HiveUtils.getConstString(argOIs[4]); cl = this.parseOptions(rawArgs); _batch_size = Primitives.parseInt(cl.getOptionValue("_batch_size"), _batch_size); if (_batch_size < 1) { throw new IllegalArgumentException("batch_size must be greater than 0: " + _batch_size); } } this.batch_size = _batch_size; return cl; } /** Override this to output predicted results depending on a task type */ @Nonnull protected abstract StructObjectInspector getReturnOI(); protected abstract void forwardPredicted(@Nonnull final List<LabeledPointWithRowId> testData, @Nonnull final float[][] predicted) throws HiveException; @Override public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 4 && argOIs.length != 5) { throw new UDFArgumentException(this.getClass().getSimpleName() + " takes 4 or 5 arguments: string rowid, string[] features, string model_id," + " array<byte> pred_model [, string options]: " + argOIs.length); } else { this.processOptions(argOIs); this.rowIdOI = HiveUtils.asStringOI(argOIs[0]); final ListObjectInspector listOI = HiveUtils.asListOI(argOIs[1]); final ObjectInspector elemOI = listOI.getListElementObjectInspector(); this.featureListOI = listOI; this.featureElemOI = HiveUtils.asStringOI(elemOI); this.modelIdOI = HiveUtils.asStringOI(argOIs[2]); this.modelOI = HiveUtils.asBinaryOI(argOIs[3]); this.mapToModel = new HashMap<String, Booster>(); this.rowBuffer = new HashMap<String, List<LabeledPointWithRowId>>(); return getReturnOI(); } } @Nonnull private static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data) throws XGBoostError { final List<LabeledPoint> points = new ArrayList<>(data.size()); for (LabeledPointWithRowId d : data) { points.add(d.point); } return new DMatrix(points.iterator(), ""); } @Nonnull private static Booster initXgBooster(@Nonnull final byte[] input) throws HiveException { try { return XGBoost.loadModel(new ByteArrayInputStream(input)); } catch (Exception e) { throw new HiveException(e); } } private void predictAndFlush(final Booster model, final List<LabeledPointWithRowId> buf) throws HiveException { final DMatrix testData; final float[][] predicted; try { testData = createDMatrix(buf); predicted = model.predict(testData); } catch (XGBoostError e) { throw new HiveException(e); } forwardPredicted(buf, predicted); buf.clear(); } @Override public void process(Object[] args) throws HiveException { if (args[1] == null) { return; } final String rowId = PrimitiveObjectInspectorUtils.getString(args[0], rowIdOI); final List<?> features = (List<?>) featureListOI.getList(args[1]); final String[] fv = new String[features.size()]; for (int i = 0; i < features.size(); i++) { fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i)); } final String modelId = PrimitiveObjectInspectorUtils.getString(args[2], modelIdOI); if (!mapToModel.containsKey(modelId)) { final byte[] predModel = PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI).getBytes(); mapToModel.put(modelId, initXgBooster(predModel)); } final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, fv); if (point == null) { return; } List<LabeledPointWithRowId> buf = rowBuffer.get(modelId); if (buf == null) { buf = new ArrayList<LabeledPointWithRowId>(); rowBuffer.put(modelId, buf); } buf.add(new LabeledPointWithRowId(rowId, point)); if (buf.size() >= batch_size) { predictAndFlush(mapToModel.get(modelId), buf); } } public static final class LabeledPointWithRowId { @Nonnull final String rowId; @Nonnull final LabeledPoint point; LabeledPointWithRowId(@Nonnull String rowId, @Nonnull LabeledPoint point) { this.rowId = rowId; this.point = point; } @Nonnull public String getRowId() { return rowId; } @Nonnull public LabeledPoint getPoint() { return point; } } @Override public void close() throws HiveException { for (Entry<String, List<LabeledPointWithRowId>> e : rowBuffer.entrySet()) { predictAndFlush(mapToModel.get(e.getKey()), e.getValue()); } } }