com.ibm.bi.dml.runtime.functionobjects.Builtin.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.bi.dml.runtime.functionobjects.Builtin.java

Source

/**
 * (C) Copyright IBM Corp. 2010, 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
*/

package com.ibm.bi.dml.runtime.functionobjects;

import java.util.HashMap;

import org.apache.commons.math3.util.FastMath;

import com.ibm.bi.dml.api.DMLScript;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLScriptException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;

/**
 *  Class with pre-defined set of objects. This class can not be instantiated elsewhere.
 *  
 *  Notes on commons.math FastMath:
 *  * FastMath uses lookup tables and interpolation instead of native calls.
 *  * The memory overhead for those tables is roughly 48KB in total (acceptable)
 *  * Micro and application benchmarks showed significantly (30%-3x) performance improvements
 *    for most operations; without loss of accuracy.
 *  * atan / sqrt were 20% slower in FastMath and hence, we use Math there
 *  * round / abs were equivalent in FastMath and hence, we use Math there
 *  * Finally, there is just one argument against FastMath - The comparison heavily depends
 *    on the JVM. For example, currently the IBM JDK JIT compiles to HW instructions for sqrt
 *    which makes this operation very efficient; as soon as other operations like log/exp are
 *    similarly compiled, we should rerun the micro benchmarks, and switch back if necessary.
 *  
 */
public class Builtin extends ValueFunction {

    private static final long serialVersionUID = 3836744687789840574L;

    public enum BuiltinFunctionCode {
        INVALID, SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SQRT, EXP, PLOGP, PRINT, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP
    };

    public BuiltinFunctionCode bFunc;

    private static final boolean FASTMATH = true;

    static public HashMap<String, BuiltinFunctionCode> String2BuiltinFunctionCode;
    static {
        String2BuiltinFunctionCode = new HashMap<String, BuiltinFunctionCode>();

        String2BuiltinFunctionCode.put("sin", BuiltinFunctionCode.SIN);
        String2BuiltinFunctionCode.put("cos", BuiltinFunctionCode.COS);
        String2BuiltinFunctionCode.put("tan", BuiltinFunctionCode.TAN);
        String2BuiltinFunctionCode.put("asin", BuiltinFunctionCode.ASIN);
        String2BuiltinFunctionCode.put("acos", BuiltinFunctionCode.ACOS);
        String2BuiltinFunctionCode.put("atan", BuiltinFunctionCode.ATAN);
        String2BuiltinFunctionCode.put("log", BuiltinFunctionCode.LOG);
        String2BuiltinFunctionCode.put("log_nz", BuiltinFunctionCode.LOG_NZ);
        String2BuiltinFunctionCode.put("min", BuiltinFunctionCode.MIN);
        String2BuiltinFunctionCode.put("max", BuiltinFunctionCode.MAX);
        String2BuiltinFunctionCode.put("maxindex", BuiltinFunctionCode.MAXINDEX);
        String2BuiltinFunctionCode.put("minindex", BuiltinFunctionCode.MININDEX);
        String2BuiltinFunctionCode.put("abs", BuiltinFunctionCode.ABS);
        String2BuiltinFunctionCode.put("sqrt", BuiltinFunctionCode.SQRT);
        String2BuiltinFunctionCode.put("exp", BuiltinFunctionCode.EXP);
        String2BuiltinFunctionCode.put("plogp", BuiltinFunctionCode.PLOGP);
        String2BuiltinFunctionCode.put("print", BuiltinFunctionCode.PRINT);
        String2BuiltinFunctionCode.put("nrow", BuiltinFunctionCode.NROW);
        String2BuiltinFunctionCode.put("ncol", BuiltinFunctionCode.NCOL);
        String2BuiltinFunctionCode.put("length", BuiltinFunctionCode.LENGTH);
        String2BuiltinFunctionCode.put("round", BuiltinFunctionCode.ROUND);
        String2BuiltinFunctionCode.put("stop", BuiltinFunctionCode.STOP);
        String2BuiltinFunctionCode.put("ceil", BuiltinFunctionCode.CEIL);
        String2BuiltinFunctionCode.put("floor", BuiltinFunctionCode.FLOOR);
        String2BuiltinFunctionCode.put("ucumk+", BuiltinFunctionCode.CUMSUM);
        String2BuiltinFunctionCode.put("ucum*", BuiltinFunctionCode.CUMPROD);
        String2BuiltinFunctionCode.put("ucummin", BuiltinFunctionCode.CUMMIN);
        String2BuiltinFunctionCode.put("ucummax", BuiltinFunctionCode.CUMMAX);
        String2BuiltinFunctionCode.put("inverse", BuiltinFunctionCode.INVERSE);
        String2BuiltinFunctionCode.put("sprop", BuiltinFunctionCode.SPROP);
        String2BuiltinFunctionCode.put("sigmoid", BuiltinFunctionCode.SIGMOID);
        String2BuiltinFunctionCode.put("sel+", BuiltinFunctionCode.SELP);
    }

    // We should create one object for every builtin function that we support
    private static Builtin sinObj = null, cosObj = null, tanObj = null, asinObj = null, acosObj = null,
            atanObj = null;
    private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null,
            minindexObj = null;
    private static Builtin absObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null;
    private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj = null,
            floorObj = null;
    private static Builtin inverseObj = null, cumsumObj = null, cumprodObj = null, cumminObj = null,
            cummaxObj = null;
    private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null;

    private Builtin(BuiltinFunctionCode bf) {
        bFunc = bf;
    }

    public BuiltinFunctionCode getBuiltinFunctionCode() {
        return bFunc;
    }

    /**
     * 
     * @param str
     * @return
     */
    public static Builtin getBuiltinFnObject(String str) {
        BuiltinFunctionCode code = String2BuiltinFunctionCode.get(str);
        return getBuiltinFnObject(code);
    }

    /**
     * 
     * @param code
     * @return
     */
    public static Builtin getBuiltinFnObject(BuiltinFunctionCode code) {
        if (code == null)
            return null;

        switch (code) {
        case SIN:
            if (sinObj == null)
                sinObj = new Builtin(BuiltinFunctionCode.SIN);
            return sinObj;

        case COS:
            if (cosObj == null)
                cosObj = new Builtin(BuiltinFunctionCode.COS);
            return cosObj;
        case TAN:
            if (tanObj == null)
                tanObj = new Builtin(BuiltinFunctionCode.TAN);
            return tanObj;
        case ASIN:
            if (asinObj == null)
                asinObj = new Builtin(BuiltinFunctionCode.ASIN);
            return asinObj;

        case ACOS:
            if (acosObj == null)
                acosObj = new Builtin(BuiltinFunctionCode.ACOS);
            return acosObj;
        case ATAN:
            if (atanObj == null)
                atanObj = new Builtin(BuiltinFunctionCode.ATAN);
            return atanObj;
        case LOG:
            if (logObj == null)
                logObj = new Builtin(BuiltinFunctionCode.LOG);
            return logObj;
        case LOG_NZ:
            if (lognzObj == null)
                lognzObj = new Builtin(BuiltinFunctionCode.LOG_NZ);
            return lognzObj;
        case MAX:
            if (maxObj == null)
                maxObj = new Builtin(BuiltinFunctionCode.MAX);
            return maxObj;
        case MAXINDEX:
            if (maxindexObj == null)
                maxindexObj = new Builtin(BuiltinFunctionCode.MAXINDEX);
            return maxindexObj;
        case MIN:
            if (minObj == null)
                minObj = new Builtin(BuiltinFunctionCode.MIN);
            return minObj;
        case MININDEX:
            if (minindexObj == null)
                minindexObj = new Builtin(BuiltinFunctionCode.MININDEX);
            return minindexObj;
        case ABS:
            if (absObj == null)
                absObj = new Builtin(BuiltinFunctionCode.ABS);
            return absObj;
        case SQRT:
            if (sqrtObj == null)
                sqrtObj = new Builtin(BuiltinFunctionCode.SQRT);
            return sqrtObj;
        case EXP:
            if (expObj == null)
                expObj = new Builtin(BuiltinFunctionCode.EXP);
            return expObj;
        case PLOGP:
            if (plogpObj == null)
                plogpObj = new Builtin(BuiltinFunctionCode.PLOGP);
            return plogpObj;
        case PRINT:
            if (printObj == null)
                printObj = new Builtin(BuiltinFunctionCode.PRINT);
            return printObj;
        case NROW:
            if (nrowObj == null)
                nrowObj = new Builtin(BuiltinFunctionCode.NROW);
            return nrowObj;
        case NCOL:
            if (ncolObj == null)
                ncolObj = new Builtin(BuiltinFunctionCode.NCOL);
            return ncolObj;
        case LENGTH:
            if (lengthObj == null)
                lengthObj = new Builtin(BuiltinFunctionCode.LENGTH);
            return lengthObj;
        case ROUND:
            if (roundObj == null)
                roundObj = new Builtin(BuiltinFunctionCode.ROUND);
            return roundObj;
        case CEIL:
            if (ceilObj == null)
                ceilObj = new Builtin(BuiltinFunctionCode.CEIL);
            return ceilObj;
        case FLOOR:
            if (floorObj == null)
                floorObj = new Builtin(BuiltinFunctionCode.FLOOR);
            return floorObj;
        case CUMSUM:
            if (cumsumObj == null)
                cumsumObj = new Builtin(BuiltinFunctionCode.CUMSUM);
            return cumsumObj;
        case CUMPROD:
            if (cumprodObj == null)
                cumprodObj = new Builtin(BuiltinFunctionCode.CUMPROD);
            return cumprodObj;
        case CUMMIN:
            if (cumminObj == null)
                cumminObj = new Builtin(BuiltinFunctionCode.CUMMIN);
            return cumminObj;
        case CUMMAX:
            if (cummaxObj == null)
                cummaxObj = new Builtin(BuiltinFunctionCode.CUMMAX);
            return cummaxObj;
        case INVERSE:
            if (inverseObj == null)
                inverseObj = new Builtin(BuiltinFunctionCode.INVERSE);
            return inverseObj;
        case STOP:
            if (stopObj == null)
                stopObj = new Builtin(BuiltinFunctionCode.STOP);
            return stopObj;

        case SPROP:
            if (spropObj == null)
                spropObj = new Builtin(BuiltinFunctionCode.SPROP);
            return spropObj;

        case SIGMOID:
            if (sigmoidObj == null)
                sigmoidObj = new Builtin(BuiltinFunctionCode.SIGMOID);
            return sigmoidObj;

        case SELP:
            if (selpObj == null)
                selpObj = new Builtin(BuiltinFunctionCode.SELP);
            return selpObj;

        default:
            // Unknown code --> return null
            return null;
        }
    }

    public Object clone() throws CloneNotSupportedException {
        // cloning is not supported for singleton classes
        throw new CloneNotSupportedException();
    }

    public boolean checkArity(int _arity) throws DMLUnsupportedOperationException {
        switch (bFunc) {
        case ABS:
        case SIN:
        case COS:
        case TAN:
        case ASIN:
        case ACOS:
        case ATAN:
        case SQRT:
        case EXP:
        case PLOGP:
        case NROW:
        case NCOL:
        case LENGTH:
        case ROUND:
        case PRINT:
        case MAXINDEX:
        case MININDEX:
        case STOP:
        case CEIL:
        case FLOOR:
        case CUMSUM:
        case INVERSE:
        case SPROP:
        case SIGMOID:
        case SELP:
            return (_arity == 1);

        case LOG:
        case LOG_NZ:
            return (_arity == 1 || _arity == 2);

        case MAX:
        case MIN:
            return (_arity == 2);
        default:
            throw new DMLUnsupportedOperationException("checkNumberOfOperands(): Unknown opcode: " + bFunc);
        }
    }

    public double execute(double in) throws DMLRuntimeException {
        switch (bFunc) {
        case SIN:
            return FASTMATH ? FastMath.sin(in) : Math.sin(in);
        case COS:
            return FASTMATH ? FastMath.cos(in) : Math.cos(in);
        case TAN:
            return FASTMATH ? FastMath.tan(in) : Math.tan(in);
        case ASIN:
            return FASTMATH ? FastMath.asin(in) : Math.asin(in);
        case ACOS:
            return FASTMATH ? FastMath.acos(in) : Math.acos(in);
        case ATAN:
            return Math.atan(in); //faster in Math
        case CEIL:
            return FASTMATH ? FastMath.ceil(in) : Math.ceil(in);
        case FLOOR:
            return FASTMATH ? FastMath.floor(in) : Math.floor(in);
        case LOG:
            //if ( in <= 0 )
            //   throw new DMLRuntimeException("Builtin.execute(): logarithm can only be computed for non-negative numbers (input = " + in + ").");
            // for negative numbers, Math.log will return NaN
            return FASTMATH ? FastMath.log(in) : Math.log(in);
        case LOG_NZ:
            return (in == 0) ? 0 : FASTMATH ? FastMath.log(in) : Math.log(in);

        case ABS:
            return Math.abs(in); //no need for FastMath

        case SQRT:
            //if ( in < 0 )
            //   throw new DMLRuntimeException("Builtin.execute(): squareroot can only be computed for non-negative numbers (input = " + in + ").");
            return Math.sqrt(in); //faster in Math

        case PLOGP:
            if (in == 0.0)
                return 0.0;
            else if (in < 0)
                return Double.NaN;
            else
                return (in * (FASTMATH ? FastMath.log(in) : Math.log(in)));

        case EXP:
            return FASTMATH ? FastMath.exp(in) : Math.exp(in);

        case ROUND:
            return Math.round(in); //no need for FastMath

        case SPROP:
            //sample proportion: P*(1-P)
            return in * (1 - in);

        case SIGMOID:
            //sigmoid: 1/(1+exp(-x))
            return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in));

        case SELP:
            //select positive: x*(x>0)
            return (in > 0) ? in : 0;

        default:
            throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
        }
    }

    public double execute(long in) throws DMLRuntimeException {
        return this.execute((double) in);
    }

    /*
     * Builtin functions with two inputs
     */

    public double execute(double in1, double in2) throws DMLRuntimeException {
        switch (bFunc) {

        /*
         * Arithmetic relational operators (==, !=, <=, >=) must be instead of
         * <code>Double.compare()</code> due to the inconsistencies in the way
         * NaN and -0.0 are handled. The behavior of methods in
         * <code>Double</code> class are designed mainly to make Java
         * collections work properly. For more details, see the help for
         * <code>Double.equals()</code> and <code>Double.comapreTo()</code>.
         */
        case MAX:
        case CUMMAX:
            //return (Double.compare(in1, in2) >= 0 ? in1 : in2);
            return (in1 >= in2 ? in1 : in2);
        case MIN:
        case CUMMIN:
            //return (Double.compare(in1, in2) <= 0 ? in1 : in2);
            return (in1 <= in2 ? in1 : in2);

        // *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
        // rowIndexMax() and its siblings require comparing four values, but
        // the aggregation API only allows two values. So the execute()
        // method receives as its argument the two cell values to be
        // compared and performs just the value part of the comparison. We
        // return an integer cast down to a double, since the aggregation
        // API doesn't have any way to return anything but a double. The
        // integer returned takes on three posssible values: //
        // .     0 => keep the index associated with in1 //
        // .     1 => use the index associated with in2 //
        // .     2 => use whichever index is higher (tie in value) //
        case MAXINDEX:
            if (in1 == in2) {
                return 2;
            } else if (in1 > in2) {
                return 1;
            } else { // in1 < in2
                return 0;
            }
        case MININDEX:
            if (in1 == in2) {
                return 2;
            } else if (in1 < in2) {
                return 1;
            } else { // in1 > in2
                return 0;
            }
            // *** END HACK ***
        case LOG:
            //if ( in1 <= 0 )
            //   throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers.");
            if (FASTMATH)
                return (FastMath.log(in1) / FastMath.log(in2));
            else
                return (Math.log(in1) / Math.log(in2));
        case LOG_NZ:
            if (FASTMATH)
                return (in1 == 0) ? 0 : (FastMath.log(in1) / FastMath.log(in2));
            else
                return (in1 == 0) ? 0 : (Math.log(in1) / Math.log(in2));

        default:
            throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
        }
    }

    /**
     * Simplified version without exception handling
     * 
     * @param in1
     * @param in2
     * @return
     */
    public double execute2(double in1, double in2) {
        switch (bFunc) {
        case MAX:
        case CUMMAX:
            //return (Double.compare(in1, in2) >= 0 ? in1 : in2); 
            return (in1 >= in2 ? in1 : in2);
        case MIN:
        case CUMMIN:
            //return (Double.compare(in1, in2) <= 0 ? in1 : in2); 
            return (in1 <= in2 ? in1 : in2);
        case MAXINDEX:
            return (in1 >= in2) ? 1 : 0;
        case MININDEX:
            return (in1 <= in2) ? 1 : 0;

        default:
            // For performance reasons, avoid throwing an exception 
            return -1;
        }
    }

    public double execute(long in1, long in2) throws DMLRuntimeException {
        switch (bFunc) {

        case MAX:
        case CUMMAX:
            return (in1 >= in2 ? in1 : in2);

        case MIN:
        case CUMMIN:
            return (in1 <= in2 ? in1 : in2);

        case MAXINDEX:
            return (in1 >= in2) ? 1 : 0;
        case MININDEX:
            return (in1 <= in2) ? 1 : 0;

        case LOG:
            //if ( in1 <= 0 )
            //   throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers.");
            if (FASTMATH)
                return (FastMath.log(in1) / FastMath.log(in2));
            else
                return (Math.log(in1) / Math.log(in2));
        case LOG_NZ:
            if (FASTMATH)
                return (in1 == 0) ? 0 : (FastMath.log(in1) / FastMath.log(in2));
            else
                return (in1 == 0) ? 0 : (Math.log(in1) / Math.log(in2));

        default:
            throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
        }
    }

    // currently, it is used only for PRINT and STOP
    public String execute(String in1) throws DMLRuntimeException {
        switch (bFunc) {
        case PRINT:
            if (!DMLScript.suppressPrint2Stdout())
                System.out.println(in1);
            return null;
        case STOP:
            throw new DMLScriptException(in1);
        default:
            throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
        }
    }

}