org.apache.sysml.runtime.functionobjects.Builtin.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.runtime.functionobjects.Builtin.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.functionobjects;

import java.util.HashMap;

import org.apache.commons.math3.util.FastMath;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;

/**
 *  Class with pre-defined set of objects. This class can not be instantiated elsewhere.
 *  
 *  Notes on commons.math FastMath:
 *  * FastMath uses lookup tables and interpolation instead of native calls.
 *  * The memory overhead for those tables is roughly 48KB in total (acceptable)
 *  * Micro and application benchmarks showed significantly (30%-3x) performance improvements
 *    for most operations; without loss of accuracy.
 *  * atan / sqrt were 20% slower in FastMath and hence, we use Math there
 *  * round / abs were equivalent in FastMath and hence, we use Math there
 *  * Finally, there is just one argument against FastMath - The comparison heavily depends
 *    on the JVM. For example, currently the IBM JDK JIT compiles to HW instructions for sqrt
 *    which makes this operation very efficient; as soon as other operations like log/exp are
 *    similarly compiled, we should rerun the micro benchmarks, and switch back if necessary.
 *  
 */
public class Builtin extends ValueFunction {

    private static final long serialVersionUID = 3836744687789840574L;

    public enum BuiltinCode {
        SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP
    };

    public BuiltinCode bFunc;

    private static final boolean FASTMATH = true;

    static public HashMap<String, BuiltinCode> String2BuiltinCode;
    static {
        String2BuiltinCode = new HashMap<String, BuiltinCode>();

        String2BuiltinCode.put("sin", BuiltinCode.SIN);
        String2BuiltinCode.put("cos", BuiltinCode.COS);
        String2BuiltinCode.put("tan", BuiltinCode.TAN);
        String2BuiltinCode.put("asin", BuiltinCode.ASIN);
        String2BuiltinCode.put("acos", BuiltinCode.ACOS);
        String2BuiltinCode.put("atan", BuiltinCode.ATAN);
        String2BuiltinCode.put("log", BuiltinCode.LOG);
        String2BuiltinCode.put("log_nz", BuiltinCode.LOG_NZ);
        String2BuiltinCode.put("min", BuiltinCode.MIN);
        String2BuiltinCode.put("max", BuiltinCode.MAX);
        String2BuiltinCode.put("maxindex", BuiltinCode.MAXINDEX);
        String2BuiltinCode.put("minindex", BuiltinCode.MININDEX);
        String2BuiltinCode.put("abs", BuiltinCode.ABS);
        String2BuiltinCode.put("sign", BuiltinCode.SIGN);
        String2BuiltinCode.put("sqrt", BuiltinCode.SQRT);
        String2BuiltinCode.put("exp", BuiltinCode.EXP);
        String2BuiltinCode.put("plogp", BuiltinCode.PLOGP);
        String2BuiltinCode.put("print", BuiltinCode.PRINT);
        String2BuiltinCode.put("printf", BuiltinCode.PRINTF);
        String2BuiltinCode.put("nrow", BuiltinCode.NROW);
        String2BuiltinCode.put("ncol", BuiltinCode.NCOL);
        String2BuiltinCode.put("length", BuiltinCode.LENGTH);
        String2BuiltinCode.put("round", BuiltinCode.ROUND);
        String2BuiltinCode.put("stop", BuiltinCode.STOP);
        String2BuiltinCode.put("ceil", BuiltinCode.CEIL);
        String2BuiltinCode.put("floor", BuiltinCode.FLOOR);
        String2BuiltinCode.put("ucumk+", BuiltinCode.CUMSUM);
        String2BuiltinCode.put("ucum*", BuiltinCode.CUMPROD);
        String2BuiltinCode.put("ucummin", BuiltinCode.CUMMIN);
        String2BuiltinCode.put("ucummax", BuiltinCode.CUMMAX);
        String2BuiltinCode.put("inverse", BuiltinCode.INVERSE);
        String2BuiltinCode.put("sprop", BuiltinCode.SPROP);
        String2BuiltinCode.put("sigmoid", BuiltinCode.SIGMOID);
        String2BuiltinCode.put("sel+", BuiltinCode.SELP);
    }

    // We should create one object for every builtin function that we support
    private static Builtin sinObj = null, cosObj = null, tanObj = null, asinObj = null, acosObj = null,
            atanObj = null;
    private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null,
            minindexObj = null;
    private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null,
            printObj = null, printfObj;
    private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj = null,
            floorObj = null;
    private static Builtin inverseObj = null, cumsumObj = null, cumprodObj = null, cumminObj = null,
            cummaxObj = null;
    private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null;

    private Builtin(BuiltinCode bf) {
        bFunc = bf;
    }

    public BuiltinCode getBuiltinCode() {
        return bFunc;
    }

    public static Builtin getBuiltinFnObject(String str) {
        BuiltinCode code = String2BuiltinCode.get(str);
        return getBuiltinFnObject(code);
    }

    public static Builtin getBuiltinFnObject(BuiltinCode code) {
        if (code == null)
            return null;

        switch (code) {
        case SIN:
            if (sinObj == null)
                sinObj = new Builtin(BuiltinCode.SIN);
            return sinObj;

        case COS:
            if (cosObj == null)
                cosObj = new Builtin(BuiltinCode.COS);
            return cosObj;
        case TAN:
            if (tanObj == null)
                tanObj = new Builtin(BuiltinCode.TAN);
            return tanObj;
        case ASIN:
            if (asinObj == null)
                asinObj = new Builtin(BuiltinCode.ASIN);
            return asinObj;

        case ACOS:
            if (acosObj == null)
                acosObj = new Builtin(BuiltinCode.ACOS);
            return acosObj;
        case ATAN:
            if (atanObj == null)
                atanObj = new Builtin(BuiltinCode.ATAN);
            return atanObj;
        case LOG:
            if (logObj == null)
                logObj = new Builtin(BuiltinCode.LOG);
            return logObj;
        case LOG_NZ:
            if (lognzObj == null)
                lognzObj = new Builtin(BuiltinCode.LOG_NZ);
            return lognzObj;
        case MAX:
            if (maxObj == null)
                maxObj = new Builtin(BuiltinCode.MAX);
            return maxObj;
        case MAXINDEX:
            if (maxindexObj == null)
                maxindexObj = new Builtin(BuiltinCode.MAXINDEX);
            return maxindexObj;
        case MIN:
            if (minObj == null)
                minObj = new Builtin(BuiltinCode.MIN);
            return minObj;
        case MININDEX:
            if (minindexObj == null)
                minindexObj = new Builtin(BuiltinCode.MININDEX);
            return minindexObj;
        case ABS:
            if (absObj == null)
                absObj = new Builtin(BuiltinCode.ABS);
            return absObj;
        case SIGN:
            if (signObj == null)
                signObj = new Builtin(BuiltinCode.SIGN);
            return signObj;
        case SQRT:
            if (sqrtObj == null)
                sqrtObj = new Builtin(BuiltinCode.SQRT);
            return sqrtObj;
        case EXP:
            if (expObj == null)
                expObj = new Builtin(BuiltinCode.EXP);
            return expObj;
        case PLOGP:
            if (plogpObj == null)
                plogpObj = new Builtin(BuiltinCode.PLOGP);
            return plogpObj;
        case PRINT:
            if (printObj == null)
                printObj = new Builtin(BuiltinCode.PRINT);
            return printObj;
        case PRINTF:
            if (printfObj == null) {
                printfObj = new Builtin(BuiltinCode.PRINTF);
            }
            return printfObj;
        case NROW:
            if (nrowObj == null)
                nrowObj = new Builtin(BuiltinCode.NROW);
            return nrowObj;
        case NCOL:
            if (ncolObj == null)
                ncolObj = new Builtin(BuiltinCode.NCOL);
            return ncolObj;
        case LENGTH:
            if (lengthObj == null)
                lengthObj = new Builtin(BuiltinCode.LENGTH);
            return lengthObj;
        case ROUND:
            if (roundObj == null)
                roundObj = new Builtin(BuiltinCode.ROUND);
            return roundObj;
        case CEIL:
            if (ceilObj == null)
                ceilObj = new Builtin(BuiltinCode.CEIL);
            return ceilObj;
        case FLOOR:
            if (floorObj == null)
                floorObj = new Builtin(BuiltinCode.FLOOR);
            return floorObj;
        case CUMSUM:
            if (cumsumObj == null)
                cumsumObj = new Builtin(BuiltinCode.CUMSUM);
            return cumsumObj;
        case CUMPROD:
            if (cumprodObj == null)
                cumprodObj = new Builtin(BuiltinCode.CUMPROD);
            return cumprodObj;
        case CUMMIN:
            if (cumminObj == null)
                cumminObj = new Builtin(BuiltinCode.CUMMIN);
            return cumminObj;
        case CUMMAX:
            if (cummaxObj == null)
                cummaxObj = new Builtin(BuiltinCode.CUMMAX);
            return cummaxObj;
        case INVERSE:
            if (inverseObj == null)
                inverseObj = new Builtin(BuiltinCode.INVERSE);
            return inverseObj;
        case STOP:
            if (stopObj == null)
                stopObj = new Builtin(BuiltinCode.STOP);
            return stopObj;

        case SPROP:
            if (spropObj == null)
                spropObj = new Builtin(BuiltinCode.SPROP);
            return spropObj;

        case SIGMOID:
            if (sigmoidObj == null)
                sigmoidObj = new Builtin(BuiltinCode.SIGMOID);
            return sigmoidObj;

        case SELP:
            if (selpObj == null)
                selpObj = new Builtin(BuiltinCode.SELP);
            return selpObj;

        default:
            // Unknown code --> return null
            return null;
        }
    }

    public Object clone() throws CloneNotSupportedException {
        // cloning is not supported for singleton classes
        throw new CloneNotSupportedException();
    }

    public double execute(double in) throws DMLRuntimeException {
        switch (bFunc) {
        case SIN:
            return FASTMATH ? FastMath.sin(in) : Math.sin(in);
        case COS:
            return FASTMATH ? FastMath.cos(in) : Math.cos(in);
        case TAN:
            return FASTMATH ? FastMath.tan(in) : Math.tan(in);
        case ASIN:
            return FASTMATH ? FastMath.asin(in) : Math.asin(in);
        case ACOS:
            return FASTMATH ? FastMath.acos(in) : Math.acos(in);
        case ATAN:
            return Math.atan(in); //faster in Math
        case CEIL:
            return FASTMATH ? FastMath.ceil(in) : Math.ceil(in);
        case FLOOR:
            return FASTMATH ? FastMath.floor(in) : Math.floor(in);
        case LOG:
            return FASTMATH ? FastMath.log(in) : Math.log(in);
        case LOG_NZ:
            return (in == 0) ? 0 : FASTMATH ? FastMath.log(in) : Math.log(in);
        case ABS:
            return Math.abs(in); //no need for FastMath         
        case SIGN:
            return FASTMATH ? FastMath.signum(in) : Math.signum(in);
        case SQRT:
            return Math.sqrt(in); //faster in Math      
        case EXP:
            return FASTMATH ? FastMath.exp(in) : Math.exp(in);
        case ROUND:
            return Math.round(in); //no need for FastMath

        case PLOGP:
            if (in == 0.0)
                return 0.0;
            else if (in < 0)
                return Double.NaN;
            else
                return (in * (FASTMATH ? FastMath.log(in) : Math.log(in)));

        case SPROP:
            //sample proportion: P*(1-P)
            return in * (1 - in);

        case SIGMOID:
            //sigmoid: 1/(1+exp(-x))
            return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in));

        case SELP:
            //select positive: x*(x>0)
            return (in > 0) ? in : 0;

        default:
            throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
        }
    }

    public double execute(long in) throws DMLRuntimeException {
        return execute((double) in);
    }

    /*
     * Builtin functions with two inputs
     */
    public double execute(double in1, double in2) throws DMLRuntimeException {
        switch (bFunc) {

        /*
         * Arithmetic relational operators (==, !=, <=, >=) must be instead of
         * <code>Double.compare()</code> due to the inconsistencies in the way
         * NaN and -0.0 are handled. The behavior of methods in
         * <code>Double</code> class are designed mainly to make Java
         * collections work properly. For more details, see the help for
         * <code>Double.equals()</code> and <code>Double.comapreTo()</code>.
         */
        case MAX:
        case CUMMAX:
            //return (Double.compare(in1, in2) >= 0 ? in1 : in2);
            return (in1 >= in2 ? in1 : in2);
        case MIN:
        case CUMMIN:
            //return (Double.compare(in1, in2) <= 0 ? in1 : in2);
            return (in1 <= in2 ? in1 : in2);

        // *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
        // rowIndexMax() and its siblings require comparing four values, but
        // the aggregation API only allows two values. So the execute()
        // method receives as its argument the two cell values to be
        // compared and performs just the value part of the comparison. We
        // return an integer cast down to a double, since the aggregation
        // API doesn't have any way to return anything but a double. The
        // integer returned takes on three posssible values: //
        // .     0 => keep the index associated with in1 //
        // .     1 => use the index associated with in2 //
        // .     2 => use whichever index is higher (tie in value) //
        case MAXINDEX:
            if (in1 == in2) {
                return 2;
            } else if (in1 > in2) {
                return 1;
            } else { // in1 < in2
                return 0;
            }
        case MININDEX:
            if (in1 == in2) {
                return 2;
            } else if (in1 < in2) {
                return 1;
            } else { // in1 > in2
                return 0;
            }
            // *** END HACK ***
        case LOG:
            //if ( in1 <= 0 )
            //   throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers.");
            if (FASTMATH)
                return (FastMath.log(in1) / FastMath.log(in2));
            else
                return (Math.log(in1) / Math.log(in2));
        case LOG_NZ:
            if (FASTMATH)
                return (in1 == 0) ? 0 : (FastMath.log(in1) / FastMath.log(in2));
            else
                return (in1 == 0) ? 0 : (Math.log(in1) / Math.log(in2));

        default:
            throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
        }
    }

    /**
     * Simplified version without exception handling
     * 
     * @param in1 double 1
     * @param in2 double 2
     * @return result
     */
    public double execute2(double in1, double in2) {
        switch (bFunc) {
        case MAX:
        case CUMMAX:
            //return (Double.compare(in1, in2) >= 0 ? in1 : in2); 
            return (in1 >= in2 ? in1 : in2);
        case MIN:
        case CUMMIN:
            //return (Double.compare(in1, in2) <= 0 ? in1 : in2); 
            return (in1 <= in2 ? in1 : in2);
        case MAXINDEX:
            return (in1 >= in2) ? 1 : 0;
        case MININDEX:
            return (in1 <= in2) ? 1 : 0;

        default:
            // For performance reasons, avoid throwing an exception 
            return -1;
        }
    }

    public double execute(long in1, long in2) throws DMLRuntimeException {
        switch (bFunc) {

        case MAX:
        case CUMMAX:
            return (in1 >= in2 ? in1 : in2);

        case MIN:
        case CUMMIN:
            return (in1 <= in2 ? in1 : in2);

        case MAXINDEX:
            return (in1 >= in2) ? 1 : 0;
        case MININDEX:
            return (in1 <= in2) ? 1 : 0;

        case LOG:
            //if ( in1 <= 0 )
            //   throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers.");
            if (FASTMATH)
                return (FastMath.log(in1) / FastMath.log(in2));
            else
                return (Math.log(in1) / Math.log(in2));
        case LOG_NZ:
            if (FASTMATH)
                return (in1 == 0) ? 0 : (FastMath.log(in1) / FastMath.log(in2));
            else
                return (in1 == 0) ? 0 : (Math.log(in1) / Math.log(in2));

        default:
            throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
        }
    }

    // currently, it is used only for PRINT, PRINTF and STOP
    public String execute(String in1) throws DMLRuntimeException {
        switch (bFunc) {
        case PRINT:
            if (!DMLScript.suppressPrint2Stdout())
                System.out.println(in1);
            return null;
        case PRINTF:
            if (!DMLScript.suppressPrint2Stdout())
                System.out.println(in1);
            return null;
        case STOP:
            throw new DMLScriptException(in1);
        default:
            throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
        }
    }
}