org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveJoinToMultiJoinRule.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveJoinToMultiJoinRule.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories.ProjectFactory;
import org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

/**
 * Rule that merges a join with multijoin/join children if
 * the equi compared the same set of input columns.
 */
public class HiveJoinToMultiJoinRule extends RelOptRule {

    public static final HiveJoinToMultiJoinRule INSTANCE = new HiveJoinToMultiJoinRule(HiveJoin.class,
            HiveProject.DEFAULT_PROJECT_FACTORY);

    private final ProjectFactory projectFactory;

    private static transient final Log LOG = LogFactory.getLog(HiveJoinToMultiJoinRule.class);

    //~ Constructors -----------------------------------------------------------

    /**
     * Creates a JoinToMultiJoinRule.
     */
    public HiveJoinToMultiJoinRule(Class<? extends Join> clazz, ProjectFactory projectFactory) {
        super(operand(clazz, operand(RelNode.class, any()), operand(RelNode.class, any())));
        this.projectFactory = projectFactory;
    }

    //~ Methods ----------------------------------------------------------------

    @Override
    public void onMatch(RelOptRuleCall call) {
        final HiveJoin join = call.rel(0);
        final RelNode left = call.rel(1);
        final RelNode right = call.rel(2);

        // 1. We try to merge this join with the left child
        RelNode multiJoin = mergeJoin(join, left, right);
        if (multiJoin != null) {
            call.transformTo(multiJoin);
            return;
        }

        // 2. If we cannot, we swap the inputs so we can try
        //    to merge it with its right child
        RelNode swapped = JoinCommuteRule.swap(join, true);
        assert swapped != null;

        //    The result of the swapping operation is either
        //    i)  a Project or,
        //    ii) if the project is trivial, a raw join
        final HiveJoin newJoin;
        Project topProject = null;
        if (swapped instanceof HiveJoin) {
            newJoin = (HiveJoin) swapped;
        } else {
            topProject = (Project) swapped;
            newJoin = (HiveJoin) swapped.getInput(0);
        }

        // 3. We try to merge the join with the right child
        multiJoin = mergeJoin(newJoin, right, left);
        if (multiJoin != null) {
            if (topProject != null) {
                multiJoin = projectFactory.createProject(multiJoin, topProject.getChildExps(),
                        topProject.getRowType().getFieldNames());
            }
            call.transformTo(multiJoin);
            return;
        }
    }

    // This method tries to merge the join with its left child. The left
    // child should be a join for this to happen.
    private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) {
        final RexBuilder rexBuilder = join.getCluster().getRexBuilder();

        // We check whether the join can be combined with any of its children
        final List<RelNode> newInputs = Lists.newArrayList();
        final List<RexNode> newJoinCondition = Lists.newArrayList();
        final List<Pair<Integer, Integer>> joinInputs = Lists.newArrayList();
        final List<JoinRelType> joinTypes = Lists.newArrayList();
        final List<RexNode> joinFilters = Lists.newArrayList();

        // Left child
        if (left instanceof HiveJoin || left instanceof HiveMultiJoin) {
            final RexNode leftCondition;
            final List<Pair<Integer, Integer>> leftJoinInputs;
            final List<JoinRelType> leftJoinTypes;
            final List<RexNode> leftJoinFilters;
            boolean combinable;
            if (left instanceof HiveJoin) {
                HiveJoin hj = (HiveJoin) left;
                leftCondition = hj.getCondition();
                leftJoinInputs = ImmutableList.of(Pair.of(0, 1));
                leftJoinTypes = ImmutableList.of(hj.getJoinType());
                leftJoinFilters = ImmutableList.of(hj.getJoinFilter());
                try {
                    combinable = isCombinableJoin(join, hj);
                } catch (CalciteSemanticException e) {
                    LOG.trace("Failed to merge join-join", e);
                    combinable = false;
                }
            } else {
                HiveMultiJoin hmj = (HiveMultiJoin) left;
                leftCondition = hmj.getCondition();
                leftJoinInputs = hmj.getJoinInputs();
                leftJoinTypes = hmj.getJoinTypes();
                leftJoinFilters = hmj.getJoinFilters();
                try {
                    combinable = isCombinableJoin(join, hmj);
                } catch (CalciteSemanticException e) {
                    LOG.trace("Failed to merge join-multijoin", e);
                    combinable = false;
                }
            }

            if (combinable) {
                newJoinCondition.add(leftCondition);
                for (int i = 0; i < leftJoinInputs.size(); i++) {
                    joinInputs.add(leftJoinInputs.get(i));
                    joinTypes.add(leftJoinTypes.get(i));
                    joinFilters.add(leftJoinFilters.get(i));
                }
                newInputs.addAll(left.getInputs());
            } else { // The join operation in the child is not on the same keys
                return null;
            }
        } else { // The left child is not a join or multijoin operator
            return null;
        }
        final int numberLeftInputs = newInputs.size();

        // Right child
        newInputs.add(right);

        // If we cannot combine any of the children, we bail out
        newJoinCondition.add(join.getCondition());
        if (newJoinCondition.size() == 1) {
            return null;
        }

        final List<RelDataTypeField> systemFieldList = ImmutableList.of();
        List<List<RexNode>> joinKeyExprs = new ArrayList<List<RexNode>>();
        List<Integer> filterNulls = new ArrayList<Integer>();
        for (int i = 0; i < newInputs.size(); i++) {
            joinKeyExprs.add(new ArrayList<RexNode>());
        }
        RexNode filters;
        try {
            filters = HiveRelOptUtil.splitHiveJoinCondition(systemFieldList, newInputs, join.getCondition(),
                    joinKeyExprs, filterNulls, null);
        } catch (CalciteSemanticException e) {
            LOG.trace("Failed to merge joins", e);
            return null;
        }
        ImmutableBitSet.Builder keysInInputsBuilder = ImmutableBitSet.builder();
        for (int i = 0; i < newInputs.size(); i++) {
            List<RexNode> partialCondition = joinKeyExprs.get(i);
            if (!partialCondition.isEmpty()) {
                keysInInputsBuilder.set(i);
            }
        }
        // If we cannot merge, we bail out
        ImmutableBitSet keysInInputs = keysInInputsBuilder.build();
        ImmutableBitSet leftReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs));
        ImmutableBitSet rightReferencedInputs = keysInInputs
                .intersect(ImmutableBitSet.range(numberLeftInputs, newInputs.size()));
        if (join.getJoinType() != JoinRelType.INNER
                && (leftReferencedInputs.cardinality() > 1 || rightReferencedInputs.cardinality() > 1)) {
            return null;
        }
        // Otherwise, we add to the join specs
        if (join.getJoinType() != JoinRelType.INNER) {
            int leftInput = keysInInputs.nextSetBit(0);
            int rightInput = keysInInputs.nextSetBit(numberLeftInputs);
            joinInputs.add(Pair.of(leftInput, rightInput));
            joinTypes.add(join.getJoinType());
            joinFilters.add(filters);
        } else {
            for (int i : leftReferencedInputs) {
                for (int j : rightReferencedInputs) {
                    joinInputs.add(Pair.of(i, j));
                    joinTypes.add(join.getJoinType());
                    joinFilters.add(filters);
                }
            }
        }

        // We can now create a multijoin operator
        RexNode newCondition = RexUtil.flatten(rexBuilder,
                RexUtil.composeConjunction(rexBuilder, newJoinCondition, false));
        return new HiveMultiJoin(join.getCluster(), newInputs, newCondition, join.getRowType(), joinInputs,
                joinTypes, joinFilters);
    }

    /*
     * Returns true if the join conditions execute over the same keys
     */
    private static boolean isCombinableJoin(HiveJoin join, HiveJoin leftChildJoin) throws CalciteSemanticException {
        final JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join,
                join.getCondition());
        final JoinPredicateInfo leftChildJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo
                .constructJoinPredicateInfo(leftChildJoin, leftChildJoin.getCondition());
        return isCombinablePredicate(joinPredInfo, leftChildJoinPredInfo, leftChildJoin.getInputs().size());
    }

    /*
     * Returns true if the join conditions execute over the same keys
     */
    private static boolean isCombinableJoin(HiveJoin join, HiveMultiJoin leftChildJoin)
            throws CalciteSemanticException {
        final JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join,
                join.getCondition());
        final JoinPredicateInfo leftChildJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo
                .constructJoinPredicateInfo(leftChildJoin, leftChildJoin.getCondition());
        return isCombinablePredicate(joinPredInfo, leftChildJoinPredInfo, leftChildJoin.getInputs().size());
    }

    /*
     * To be able to combine a parent join and its left input join child,
     * the left keys over which the parent join is executed need to be the same
     * than those of the child join.
     * Thus, we iterate over the different inputs of the child, checking if the
     * keys of the parent are the same
     */
    private static boolean isCombinablePredicate(JoinPredicateInfo joinPredInfo,
            JoinPredicateInfo leftChildJoinPredInfo, int noLeftChildInputs) throws CalciteSemanticException {
        Set<Integer> keys = joinPredInfo.getProjsJoinKeysInChildSchema(0);
        if (keys.isEmpty()) {
            return false;
        }
        for (int i = 0; i < noLeftChildInputs; i++) {
            if (keys.equals(leftChildJoinPredInfo.getProjsJoinKeysInJoinSchema(i))) {
                return true;
            }
        }
        return false;
    }
}