com.facebook.presto.sql.gen.LambdaBytecodeGenerator.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.sql.gen.LambdaBytecodeGenerator.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.ParameterizedType;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.expression.BytecodeExpression;
import com.facebook.presto.bytecode.expression.BytecodeExpressions;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.sql.relational.CallExpression;
import com.facebook.presto.sql.relational.ConstantExpression;
import com.facebook.presto.sql.relational.InputReferenceExpression;
import com.facebook.presto.sql.relational.LambdaDefinitionExpression;
import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.sql.relational.RowExpressionVisitor;
import com.facebook.presto.sql.relational.VariableReferenceExpression;
import com.facebook.presto.util.Reflection;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Primitives;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;

import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PRIVATE;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.STATIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.Parameter.arg;
import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantClass;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeDynamic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.setStatic;
import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR;
import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary;
import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary;
import static com.facebook.presto.sql.gen.LambdaCapture.LAMBDA_CAPTURE_METHOD;
import static com.facebook.presto.util.Failures.checkCondition;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
import static org.objectweb.asm.Type.getMethodType;
import static org.objectweb.asm.Type.getType;

public class LambdaBytecodeGenerator {
    private LambdaBytecodeGenerator() {
    }

    /**
     * @return a MethodHandle field that represents the lambda expression
     */
    public static CompiledLambda preGenerateLambdaExpression(LambdaDefinitionExpression lambdaExpression,
            String fieldName, ClassDefinition classDefinition, PreGeneratedExpressions preGeneratedExpressions,
            CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder,
            FunctionRegistry functionRegistry) {
        ImmutableList.Builder<Parameter> parameters = ImmutableList.builder();
        ImmutableMap.Builder<String, ParameterAndType> parameterMapBuilder = ImmutableMap.builder();

        parameters.add(arg("session", ConnectorSession.class));
        for (int i = 0; i < lambdaExpression.getArguments().size(); i++) {
            Class<?> type = Primitives.wrap(lambdaExpression.getArgumentTypes().get(i).getJavaType());
            String argumentName = lambdaExpression.getArguments().get(i);
            Parameter arg = arg("lambda_" + argumentName, type);
            parameters.add(arg);
            parameterMapBuilder.put(argumentName, new ParameterAndType(arg, type));
        }

        RowExpressionCompiler innerExpressionCompiler = new RowExpressionCompiler(callSiteBinder,
                cachedInstanceBinder, variableReferenceCompiler(parameterMapBuilder.build()), functionRegistry,
                preGeneratedExpressions);

        return defineLambdaMethodAndField(innerExpressionCompiler, classDefinition, fieldName, parameters.build(),
                lambdaExpression);
    }

    private static CompiledLambda defineLambdaMethodAndField(RowExpressionCompiler innerExpressionCompiler,
            ClassDefinition classDefinition, String fieldAndMethodName, List<Parameter> inputParameters,
            LambdaDefinitionExpression lambda) {
        Class<?> returnType = Primitives.wrap(lambda.getBody().getType().getJavaType());
        MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), fieldAndMethodName, type(returnType),
                inputParameters);

        Scope scope = method.getScope();
        Variable wasNull = scope.declareVariable(boolean.class, "wasNull");
        BytecodeNode compiledBody = innerExpressionCompiler.compile(lambda.getBody(), scope);
        method.getBody().putVariable(wasNull, false).append(compiledBody)
                .append(boxPrimitiveIfNecessary(scope, returnType)).ret(returnType);

        FieldDefinition staticField = classDefinition.declareField(a(PRIVATE, STATIC, FINAL), fieldAndMethodName,
                type(MethodHandle.class));
        FieldDefinition instanceField = classDefinition.declareField(a(PRIVATE, FINAL),
                "binded_" + fieldAndMethodName, type(MethodHandle.class));

        classDefinition.getClassInitializer().getBody()
                .append(setStatic(staticField,
                        invokeStatic(Reflection.class, "methodHandle", MethodHandle.class,
                                constantClass(classDefinition.getType()), constantString(fieldAndMethodName),
                                newArray(type(Class[].class), inputParameters.stream().map(Parameter::getType)
                                        .map(BytecodeExpressions::constantClass).collect(toImmutableList())))));

        Handle lambdaAsmHandle = new Handle(Opcodes.H_INVOKEVIRTUAL, method.getThis().getType().getClassName(),
                method.getName(), method.getMethodDescriptor());

        return new CompiledLambda(lambdaAsmHandle, method.getReturnType(), method.getParameterTypes(), staticField,
                instanceField);
    }

    public static BytecodeNode generateLambda(BytecodeGeneratorContext context,
            List<RowExpression> captureExpressions, CompiledLambda compiledLambda, Class lambdaInterface) {
        if (!lambdaInterface.isAnnotationPresent(FunctionalInterface.class)) {
            // lambdaInterface is checked to be annotated with FunctionalInterface when generating ScalarFunctionImplementation
            throw new VerifyException("lambda should be generated as class annotated with FunctionalInterface");
        }

        BytecodeBlock block = new BytecodeBlock().setDescription("Partial apply");
        Scope scope = context.getScope();

        Variable wasNull = scope.getVariable("wasNull");

        // generate values to be captured
        ImmutableList.Builder<BytecodeExpression> captureVariableBuilder = ImmutableList.builder();
        for (RowExpression captureExpression : captureExpressions) {
            Class<?> valueType = Primitives.wrap(captureExpression.getType().getJavaType());
            Variable valueVariable = scope.createTempVariable(valueType);
            block.append(context.generate(captureExpression));
            block.append(boxPrimitiveIfNecessary(scope, valueType));
            block.putVariable(valueVariable);
            block.append(wasNull.set(constantFalse()));
            captureVariableBuilder.add(valueVariable);
        }

        List<BytecodeExpression> captureVariables = ImmutableList.<BytecodeExpression>builder()
                .add(scope.getThis(), scope.getVariable("session")).addAll(captureVariableBuilder.build()).build();

        Type instantiatedMethodAsmType = getMethodType(compiledLambda.getReturnType().getAsmType(),
                compiledLambda.getParameterTypes().stream().skip(captureExpressions.size() + 1) // skip capture variables and ConnectorSession
                        .map(ParameterizedType::getAsmType).collect(toImmutableList()).toArray(new Type[0]));

        block.append(invokeDynamic(LAMBDA_CAPTURE_METHOD,
                ImmutableList.of(getType(getSingleApplyMethod(lambdaInterface)),
                        compiledLambda.getLambdaAsmHandle(), instantiatedMethodAsmType),
                "apply", type(lambdaInterface), captureVariables));
        return block;
    }

    private static Method getSingleApplyMethod(Class lambdaFunctionInterface) {
        checkCondition(lambdaFunctionInterface.isAnnotationPresent(FunctionalInterface.class), COMPILER_ERROR,
                "Lambda function interface is required to be annotated with FunctionalInterface");

        List<Method> applyMethods = Arrays.stream(lambdaFunctionInterface.getMethods())
                .filter(method -> method.getName().equals("apply")).collect(toImmutableList());

        checkCondition(applyMethods.size() == 1, COMPILER_ERROR,
                "Expect to have exactly 1 method with name 'apply' in interface "
                        + lambdaFunctionInterface.getName());
        return applyMethods.get(0);
    }

    private static RowExpressionVisitor<BytecodeNode, Scope> variableReferenceCompiler(
            Map<String, ParameterAndType> parameterMap) {
        return new RowExpressionVisitor<BytecodeNode, Scope>() {
            @Override
            public BytecodeNode visitInputReference(InputReferenceExpression node, Scope scope) {
                throw new UnsupportedOperationException();
            }

            @Override
            public BytecodeNode visitCall(CallExpression call, Scope scope) {
                throw new UnsupportedOperationException();
            }

            @Override
            public BytecodeNode visitConstant(ConstantExpression literal, Scope scope) {
                throw new UnsupportedOperationException();
            }

            @Override
            public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope context) {
                throw new UnsupportedOperationException();
            }

            @Override
            public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Scope context) {
                ParameterAndType parameterAndType = parameterMap.get(reference.getName());
                Parameter parameter = parameterAndType.getParameter();
                Class<?> type = parameterAndType.getType();
                return new BytecodeBlock().append(parameter).append(unboxPrimitiveIfNecessary(context, type));
            }
        };
    }

    static class CompiledLambda {
        private final FieldDefinition staticField;
        // the instance field will be binded to "this" in constructor
        private final FieldDefinition instanceField;

        // lambda method information
        private final Handle lambdaAsmHandle;
        private final ParameterizedType returnType;
        private final List<ParameterizedType> parameterTypes;

        public CompiledLambda(Handle lambdaAsmHandle, ParameterizedType returnType,
                List<ParameterizedType> parameterTypes, FieldDefinition staticField,
                FieldDefinition instanceField) {
            this.staticField = requireNonNull(staticField, "staticField is null");
            this.instanceField = requireNonNull(instanceField, "instanceField is null");
            this.lambdaAsmHandle = requireNonNull(lambdaAsmHandle, "lambdaMethodAsmHandle is null");
            this.returnType = requireNonNull(returnType, "returnType is null");
            this.parameterTypes = ImmutableList.copyOf(requireNonNull(parameterTypes, "returnType is null"));
        }

        public Handle getLambdaAsmHandle() {
            return lambdaAsmHandle;
        }

        public ParameterizedType getReturnType() {
            return returnType;
        }

        public List<ParameterizedType> getParameterTypes() {
            return parameterTypes;
        }

        public FieldDefinition getInstanceField() {
            return instanceField;
        }

        public void generateInitialization(Variable thisVariable, BytecodeBlock block) {
            block.append(thisVariable.setField(instanceField,
                    getStatic(staticField).invoke("bindTo", MethodHandle.class, thisVariable.cast(Object.class))));
        }
    }
}