com.facebook.presto.sql.planner.optimizations.HashGenerationOptimizer.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.sql.planner.optimizations.HashGenerationOptimizer.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.QualifiedNameReference;
import com.facebook.presto.sql.tree.Window;
import com.facebook.presto.type.TypeUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class HashGenerationOptimizer extends PlanOptimizer {
    public static final int INITIAL_HASH_VALUE = 0;
    private static final String HASH_CODE = FunctionRegistry.mangleOperatorName("HASH_CODE");

    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types,
            SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        requireNonNull(plan, "plan is null");
        requireNonNull(session, "session is null");
        requireNonNull(types, "types is null");
        requireNonNull(symbolAllocator, "symbolAllocator is null");
        requireNonNull(idAllocator, "idAllocator is null");
        if (SystemSessionProperties.isOptimizeHashGenerationEnabled(session)) {
            return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, symbolAllocator, types), plan, null);
        }
        return plan;
    }

    private static class Rewriter extends SimplePlanRewriter<Void> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Map<Symbol, Type> types;

        private Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator,
                Map<Symbol, Type> types) {
            this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.types = requireNonNull(types, "types is null");
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), null);
            if (rewrittenSource == node.getSource() && node.getGroupBy().isEmpty()) {
                return node;
            }
            if (node.getGroupBy().isEmpty() || canSkipHashGeneration(node)) {
                return new AggregationNode(idAllocator.getNextId(), rewrittenSource, node.getGroupBy(),
                        node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(),
                        node.getSampleWeight(), node.getConfidence(), Optional.empty());
            }

            Symbol hashSymbol = symbolAllocator.newHashSymbol();
            PlanNode hashProjectNode = getHashProjectNode(idAllocator, rewrittenSource, hashSymbol,
                    node.getGroupBy());
            return new AggregationNode(idAllocator.getNextId(), hashProjectNode, node.getGroupBy(),
                    node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(),
                    node.getSampleWeight(), node.getConfidence(), Optional.of(hashSymbol));
        }

        private boolean canSkipHashGeneration(AggregationNode node) {
            // HACK: bigint grouped aggregation has special operators that do not use precomputed hash, so we can skip hash generation
            return node.getGroupBy().size() == 1
                    && types.get(Iterables.getOnlyElement(node.getGroupBy())).equals(BigintType.BIGINT);
        }

        @Override
        public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), null);
            Symbol hashSymbol = symbolAllocator.newHashSymbol();
            PlanNode hashProjectNode = getHashProjectNode(idAllocator, rewrittenSource, hashSymbol,
                    node.getOutputSymbols());
            return new DistinctLimitNode(idAllocator.getNextId(), hashProjectNode, node.getLimit(),
                    Optional.of(hashSymbol));
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), null);
            Symbol hashSymbol = symbolAllocator.newHashSymbol();
            PlanNode hashProjectNode = getHashProjectNode(idAllocator, rewrittenSource, hashSymbol,
                    node.getDistinctSymbols());
            return new MarkDistinctNode(idAllocator.getNextId(), hashProjectNode, node.getMarkerSymbol(),
                    node.getDistinctSymbols(), Optional.of(hashSymbol));
        }

        @Override
        public PlanNode visitRowNumber(RowNumberNode node, RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), null);
            if (rewrittenSource == node.getSource() && node.getPartitionBy().isEmpty()) {
                return node;
            }

            if (!node.getPartitionBy().isEmpty()) {
                Symbol hashSymbol = symbolAllocator.newHashSymbol();
                PlanNode hashProjectNode = getHashProjectNode(idAllocator, rewrittenSource, hashSymbol,
                        node.getPartitionBy());
                return new RowNumberNode(idAllocator.getNextId(), hashProjectNode, node.getPartitionBy(),
                        node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), Optional.of(hashSymbol));
            }
            return new RowNumberNode(idAllocator.getNextId(), rewrittenSource, node.getPartitionBy(),
                    node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.getHashSymbol());
        }

        @Override
        public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), null);
            if (rewrittenSource == node.getSource() && node.getPartitionBy().isEmpty()) {
                return node;
            }

            if (node.getPartitionBy().isEmpty()) {
                return new TopNRowNumberNode(idAllocator.getNextId(), rewrittenSource, node.getPartitionBy(),
                        node.getOrderBy(), node.getOrderings(), node.getRowNumberSymbol(),
                        node.getMaxRowCountPerPartition(), node.isPartial(), node.getHashSymbol());
            }
            Symbol hashSymbol = symbolAllocator.newHashSymbol();
            PlanNode hashProjectNode = getHashProjectNode(idAllocator, rewrittenSource, hashSymbol,
                    node.getPartitionBy());
            return new TopNRowNumberNode(idAllocator.getNextId(), hashProjectNode, node.getPartitionBy(),
                    node.getOrderBy(), node.getOrderings(), node.getRowNumberSymbol(),
                    node.getMaxRowCountPerPartition(), node.isPartial(), Optional.of(hashSymbol));
        }

        @Override
        public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context) {
            List<JoinNode.EquiJoinClause> clauses = node.getCriteria();

            List<Symbol> leftSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getLeft);
            List<Symbol> rightSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getRight);

            PlanNode rewrittenLeft = context.rewrite(node.getLeft(), null);
            PlanNode rewrittenRight = context.rewrite(node.getRight(), null);

            if (clauses.isEmpty()) {
                // No Hash is necessary for cross join
                return new JoinNode(idAllocator.getNextId(), JoinNode.Type.INNER, rewrittenLeft, rewrittenRight,
                        node.getCriteria(), Optional.empty(), Optional.empty());
            }

            Symbol leftHashSymbol = symbolAllocator.newHashSymbol();
            Symbol rightHashSymbol = symbolAllocator.newHashSymbol();

            PlanNode leftHashProjectNode = getHashProjectNode(idAllocator, rewrittenLeft, leftHashSymbol,
                    leftSymbols);
            PlanNode rightHashProjectNode = getHashProjectNode(idAllocator, rewrittenRight, rightHashSymbol,
                    rightSymbols);

            return new JoinNode(idAllocator.getNextId(), node.getType(), leftHashProjectNode, rightHashProjectNode,
                    node.getCriteria(), Optional.of(leftHashSymbol), Optional.of(rightHashSymbol));
        }

        @Override
        public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), null);
            PlanNode rewrittenFilteringSource = context.rewrite(node.getFilteringSource(), null);

            Symbol sourceHashSymbol = symbolAllocator.newHashSymbol();
            Symbol filteringSourceHashSymbol = symbolAllocator.newHashSymbol();

            PlanNode sourceHashProjectNode = getHashProjectNode(idAllocator, rewrittenSource, sourceHashSymbol,
                    ImmutableList.of(node.getSourceJoinSymbol()));
            PlanNode filteringSourceHashProjectNode = getHashProjectNode(idAllocator, rewrittenFilteringSource,
                    filteringSourceHashSymbol, ImmutableList.of(node.getFilteringSourceJoinSymbol()));

            return new SemiJoinNode(idAllocator.getNextId(), sourceHashProjectNode, filteringSourceHashProjectNode,
                    node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(),
                    Optional.of(sourceHashSymbol), Optional.of(filteringSourceHashSymbol));
        }

        @Override
        public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Void> context) {
            PlanNode rewrittenIndex = context.rewrite(node.getIndexSource(), null);
            PlanNode rewrittenProbe = context.rewrite(node.getProbeSource(), null);

            Symbol indexHashSymbol = symbolAllocator.newHashSymbol();
            Symbol probeHashSymbol = symbolAllocator.newHashSymbol();

            List<IndexJoinNode.EquiJoinClause> clauses = node.getCriteria();

            List<Symbol> indexSymbols = Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getIndex);
            List<Symbol> probeSymbols = Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getProbe);

            PlanNode indexHashProjectNode = getHashProjectNode(idAllocator, rewrittenIndex, indexHashSymbol,
                    indexSymbols);
            PlanNode probeHashProjectNode = getHashProjectNode(idAllocator, rewrittenProbe, probeHashSymbol,
                    probeSymbols);

            return new IndexJoinNode(idAllocator.getNextId(), node.getType(), probeHashProjectNode,
                    indexHashProjectNode, node.getCriteria(), Optional.of(probeHashSymbol),
                    Optional.of(indexHashSymbol));
        }

        @Override
        public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), null);
            if (rewrittenSource == node.getSource() && node.getPartitionBy().isEmpty()) {
                return node;
            }
            if (node.getPartitionBy().isEmpty()) {
                return new WindowNode(idAllocator.getNextId(), rewrittenSource, node.getPartitionBy(),
                        node.getOrderBy(), node.getOrderings(), node.getFrame(), node.getWindowFunctions(),
                        node.getSignatures(), Optional.empty(), node.getPrePartitionedInputs(),
                        node.getPreSortedOrderPrefix());
            }
            Symbol hashSymbol = symbolAllocator.newHashSymbol();
            PlanNode hashProjectNode = getHashProjectNode(idAllocator, rewrittenSource, hashSymbol,
                    node.getPartitionBy());
            return new WindowNode(idAllocator.getNextId(), hashProjectNode, node.getPartitionBy(),
                    node.getOrderBy(), node.getOrderings(), node.getFrame(), node.getWindowFunctions(),
                    node.getSignatures(), Optional.of(hashSymbol), node.getPrePartitionedInputs(),
                    node.getPreSortedOrderPrefix());
        }
    }

    private static ProjectNode getHashProjectNode(PlanNodeIdAllocator idAllocator, PlanNode source,
            Symbol hashSymbol, List<Symbol> partitioningSymbols) {
        checkArgument(!partitioningSymbols.isEmpty(), "partitioningSymbols is empty");
        ImmutableMap.Builder<Symbol, Expression> outputSymbols = ImmutableMap.builder();
        for (Symbol symbol : source.getOutputSymbols()) {
            Expression expression = new QualifiedNameReference(symbol.toQualifiedName());
            outputSymbols.put(symbol, expression);
        }

        Expression hashExpression = getHashExpression(partitioningSymbols);
        outputSymbols.put(hashSymbol, hashExpression);
        return new ProjectNode(idAllocator.getNextId(), source, outputSymbols.build());
    }

    private static Expression getHashExpression(List<Symbol> partitioningSymbols) {
        Expression hashExpression = new LongLiteral(String.valueOf(INITIAL_HASH_VALUE));
        for (Symbol symbol : partitioningSymbols) {
            hashExpression = getHashFunctionCall(hashExpression, symbol);
        }
        return hashExpression;
    }

    private static Expression getHashFunctionCall(Expression previousHashValue, Symbol symbol) {
        FunctionCall functionCall = new FunctionCall(QualifiedName.of(HASH_CODE), Optional.<Window>empty(), false,
                ImmutableList.<Expression>of(new QualifiedNameReference(symbol.toQualifiedName())));
        List<Expression> arguments = ImmutableList.of(previousHashValue, orNullHashCode(functionCall));
        return new FunctionCall(QualifiedName.of("combine_hash"), arguments);
    }

    private static Expression orNullHashCode(Expression expression) {
        return new CoalesceExpression(expression, new LongLiteral(String.valueOf(TypeUtils.NULL_HASH_CODE)));
    }
}