hivemall.dataset.LogisticRegressionDataGeneratorUDTFWrapper.java Source code

Java tutorial

Introduction

Here is the source code for hivemall.dataset.LogisticRegressionDataGeneratorUDTFWrapper.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package hivemall.dataset;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Random;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.Collector;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

import hivemall.UDTFWithOptions;

/**
 * A wrapper of [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]]. This wrapper is needed
 * because Spark cannot handle HadoopUtils#getTaskId() correctly.
 */
@Description(name = "lr_datagen", value = "_FUNC_(options string) - Generates a logistic regression dataset")
public final class LogisticRegressionDataGeneratorUDTFWrapper extends UDTFWithOptions {
    private transient LogisticRegressionDataGeneratorUDTF udtf = new LogisticRegressionDataGeneratorUDTF();

    @Override
    protected Options getOptions() {
        Options options = null;
        try {
            Method m = udtf.getClass().getDeclaredMethod("getOptions");
            m.setAccessible(true);
            options = (Options) m.invoke(udtf);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return options;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] objectInspectors) throws UDFArgumentException {
        CommandLine commands = null;
        try {
            Method m = udtf.getClass().getDeclaredMethod("processOptions");
            m.setAccessible(true);
            commands = (CommandLine) m.invoke(udtf, objectInspectors);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return commands;
    }

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        try {
            // Extract a collector for LogisticRegressionDataGeneratorUDTF
            Field collector = GenericUDTF.class.getDeclaredField("collector");
            collector.setAccessible(true);
            udtf.setCollector((Collector) collector.get(this));

            // To avoid HadoopUtils#getTaskId()
            Class<?> clazz = udtf.getClass();
            Field rnd1 = clazz.getDeclaredField("rnd1");
            Field rnd2 = clazz.getDeclaredField("rnd2");
            Field r_seed = clazz.getDeclaredField("r_seed");
            r_seed.setAccessible(true);
            final long seed = r_seed.getLong(udtf) + (int) Thread.currentThread().getId();
            rnd1.setAccessible(true);
            rnd2.setAccessible(true);
            rnd1.set(udtf, new Random(seed));
            rnd2.set(udtf, new Random(seed + 1));
        } catch (Exception e) {
            e.printStackTrace();
        }
        return udtf.initialize(argOIs);
    }

    @Override
    public void process(Object[] objects) throws HiveException {
        udtf.process(objects);
    }

    @Override
    public void close() throws HiveException {
        udtf.close();
    }
}