Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.functionobjects; import java.util.HashMap; import org.apache.commons.math3.util.FastMath; import org.apache.sysml.api.DMLScript; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; /** * Class with pre-defined set of objects. This class can not be instantiated elsewhere. * * Notes on commons.math FastMath: * * FastMath uses lookup tables and interpolation instead of native calls. * * The memory overhead for those tables is roughly 48KB in total (acceptable) * * Micro and application benchmarks showed significantly (30%-3x) performance improvements * for most operations; without loss of accuracy. * * atan / sqrt were 20% slower in FastMath and hence, we use Math there * * round / abs were equivalent in FastMath and hence, we use Math there * * Finally, there is just one argument against FastMath - The comparison heavily depends * on the JVM. For example, currently the IBM JDK JIT compiles to HW instructions for sqrt * which makes this operation very efficient; as soon as other operations like log/exp are * similarly compiled, we should rerun the micro benchmarks, and switch back if necessary. * */ public class Builtin extends ValueFunction { private static final long serialVersionUID = 3836744687789840574L; public enum BuiltinCode { SIN, COS, TAN, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, INVERSE, SPROP, SIGMOID, SELP }; public BuiltinCode bFunc; private static final boolean FASTMATH = true; static public HashMap<String, BuiltinCode> String2BuiltinCode; static { String2BuiltinCode = new HashMap<String, BuiltinCode>(); String2BuiltinCode.put("sin", BuiltinCode.SIN); String2BuiltinCode.put("cos", BuiltinCode.COS); String2BuiltinCode.put("tan", BuiltinCode.TAN); String2BuiltinCode.put("asin", BuiltinCode.ASIN); String2BuiltinCode.put("acos", BuiltinCode.ACOS); String2BuiltinCode.put("atan", BuiltinCode.ATAN); String2BuiltinCode.put("log", BuiltinCode.LOG); String2BuiltinCode.put("log_nz", BuiltinCode.LOG_NZ); String2BuiltinCode.put("min", BuiltinCode.MIN); String2BuiltinCode.put("max", BuiltinCode.MAX); String2BuiltinCode.put("maxindex", BuiltinCode.MAXINDEX); String2BuiltinCode.put("minindex", BuiltinCode.MININDEX); String2BuiltinCode.put("abs", BuiltinCode.ABS); String2BuiltinCode.put("sign", BuiltinCode.SIGN); String2BuiltinCode.put("sqrt", BuiltinCode.SQRT); String2BuiltinCode.put("exp", BuiltinCode.EXP); String2BuiltinCode.put("plogp", BuiltinCode.PLOGP); String2BuiltinCode.put("print", BuiltinCode.PRINT); String2BuiltinCode.put("printf", BuiltinCode.PRINTF); String2BuiltinCode.put("nrow", BuiltinCode.NROW); String2BuiltinCode.put("ncol", BuiltinCode.NCOL); String2BuiltinCode.put("length", BuiltinCode.LENGTH); String2BuiltinCode.put("round", BuiltinCode.ROUND); String2BuiltinCode.put("stop", BuiltinCode.STOP); String2BuiltinCode.put("ceil", BuiltinCode.CEIL); String2BuiltinCode.put("floor", BuiltinCode.FLOOR); String2BuiltinCode.put("ucumk+", BuiltinCode.CUMSUM); String2BuiltinCode.put("ucum*", BuiltinCode.CUMPROD); String2BuiltinCode.put("ucummin", BuiltinCode.CUMMIN); String2BuiltinCode.put("ucummax", BuiltinCode.CUMMAX); String2BuiltinCode.put("inverse", BuiltinCode.INVERSE); String2BuiltinCode.put("sprop", BuiltinCode.SPROP); String2BuiltinCode.put("sigmoid", BuiltinCode.SIGMOID); String2BuiltinCode.put("sel+", BuiltinCode.SELP); } // We should create one object for every builtin function that we support private static Builtin sinObj = null, cosObj = null, tanObj = null, asinObj = null, acosObj = null, atanObj = null; private static Builtin logObj = null, lognzObj = null, minObj = null, maxObj = null, maxindexObj = null, minindexObj = null; private static Builtin absObj = null, signObj = null, sqrtObj = null, expObj = null, plogpObj = null, printObj = null, printfObj; private static Builtin nrowObj = null, ncolObj = null, lengthObj = null, roundObj = null, ceilObj = null, floorObj = null; private static Builtin inverseObj = null, cumsumObj = null, cumprodObj = null, cumminObj = null, cummaxObj = null; private static Builtin stopObj = null, spropObj = null, sigmoidObj = null, selpObj = null; private Builtin(BuiltinCode bf) { bFunc = bf; } public BuiltinCode getBuiltinCode() { return bFunc; } public static Builtin getBuiltinFnObject(String str) { BuiltinCode code = String2BuiltinCode.get(str); return getBuiltinFnObject(code); } public static Builtin getBuiltinFnObject(BuiltinCode code) { if (code == null) return null; switch (code) { case SIN: if (sinObj == null) sinObj = new Builtin(BuiltinCode.SIN); return sinObj; case COS: if (cosObj == null) cosObj = new Builtin(BuiltinCode.COS); return cosObj; case TAN: if (tanObj == null) tanObj = new Builtin(BuiltinCode.TAN); return tanObj; case ASIN: if (asinObj == null) asinObj = new Builtin(BuiltinCode.ASIN); return asinObj; case ACOS: if (acosObj == null) acosObj = new Builtin(BuiltinCode.ACOS); return acosObj; case ATAN: if (atanObj == null) atanObj = new Builtin(BuiltinCode.ATAN); return atanObj; case LOG: if (logObj == null) logObj = new Builtin(BuiltinCode.LOG); return logObj; case LOG_NZ: if (lognzObj == null) lognzObj = new Builtin(BuiltinCode.LOG_NZ); return lognzObj; case MAX: if (maxObj == null) maxObj = new Builtin(BuiltinCode.MAX); return maxObj; case MAXINDEX: if (maxindexObj == null) maxindexObj = new Builtin(BuiltinCode.MAXINDEX); return maxindexObj; case MIN: if (minObj == null) minObj = new Builtin(BuiltinCode.MIN); return minObj; case MININDEX: if (minindexObj == null) minindexObj = new Builtin(BuiltinCode.MININDEX); return minindexObj; case ABS: if (absObj == null) absObj = new Builtin(BuiltinCode.ABS); return absObj; case SIGN: if (signObj == null) signObj = new Builtin(BuiltinCode.SIGN); return signObj; case SQRT: if (sqrtObj == null) sqrtObj = new Builtin(BuiltinCode.SQRT); return sqrtObj; case EXP: if (expObj == null) expObj = new Builtin(BuiltinCode.EXP); return expObj; case PLOGP: if (plogpObj == null) plogpObj = new Builtin(BuiltinCode.PLOGP); return plogpObj; case PRINT: if (printObj == null) printObj = new Builtin(BuiltinCode.PRINT); return printObj; case PRINTF: if (printfObj == null) { printfObj = new Builtin(BuiltinCode.PRINTF); } return printfObj; case NROW: if (nrowObj == null) nrowObj = new Builtin(BuiltinCode.NROW); return nrowObj; case NCOL: if (ncolObj == null) ncolObj = new Builtin(BuiltinCode.NCOL); return ncolObj; case LENGTH: if (lengthObj == null) lengthObj = new Builtin(BuiltinCode.LENGTH); return lengthObj; case ROUND: if (roundObj == null) roundObj = new Builtin(BuiltinCode.ROUND); return roundObj; case CEIL: if (ceilObj == null) ceilObj = new Builtin(BuiltinCode.CEIL); return ceilObj; case FLOOR: if (floorObj == null) floorObj = new Builtin(BuiltinCode.FLOOR); return floorObj; case CUMSUM: if (cumsumObj == null) cumsumObj = new Builtin(BuiltinCode.CUMSUM); return cumsumObj; case CUMPROD: if (cumprodObj == null) cumprodObj = new Builtin(BuiltinCode.CUMPROD); return cumprodObj; case CUMMIN: if (cumminObj == null) cumminObj = new Builtin(BuiltinCode.CUMMIN); return cumminObj; case CUMMAX: if (cummaxObj == null) cummaxObj = new Builtin(BuiltinCode.CUMMAX); return cummaxObj; case INVERSE: if (inverseObj == null) inverseObj = new Builtin(BuiltinCode.INVERSE); return inverseObj; case STOP: if (stopObj == null) stopObj = new Builtin(BuiltinCode.STOP); return stopObj; case SPROP: if (spropObj == null) spropObj = new Builtin(BuiltinCode.SPROP); return spropObj; case SIGMOID: if (sigmoidObj == null) sigmoidObj = new Builtin(BuiltinCode.SIGMOID); return sigmoidObj; case SELP: if (selpObj == null) selpObj = new Builtin(BuiltinCode.SELP); return selpObj; default: // Unknown code --> return null return null; } } public Object clone() throws CloneNotSupportedException { // cloning is not supported for singleton classes throw new CloneNotSupportedException(); } public double execute(double in) throws DMLRuntimeException { switch (bFunc) { case SIN: return FASTMATH ? FastMath.sin(in) : Math.sin(in); case COS: return FASTMATH ? FastMath.cos(in) : Math.cos(in); case TAN: return FASTMATH ? FastMath.tan(in) : Math.tan(in); case ASIN: return FASTMATH ? FastMath.asin(in) : Math.asin(in); case ACOS: return FASTMATH ? FastMath.acos(in) : Math.acos(in); case ATAN: return Math.atan(in); //faster in Math case CEIL: return FASTMATH ? FastMath.ceil(in) : Math.ceil(in); case FLOOR: return FASTMATH ? FastMath.floor(in) : Math.floor(in); case LOG: return FASTMATH ? FastMath.log(in) : Math.log(in); case LOG_NZ: return (in == 0) ? 0 : FASTMATH ? FastMath.log(in) : Math.log(in); case ABS: return Math.abs(in); //no need for FastMath case SIGN: return FASTMATH ? FastMath.signum(in) : Math.signum(in); case SQRT: return Math.sqrt(in); //faster in Math case EXP: return FASTMATH ? FastMath.exp(in) : Math.exp(in); case ROUND: return Math.round(in); //no need for FastMath case PLOGP: if (in == 0.0) return 0.0; else if (in < 0) return Double.NaN; else return (in * (FASTMATH ? FastMath.log(in) : Math.log(in))); case SPROP: //sample proportion: P*(1-P) return in * (1 - in); case SIGMOID: //sigmoid: 1/(1+exp(-x)) return FASTMATH ? 1 / (1 + FastMath.exp(-in)) : 1 / (1 + Math.exp(-in)); case SELP: //select positive: x*(x>0) return (in > 0) ? in : 0; default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } } public double execute(long in) throws DMLRuntimeException { return execute((double) in); } /* * Builtin functions with two inputs */ public double execute(double in1, double in2) throws DMLRuntimeException { switch (bFunc) { /* * Arithmetic relational operators (==, !=, <=, >=) must be instead of * <code>Double.compare()</code> due to the inconsistencies in the way * NaN and -0.0 are handled. The behavior of methods in * <code>Double</code> class are designed mainly to make Java * collections work properly. For more details, see the help for * <code>Double.equals()</code> and <code>Double.comapreTo()</code>. */ case MAX: case CUMMAX: //return (Double.compare(in1, in2) >= 0 ? in1 : in2); return (in1 >= in2 ? in1 : in2); case MIN: case CUMMIN: //return (Double.compare(in1, in2) <= 0 ? in1 : in2); return (in1 <= in2 ? in1 : in2); // *** HACK ALERT *** HACK ALERT *** HACK ALERT *** // rowIndexMax() and its siblings require comparing four values, but // the aggregation API only allows two values. So the execute() // method receives as its argument the two cell values to be // compared and performs just the value part of the comparison. We // return an integer cast down to a double, since the aggregation // API doesn't have any way to return anything but a double. The // integer returned takes on three posssible values: // // . 0 => keep the index associated with in1 // // . 1 => use the index associated with in2 // // . 2 => use whichever index is higher (tie in value) // case MAXINDEX: if (in1 == in2) { return 2; } else if (in1 > in2) { return 1; } else { // in1 < in2 return 0; } case MININDEX: if (in1 == in2) { return 2; } else if (in1 < in2) { return 1; } else { // in1 > in2 return 0; } // *** END HACK *** case LOG: //if ( in1 <= 0 ) // throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers."); if (FASTMATH) return (FastMath.log(in1) / FastMath.log(in2)); else return (Math.log(in1) / Math.log(in2)); case LOG_NZ: if (FASTMATH) return (in1 == 0) ? 0 : (FastMath.log(in1) / FastMath.log(in2)); else return (in1 == 0) ? 0 : (Math.log(in1) / Math.log(in2)); default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } } /** * Simplified version without exception handling * * @param in1 double 1 * @param in2 double 2 * @return result */ public double execute2(double in1, double in2) { switch (bFunc) { case MAX: case CUMMAX: //return (Double.compare(in1, in2) >= 0 ? in1 : in2); return (in1 >= in2 ? in1 : in2); case MIN: case CUMMIN: //return (Double.compare(in1, in2) <= 0 ? in1 : in2); return (in1 <= in2 ? in1 : in2); case MAXINDEX: return (in1 >= in2) ? 1 : 0; case MININDEX: return (in1 <= in2) ? 1 : 0; default: // For performance reasons, avoid throwing an exception return -1; } } public double execute(long in1, long in2) throws DMLRuntimeException { switch (bFunc) { case MAX: case CUMMAX: return (in1 >= in2 ? in1 : in2); case MIN: case CUMMIN: return (in1 <= in2 ? in1 : in2); case MAXINDEX: return (in1 >= in2) ? 1 : 0; case MININDEX: return (in1 <= in2) ? 1 : 0; case LOG: //if ( in1 <= 0 ) // throw new DMLRuntimeException("Builtin.execute(): logarithm can be computed only for non-negative numbers."); if (FASTMATH) return (FastMath.log(in1) / FastMath.log(in2)); else return (Math.log(in1) / Math.log(in2)); case LOG_NZ: if (FASTMATH) return (in1 == 0) ? 0 : (FastMath.log(in1) / FastMath.log(in2)); else return (in1 == 0) ? 0 : (Math.log(in1) / Math.log(in2)); default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } } // currently, it is used only for PRINT, PRINTF and STOP public String execute(String in1) throws DMLRuntimeException { switch (bFunc) { case PRINT: if (!DMLScript.suppressPrint2Stdout()) System.out.println(in1); return null; case PRINTF: if (!DMLScript.suppressPrint2Stdout()) System.out.println(in1); return null; case STOP: throw new DMLScriptException(in1); default: throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); } } }