org.apache.tajo.engine.function.FunctionLoader.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.tajo.engine.function.FunctionLoader.java

Source

/***
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tajo.engine.function;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.commons.collections.Predicate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.tajo.annotation.Nullable;
import org.apache.tajo.catalog.CatalogUtil;
import org.apache.tajo.catalog.FunctionDesc;
import org.apache.tajo.common.TajoDataTypes;
import org.apache.tajo.conf.TajoConf;
import org.apache.tajo.engine.function.annotation.Description;
import org.apache.tajo.engine.function.annotation.ParamOptionTypes;
import org.apache.tajo.engine.function.annotation.ParamTypes;
import org.apache.tajo.function.*;
import org.apache.tajo.plan.function.python.PythonScriptEngine;
import org.apache.tajo.util.ClassUtil;

import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.*;

import static org.apache.tajo.catalog.proto.CatalogProtos.FunctionType.GENERAL;

public class FunctionLoader {

    private static Log LOG = LogFactory.getLog(FunctionLoader.class);
    public static final String PYTHON_FUNCTION_NAMESPACE = "python";

    /**
     * Load built-in functions
     *
     * @return
     */
    public static Map<FunctionSignature, FunctionDesc> loadBuiltinFunctions() {
        Map<FunctionSignature, FunctionDesc> map = Maps.newHashMap();

        List<FunctionDesc> dd = Lists.newArrayList();
        for (FunctionDesc f : findLegacyFunctions()) {
            map.put(f.getSignature(), f);

            if (f.getSignature().getName().equals("pow") || f.getSignature().getName().equals("pi")) {
                dd.add(f);
            }
        }

        for (FunctionDesc f : findScalarFunctions()) {
            if (map.containsKey(f.getSignature())) {
                FunctionDesc existing = map.get(f.getSignature());
                existing.getInvocation().setScalar(f.getInvocation().getScalar());
            } else {
                map.put(f.getSignature(), f);
            }
        }

        return map;
    }

    /**
     * Load functions defined by users.
     *
     * @param conf
     * @return
     * @throws IOException
     */
    public static Optional<List<FunctionDesc>> loadUserDefinedFunctions(TajoConf conf) throws IOException {
        List<FunctionDesc> functionList = new LinkedList<>();

        String[] codePaths = conf.getStrings(TajoConf.ConfVars.PYTHON_CODE_DIR.varname);
        if (codePaths != null) {
            FileSystem localFS = FileSystem.getLocal(conf);
            for (String codePathStr : codePaths) {

                Path codePath;
                try {
                    codePath = new Path(codePathStr);
                } catch (IllegalArgumentException e) {
                    LOG.warn("Illegal function path", e);
                    continue;
                }

                List<Path> filePaths = new ArrayList<>();
                if (localFS.isDirectory(codePath)) {
                    for (FileStatus file : localFS.listStatus(codePath,
                            (Path path) -> path.getName().endsWith(PythonScriptEngine.FILE_EXTENSION))) {
                        filePaths.add(file.getPath());
                    }
                } else {
                    filePaths.add(codePath);
                }
                for (Path filePath : filePaths) {
                    PythonScriptEngine.registerFunctions(filePath.toUri(), FunctionLoader.PYTHON_FUNCTION_NAMESPACE)
                            .forEach(functionList::add);
                }
            }
        }

        return Optional.of(functionList);
    }

    public static Set<FunctionDesc> findScalarFunctions() {
        Set<FunctionDesc> functions = Sets.newHashSet();

        Set<Method> scalarFunctions = findPublicStaticMethods("org.apache.tajo.engine.function",
                (Object object) -> ((Method) object).getAnnotation(ScalarFunction.class) != null);

        for (Method method : scalarFunctions) {
            ScalarFunction annotation = method.getAnnotation(ScalarFunction.class);
            functions.addAll(buildFunctionDescs(annotation, method));
        }

        return functions;
    }

    private static Set<Method> findPublicStaticMethods(String packageName, Predicate predicate) {
        Set<Class> found = findFunctionCollections(packageName);
        Set<Method> filtered = Sets.newHashSet();

        for (Class clazz : found) {
            for (Method method : clazz.getMethods()) {
                if (isPublicStaticMethod(method) && (predicate == null || predicate.evaluate(method))) {
                    filtered.add(method);
                }
            }
        }

        return filtered;
    }

    private static boolean isPublicStaticMethod(Method method) {
        return Modifier.isPublic(method.getModifiers()) && Modifier.isStatic(method.getModifiers());
    }

    private static Set<Class> findFunctionCollections(String packageName) {
        return ClassUtil.findClasses(null, packageName,
                (Object object) -> ((Class) object).getAnnotation(FunctionCollection.class) != null);
    }

    private static Collection<FunctionDesc> buildFunctionDescs(ScalarFunction annotation, Method method) {
        List<FunctionDesc> functionDescs = Lists.newArrayList();

        FunctionInvocation invocation = new FunctionInvocation();
        invocation.setScalar(extractStaticMethodInvocation(method));
        FunctionSupplement supplement = extractSupplement(annotation);

        // primary name
        functionDescs.add(new FunctionDesc(extractSignature(annotation, null), invocation, supplement));

        // for multiple aliases
        for (String alias : annotation.synonyms()) {
            functionDescs.add(new FunctionDesc(extractSignature(annotation, alias), invocation, supplement));
        }

        return functionDescs;
    }

    private static FunctionSignature extractSignature(ScalarFunction annotation, @Nullable String alias) {
        return new FunctionSignature(GENERAL, alias != null ? alias : annotation.name(),
                CatalogUtil.newSimpleDataType(annotation.returnType()),
                CatalogUtil.newSimpleDataTypeArray(annotation.paramTypes()));
    }

    private static FunctionSupplement extractSupplement(ScalarFunction function) {
        return new FunctionSupplement(function.shortDescription(), function.detail(), function.example());
    }

    private static StaticMethodInvocationDesc extractStaticMethodInvocation(Method method) {
        Preconditions.checkArgument(Modifier.isPublic(method.getModifiers()));
        Preconditions.checkArgument(Modifier.isStatic(method.getModifiers()));

        String methodName = method.getName();
        Class returnClass = method.getReturnType();
        Class[] paramClasses = method.getParameterTypes();
        return new StaticMethodInvocationDesc(method.getDeclaringClass(), methodName, returnClass, paramClasses);
    }

    /**
     * This method finds and build FunctionDesc for the legacy function and UD(A)F system.
     *
     * @return A list of FunctionDescs
     */
    public static List<FunctionDesc> findLegacyFunctions() {
        List<FunctionDesc> sqlFuncs = new ArrayList<>();

        Set<Class> functionClasses = ClassUtil.findClasses(Function.class, "org.apache.tajo.engine.function");

        for (Class eachClass : functionClasses) {
            if (eachClass.isInterface() || Modifier.isAbstract(eachClass.getModifiers())) {
                continue;
            }
            Function function = null;
            try {
                function = (Function) eachClass.newInstance();
            } catch (Exception e) {
                LOG.warn(eachClass + " cannot instantiate Function class because of " + e.getMessage(), e);
                continue;
            }
            String functionName = function.getClass().getAnnotation(Description.class).functionName();
            String[] synonyms = function.getClass().getAnnotation(Description.class).synonyms();
            String description = function.getClass().getAnnotation(Description.class).description();
            String detail = function.getClass().getAnnotation(Description.class).detail();
            String example = function.getClass().getAnnotation(Description.class).example();
            TajoDataTypes.Type returnType = function.getClass().getAnnotation(Description.class).returnType();
            ParamTypes[] paramArray = function.getClass().getAnnotation(Description.class).paramTypes();

            String[] allFunctionNames = null;
            if (synonyms != null && synonyms.length > 0) {
                allFunctionNames = new String[1 + synonyms.length];
                allFunctionNames[0] = functionName;
                System.arraycopy(synonyms, 0, allFunctionNames, 1, synonyms.length);
            } else {
                allFunctionNames = new String[] { functionName };
            }

            for (String eachFunctionName : allFunctionNames) {
                for (ParamTypes params : paramArray) {
                    ParamOptionTypes[] paramOptionArray;
                    if (params.paramOptionTypes() == null
                            || params.paramOptionTypes().getClass().getAnnotation(ParamTypes.class) == null) {
                        paramOptionArray = new ParamOptionTypes[0];
                    } else {
                        paramOptionArray = params.paramOptionTypes().getClass().getAnnotation(ParamTypes.class)
                                .paramOptionTypes();
                    }

                    TajoDataTypes.Type[] paramTypes = params.paramTypes();
                    if (paramOptionArray.length > 0)
                        paramTypes = params.paramTypes().clone();

                    for (int i = 0; i < paramOptionArray.length + 1; i++) {
                        FunctionDesc functionDesc = new FunctionDesc(eachFunctionName, function.getClass(),
                                function.getFunctionType(), CatalogUtil.newSimpleDataType(returnType),
                                paramTypes.length == 0 ? CatalogUtil.newSimpleDataTypeArray()
                                        : CatalogUtil.newSimpleDataTypeArray(paramTypes));

                        functionDesc.setDescription(description);
                        functionDesc.setExample(example);
                        functionDesc.setDetail(detail);
                        sqlFuncs.add(functionDesc);

                        if (i != paramOptionArray.length) {
                            paramTypes = new TajoDataTypes.Type[paramTypes.length
                                    + paramOptionArray[i].paramOptionTypes().length];
                            System.arraycopy(params.paramTypes(), 0, paramTypes, 0, paramTypes.length);
                            System.arraycopy(paramOptionArray[i].paramOptionTypes(), 0, paramTypes,
                                    paramTypes.length, paramOptionArray[i].paramOptionTypes().length);
                        }
                    }
                }
            }
        }

        return sqlFuncs;
    }
}