org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;

import com.google.common.collect.ImmutableList;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelReferentialConstraint;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Aggregate.Group;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexTableInputRef;
import org.apache.calcite.rex.RexTableInputRef.RelTableRef;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveRelOptUtil extends RelOptUtil {

    private static final Logger LOG = LoggerFactory.getLogger(HiveRelOptUtil.class);

    /**
     * Splits out the equi-join (and optionally, a single non-equi) components
     * of a join condition, and returns what's left. Projection might be
     * required by the caller to provide join keys that are not direct field
     * references.
     *
     * @param sysFieldList  list of system fields
     * @param inputs        join inputs
     * @param condition     join condition
     * @param joinKeys      The join keys from the inputs which are equi-join
     *                      keys
     * @param filterNulls   The join key positions for which null values will not
     *                      match. null values only match for the "is not distinct
     *                      from" condition.
     * @param rangeOp       if null, only locate equi-joins; otherwise, locate a
     *                      single non-equi join predicate and return its operator
     *                      in this list; join keys associated with the non-equi
     *                      join predicate are at the end of the key lists
     *                      returned
     * @return What's left, never null
     * @throws CalciteSemanticException
     */
    public static RexNode splitHiveJoinCondition(List<RelDataTypeField> sysFieldList, List<RelNode> inputs,
            RexNode condition, List<List<RexNode>> joinKeys, List<Integer> filterNulls, List<SqlOperator> rangeOp)
            throws CalciteSemanticException {
        final List<RexNode> nonEquiList = new ArrayList<>();

        splitJoinCondition(sysFieldList, inputs, condition, joinKeys, filterNulls, rangeOp, nonEquiList);

        // Convert the remainders into a list that are AND'ed together.
        return RexUtil.composeConjunction(inputs.get(0).getCluster().getRexBuilder(), nonEquiList, false);
    }

    private static void splitJoinCondition(List<RelDataTypeField> sysFieldList, List<RelNode> inputs,
            RexNode condition, List<List<RexNode>> joinKeys, List<Integer> filterNulls, List<SqlOperator> rangeOp,
            List<RexNode> nonEquiList) throws CalciteSemanticException {
        final int sysFieldCount = sysFieldList.size();
        final RelOptCluster cluster = inputs.get(0).getCluster();
        final RexBuilder rexBuilder = cluster.getRexBuilder();

        if (condition instanceof RexCall) {
            RexCall call = (RexCall) condition;
            if (call.getOperator() == SqlStdOperatorTable.AND) {
                for (RexNode operand : call.getOperands()) {
                    splitJoinCondition(sysFieldList, inputs, operand, joinKeys, filterNulls, rangeOp, nonEquiList);
                }
                return;
            }

            RexNode leftKey = null;
            RexNode rightKey = null;
            int leftInput = 0;
            int rightInput = 0;
            List<RelDataTypeField> leftFields = null;
            List<RelDataTypeField> rightFields = null;
            boolean reverse = false;

            SqlKind kind = call.getKind();

            // Only consider range operators if we haven't already seen one
            if ((kind == SqlKind.EQUALS) || (filterNulls != null && kind == SqlKind.IS_NOT_DISTINCT_FROM)
                    || (rangeOp != null && rangeOp.isEmpty()
                            && (kind == SqlKind.GREATER_THAN || kind == SqlKind.GREATER_THAN_OR_EQUAL
                                    || kind == SqlKind.LESS_THAN || kind == SqlKind.LESS_THAN_OR_EQUAL))) {
                final List<RexNode> operands = call.getOperands();
                RexNode op0 = operands.get(0);
                RexNode op1 = operands.get(1);

                final ImmutableBitSet projRefs0 = InputFinder.bits(op0);
                final ImmutableBitSet projRefs1 = InputFinder.bits(op1);

                final ImmutableBitSet[] inputsRange = new ImmutableBitSet[inputs.size()];
                int totalFieldCount = 0;
                for (int i = 0; i < inputs.size(); i++) {
                    final int firstField = totalFieldCount + sysFieldCount;
                    totalFieldCount = firstField + inputs.get(i).getRowType().getFieldCount();
                    inputsRange[i] = ImmutableBitSet.range(firstField, totalFieldCount);
                }

                boolean foundBothInputs = false;
                for (int i = 0; i < inputs.size() && !foundBothInputs; i++) {
                    if (projRefs0.intersects(inputsRange[i])
                            && projRefs0.union(inputsRange[i]).equals(inputsRange[i])) {
                        if (leftKey == null) {
                            leftKey = op0;
                            leftInput = i;
                            leftFields = inputs.get(leftInput).getRowType().getFieldList();
                        } else {
                            rightKey = op0;
                            rightInput = i;
                            rightFields = inputs.get(rightInput).getRowType().getFieldList();
                            reverse = true;
                            foundBothInputs = true;
                        }
                    } else if (projRefs1.intersects(inputsRange[i])
                            && projRefs1.union(inputsRange[i]).equals(inputsRange[i])) {
                        if (leftKey == null) {
                            leftKey = op1;
                            leftInput = i;
                            leftFields = inputs.get(leftInput).getRowType().getFieldList();
                        } else {
                            rightKey = op1;
                            rightInput = i;
                            rightFields = inputs.get(rightInput).getRowType().getFieldList();
                            foundBothInputs = true;
                        }
                    }
                }

                if ((leftKey != null) && (rightKey != null)) {
                    // adjustment array
                    int[] adjustments = new int[totalFieldCount];
                    for (int i = 0; i < inputs.size(); i++) {
                        final int adjustment = inputsRange[i].nextSetBit(0);
                        for (int j = adjustment; j < inputsRange[i].length(); j++) {
                            adjustments[j] = -adjustment;
                        }
                    }

                    // replace right Key input ref
                    rightKey = rightKey.accept(
                            new RelOptUtil.RexInputConverter(rexBuilder, rightFields, rightFields, adjustments));

                    // left key only needs to be adjusted if there are system
                    // fields, but do it for uniformity
                    leftKey = leftKey.accept(
                            new RelOptUtil.RexInputConverter(rexBuilder, leftFields, leftFields, adjustments));

                    RelDataType leftKeyType = leftKey.getType();
                    RelDataType rightKeyType = rightKey.getType();

                    if (leftKeyType != rightKeyType) {
                        // perform casting using Hive rules
                        TypeInfo rType = TypeConverter.convert(rightKeyType);
                        TypeInfo lType = TypeConverter.convert(leftKeyType);
                        TypeInfo tgtType = FunctionRegistry.getCommonClassForComparison(lType, rType);

                        if (tgtType == null) {
                            throw new CalciteSemanticException(
                                    "Cannot find common type for join keys " + leftKey + " (type " + leftKeyType
                                            + ") and " + rightKey + " (type " + rightKeyType + ")");
                        }
                        RelDataType targetKeyType = TypeConverter.convert(tgtType, rexBuilder.getTypeFactory());

                        if (leftKeyType != targetKeyType
                                && TypeInfoUtils.isConversionRequiredForComparison(tgtType, lType)) {
                            leftKey = rexBuilder.makeCast(targetKeyType, leftKey);
                        }

                        if (rightKeyType != targetKeyType
                                && TypeInfoUtils.isConversionRequiredForComparison(tgtType, rType)) {
                            rightKey = rexBuilder.makeCast(targetKeyType, rightKey);
                        }
                    }
                }
            }

            if ((leftKey != null) && (rightKey != null)) {
                // found suitable join keys
                // add them to key list, ensuring that if there is a
                // non-equi join predicate, it appears at the end of the
                // key list; also mark the null filtering property
                addJoinKey(joinKeys.get(leftInput), leftKey, (rangeOp != null) && !rangeOp.isEmpty());
                addJoinKey(joinKeys.get(rightInput), rightKey, (rangeOp != null) && !rangeOp.isEmpty());
                if (filterNulls != null && kind == SqlKind.EQUALS) {
                    // nulls are considered not matching for equality comparison
                    // add the position of the most recently inserted key
                    filterNulls.add(joinKeys.get(leftInput).size() - 1);
                }
                if (rangeOp != null && kind != SqlKind.EQUALS && kind != SqlKind.IS_DISTINCT_FROM) {
                    if (reverse) {
                        kind = reverse(kind);
                    }
                    rangeOp.add(op(kind, call.getOperator()));
                }
                return;
            } // else fall through and add this condition as nonEqui condition
        }

        // The operator is not of RexCall type
        // So we fail. Fall through.
        // Add this condition to the list of non-equi-join conditions.
        nonEquiList.add(condition);
    }

    private static SqlKind reverse(SqlKind kind) {
        switch (kind) {
        case GREATER_THAN:
            return SqlKind.LESS_THAN;
        case GREATER_THAN_OR_EQUAL:
            return SqlKind.LESS_THAN_OR_EQUAL;
        case LESS_THAN:
            return SqlKind.GREATER_THAN;
        case LESS_THAN_OR_EQUAL:
            return SqlKind.GREATER_THAN_OR_EQUAL;
        default:
            return kind;
        }
    }

    private static void addJoinKey(List<RexNode> joinKeyList, RexNode key, boolean preserveLastElementInList) {
        if (!joinKeyList.isEmpty() && preserveLastElementInList) {
            joinKeyList.add(joinKeyList.size() - 1, key);
        } else {
            joinKeyList.add(key);
        }
    }

    /**
     * Creates a relational expression that projects the given fields of the
     * input.
     *
     * <p>Optimizes if the fields are the identity projection.
     *
     * @param relBuilder RelBuilder
     * @param child Input relational expression
     * @param posList Source of each projected field
     * @return Relational expression that projects given fields
     */
    public static RelNode createProject(final RelBuilder relBuilder, final RelNode child,
            final List<Integer> posList) {
        RelDataType rowType = child.getRowType();
        final List<String> fieldNames = rowType.getFieldNames();
        final RexBuilder rexBuilder = child.getCluster().getRexBuilder();
        return createProject(child, new AbstractList<RexNode>() {
            public int size() {
                return posList.size();
            }

            public RexNode get(int index) {
                final int pos = posList.get(index);
                return rexBuilder.makeInputRef(child, pos);
            }
        }, new AbstractList<String>() {
            public int size() {
                return posList.size();
            }

            public String get(int index) {
                final int pos = posList.get(index);
                return fieldNames.get(pos);
            }
        }, true, relBuilder);
    }

    public static RexNode splitCorrelatedFilterCondition(Filter filter, List<RexNode> joinKeys,
            List<RexNode> correlatedJoinKeys, boolean extractCorrelatedFieldAccess) {
        final List<RexNode> nonEquiList = new ArrayList<>();

        splitCorrelatedFilterCondition(filter, filter.getCondition(), joinKeys, correlatedJoinKeys, nonEquiList,
                extractCorrelatedFieldAccess);

        // Convert the remainders into a list that are AND'ed together.
        return RexUtil.composeConjunction(filter.getCluster().getRexBuilder(), nonEquiList, true);
    }

    private static void splitCorrelatedFilterCondition(Filter filter, RexNode condition, List<RexNode> joinKeys,
            List<RexNode> correlatedJoinKeys, List<RexNode> nonEquiList, boolean extractCorrelatedFieldAccess) {
        if (condition instanceof RexCall) {
            RexCall call = (RexCall) condition;
            if (call.getOperator().getKind() == SqlKind.AND) {
                for (RexNode operand : call.getOperands()) {
                    splitCorrelatedFilterCondition(filter, operand, joinKeys, correlatedJoinKeys, nonEquiList,
                            extractCorrelatedFieldAccess);
                }
                return;
            }

            if (call.getOperator().getKind() == SqlKind.EQUALS) {
                final List<RexNode> operands = call.getOperands();
                RexNode op0 = operands.get(0);
                RexNode op1 = operands.get(1);

                if (extractCorrelatedFieldAccess) {
                    if (!RexUtil.containsFieldAccess(op0) && (op1 instanceof RexFieldAccess)) {
                        joinKeys.add(op0);
                        correlatedJoinKeys.add(op1);
                        return;
                    } else if ((op0 instanceof RexFieldAccess) && !RexUtil.containsFieldAccess(op1)) {
                        correlatedJoinKeys.add(op0);
                        joinKeys.add(op1);
                        return;
                    }
                } else {
                    if (!(RexUtil.containsInputRef(op0)) && (op1 instanceof RexInputRef)) {
                        correlatedJoinKeys.add(op0);
                        joinKeys.add(op1);
                        return;
                    } else if ((op0 instanceof RexInputRef) && !(RexUtil.containsInputRef(op1))) {
                        joinKeys.add(op0);
                        correlatedJoinKeys.add(op1);
                        return;
                    }
                }
            }
        }

        // The operator is not of RexCall type
        // So we fail. Fall through.
        // Add this condition to the list of non-equi-join conditions.
        nonEquiList.add(condition);
    }

    /**
     * Creates a LogicalAggregate that removes all duplicates from the result of
     * an underlying relational expression.
     *
     * @param rel underlying rel
     * @return rel implementing SingleValueAgg
     */
    public static RelNode createSingleValueAggRel(RelOptCluster cluster, RelNode rel,
            RelFactories.AggregateFactory aggregateFactory) {
        // assert (rel.getRowType().getFieldCount() == 1);
        final int aggCallCnt = rel.getRowType().getFieldCount();
        final List<AggregateCall> aggCalls = new ArrayList<>();

        for (int i = 0; i < aggCallCnt; i++) {
            aggCalls.add(AggregateCall.create(SqlStdOperatorTable.SINGLE_VALUE, false, false, ImmutableList.of(i),
                    -1, 0, rel, null, null));
        }

        return aggregateFactory.createAggregate(rel, false, ImmutableBitSet.of(), null, aggCalls);
    }

    /**
     * Given a RelNode, it checks whether there is any filtering condition
     * below. Basically we check whether the operators
     * below altered the PK cardinality in any way
     */
    public static boolean isRowFilteringPlan(final RelMetadataQuery mq, RelNode operator) {
        final Multimap<Class<? extends RelNode>, RelNode> nodesBelowNonFkInput = mq.getNodeTypes(operator);
        for (Entry<Class<? extends RelNode>, Collection<RelNode>> e : nodesBelowNonFkInput.asMap().entrySet()) {
            if (e.getKey() == TableScan.class) {
                if (e.getValue().size() > 1) {
                    // Bail out as we may not have more than one TS on non-FK side
                    return true;
                }
            } else if (e.getKey() == Project.class) {
                // We check there is no windowing expression
                for (RelNode node : e.getValue()) {
                    Project p = (Project) node;
                    for (RexNode expr : p.getChildExps()) {
                        if (expr instanceof RexOver) {
                            // Bail out as it may change cardinality
                            return true;
                        }
                    }
                }
            } else if (e.getKey() == Aggregate.class) {
                // We check there is are not grouping sets
                for (RelNode node : e.getValue()) {
                    Aggregate a = (Aggregate) node;
                    if (a.getGroupType() != Group.SIMPLE) {
                        // Bail out as it may change cardinality
                        return true;
                    }
                }
            } else if (e.getKey() == Sort.class) {
                // We check whether there is a limit clause
                for (RelNode node : e.getValue()) {
                    Sort s = (Sort) node;
                    if (s.fetch != null || s.offset != null) {
                        // Bail out as it may change cardinality
                        return true;
                    }
                }
            } else {
                // Bail out, we cannot rewrite the expression if non-fk side cardinality
                // is being altered
                return true;
            }
        }
        // It passed all the tests
        return false;
    }

    /**
     * Returns a triple where first value represents whether we could extract a FK-PK join
     * or not, the second value is a pair with the column from left and right input that
     * are used for the FK-PK join, and the third value are the predicates that are not
     * part of the FK-PK condition. Currently we can only extract one FK-PK join.
     */
    public static PKFKJoinInfo extractPKFKJoin(Join join, List<RexNode> joinFilters, boolean leftInputPotentialFK,
            RelMetadataQuery mq) {
        final List<RexNode> residualPreds = new ArrayList<>();
        final JoinRelType joinType = join.getJoinType();
        final RelNode fkInput = leftInputPotentialFK ? join.getLeft() : join.getRight();
        final PKFKJoinInfo cannotExtract = PKFKJoinInfo.of(false, null, null);

        if (joinType != JoinRelType.INNER) {
            // If it is not an inner, we transform it as the metadata
            // providers for expressions do not pull information through
            // outer join (as it would not be correct)
            join = join.copy(join.getTraitSet(), join.getCluster().getRexBuilder().makeLiteral(true),
                    join.getLeft(), join.getRight(), JoinRelType.INNER, false);
        }

        // 1) Gather all tables from the FK side and the table from the
        // non-FK side
        final Set<RelTableRef> leftTables = mq.getTableReferences(join.getLeft());
        final Set<RelTableRef> rightTables = Sets.difference(mq.getTableReferences(join),
                mq.getTableReferences(join.getLeft()));
        final Set<RelTableRef> fkTables = join.getLeft() == fkInput ? leftTables : rightTables;
        final Set<RelTableRef> nonFkTables = join.getLeft() == fkInput ? rightTables : leftTables;

        // 2) Check whether there is a FK relationship
        Set<RexCall> candidatePredicates = new HashSet<>();
        EquivalenceClasses ec = new EquivalenceClasses();
        for (RexNode conj : joinFilters) {
            if (!conj.isA(SqlKind.EQUALS)) {
                // Not an equality, continue
                residualPreds.add(conj);
                continue;
            }
            RexCall equiCond = (RexCall) conj;
            RexNode eqOp1 = equiCond.getOperands().get(0);
            if (!RexUtil.isReferenceOrAccess(eqOp1, true)) {
                // Ignore
                residualPreds.add(conj);
                continue;
            }
            Set<RexNode> eqOp1ExprsLineage = mq.getExpressionLineage(join, eqOp1);
            if (eqOp1ExprsLineage == null) {
                // Cannot be mapped, continue
                residualPreds.add(conj);
                continue;
            }
            RexNode eqOp2 = equiCond.getOperands().get(1);
            if (!RexUtil.isReferenceOrAccess(eqOp2, true)) {
                // Ignore
                residualPreds.add(conj);
                continue;
            }
            Set<RexNode> eqOp2ExprsLineage = mq.getExpressionLineage(join, eqOp2);
            if (eqOp2ExprsLineage == null) {
                // Cannot be mapped, continue
                residualPreds.add(conj);
                continue;
            }
            List<RexTableInputRef> eqOp2ExprsFiltered = null;
            for (RexNode eqOpExprLineage1 : eqOp1ExprsLineage) {
                RexTableInputRef inputRef1 = extractTableInputRef(eqOpExprLineage1);
                if (inputRef1 == null) {
                    // This condition could not be map into an input reference
                    continue;
                }
                if (eqOp2ExprsFiltered == null) {
                    // First iteration
                    eqOp2ExprsFiltered = new ArrayList<>();
                    for (RexNode eqOpExprLineage2 : eqOp2ExprsLineage) {
                        RexTableInputRef inputRef2 = extractTableInputRef(eqOpExprLineage2);
                        if (inputRef2 == null) {
                            // Bail out as this condition could not be map into an input reference
                            continue;
                        }
                        // Add to list of expressions for follow-up iterations
                        eqOp2ExprsFiltered.add(inputRef2);
                        // Add to equivalence classes and backwards mapping
                        ec.addEquivalence(inputRef1, inputRef2, equiCond);
                        candidatePredicates.add(equiCond);
                    }
                } else {
                    // Rest of iterations, only adding, no checking
                    for (RexTableInputRef inputRef2 : eqOp2ExprsFiltered) {
                        ec.addEquivalence(inputRef1, inputRef2, equiCond);
                    }
                }
            }
            if (!candidatePredicates.contains(conj)) {
                // We add it to residual already
                residualPreds.add(conj);
            }
        }
        if (ec.getEquivalenceClassesMap().isEmpty()) {
            // This may be a cartesian product, we bail out
            return cannotExtract;
        }

        // 4) For each table, check whether there is a matching on the non-FK side.
        // If there is and it is the only condition, we are ready to transform
        for (final RelTableRef nonFkTable : nonFkTables) {
            final List<String> nonFkTableQName = nonFkTable.getQualifiedName();
            for (RelTableRef tRef : fkTables) {
                List<RelReferentialConstraint> constraints = tRef.getTable().getReferentialConstraints();
                for (RelReferentialConstraint constraint : constraints) {
                    if (constraint.getTargetQualifiedName().equals(nonFkTableQName)) {
                        EquivalenceClasses ecT = EquivalenceClasses.copy(ec);
                        Set<RexNode> removedOriginalPredicates = new HashSet<>();
                        ImmutableBitSet.Builder lBitSet = ImmutableBitSet.builder();
                        ImmutableBitSet.Builder rBitSet = ImmutableBitSet.builder();
                        boolean allContained = true;
                        for (int pos = 0; pos < constraint.getNumColumns(); pos++) {
                            int foreignKeyPos = constraint.getColumnPairs().get(pos).source;
                            RelDataType foreignKeyColumnType = tRef.getTable().getRowType().getFieldList()
                                    .get(foreignKeyPos).getType();
                            RexTableInputRef foreignKeyColumnRef = RexTableInputRef.of(tRef, foreignKeyPos,
                                    foreignKeyColumnType);
                            int uniqueKeyPos = constraint.getColumnPairs().get(pos).target;
                            RexTableInputRef uniqueKeyColumnRef = RexTableInputRef.of(nonFkTable, uniqueKeyPos,
                                    nonFkTable.getTable().getRowType().getFieldList().get(uniqueKeyPos).getType());
                            if (ecT.getEquivalenceClassesMap().containsKey(uniqueKeyColumnRef)
                                    && ecT.getEquivalenceClassesMap().get(uniqueKeyColumnRef)
                                            .contains(foreignKeyColumnRef)) {
                                // Remove this condition from eq classes as we have checked that it is present
                                // in the join condition. In turn, populate the columns that are referenced
                                // from the join inputs
                                for (RexCall originalPred : ecT.removeEquivalence(uniqueKeyColumnRef,
                                        foreignKeyColumnRef)) {
                                    ImmutableBitSet leftCols = RelOptUtil.InputFinder
                                            .bits(originalPred.getOperands().get(0));
                                    ImmutableBitSet rightCols = RelOptUtil.InputFinder
                                            .bits(originalPred.getOperands().get(1));
                                    // Get length and flip column references if join condition specified in
                                    // reverse order to join sources
                                    int nFieldsLeft = join.getLeft().getRowType().getFieldList().size();
                                    int nFieldsRight = join.getRight().getRowType().getFieldList().size();
                                    int nSysFields = join.getSystemFieldList().size();
                                    ImmutableBitSet rightFieldsBitSet = ImmutableBitSet.range(
                                            nSysFields + nFieldsLeft, nSysFields + nFieldsLeft + nFieldsRight);
                                    if (rightFieldsBitSet.contains(leftCols)) {
                                        ImmutableBitSet t = leftCols;
                                        leftCols = rightCols;
                                        rightCols = t;
                                    }
                                    lBitSet.set(leftCols.nextSetBit(0) - nSysFields);
                                    rBitSet.set(rightCols.nextSetBit(0) - (nSysFields + nFieldsLeft));
                                    removedOriginalPredicates.add(originalPred);
                                }
                            } else {
                                // No relationship, we cannot do anything
                                allContained = false;
                                break;
                            }
                        }
                        if (allContained) {
                            // This is a PK-FK, reassign equivalence classes and remove conditions
                            // TODO: Support inference of multiple PK-FK relationships

                            // 4.1) Add to residual whatever is remaining
                            candidatePredicates.removeAll(removedOriginalPredicates);
                            residualPreds.addAll(candidatePredicates);
                            // 4.2) Return result
                            return PKFKJoinInfo.of(true, Pair.of(lBitSet.build(), rBitSet.build()), residualPreds);
                        }
                    }
                }
            }
        }

        return cannotExtract;
    }

    public static class PKFKJoinInfo {
        public final boolean isPkFkJoin;
        public final Pair<ImmutableBitSet, ImmutableBitSet> pkFkJoinColumns;
        public final List<RexNode> additionalPredicates;

        private PKFKJoinInfo(boolean isPkFkJoin, Pair<ImmutableBitSet, ImmutableBitSet> pkFkJoinColumns,
                List<RexNode> additionalPredicates) {
            this.isPkFkJoin = isPkFkJoin;
            this.pkFkJoinColumns = pkFkJoinColumns;
            this.additionalPredicates = additionalPredicates == null ? null
                    : ImmutableList.copyOf(additionalPredicates);
        }

        public static PKFKJoinInfo of(boolean isPkFkJoin, Pair<ImmutableBitSet, ImmutableBitSet> pkFkJoinColumns,
                List<RexNode> additionalPredicates) {
            return new PKFKJoinInfo(isPkFkJoin, pkFkJoinColumns, additionalPredicates);
        }
    }

    public static RewritablePKFKJoinInfo isRewritablePKFKJoin(Join join, boolean leftInputPotentialFK,
            RelMetadataQuery mq) {
        final JoinRelType joinType = join.getJoinType();
        final RexNode cond = join.getCondition();
        final RelNode fkInput = leftInputPotentialFK ? join.getLeft() : join.getRight();
        final RelNode nonFkInput = leftInputPotentialFK ? join.getRight() : join.getLeft();
        final RewritablePKFKJoinInfo nonRewritable = RewritablePKFKJoinInfo.of(false, null);

        if (joinType != JoinRelType.INNER) {
            // If it is not an inner, we transform it as the metadata
            // providers for expressions do not pull information through
            // outer join (as it would not be correct)
            join = join.copy(join.getTraitSet(), cond, join.getLeft(), join.getRight(), JoinRelType.INNER, false);
        }

        // 1) Check whether there is any filtering condition on the
        // non-FK side. Basically we check whether the operators
        // below altered the PK cardinality in any way
        if (HiveRelOptUtil.isRowFilteringPlan(mq, nonFkInput)) {
            return nonRewritable;
        }

        // 2) Check whether there is an FK relationship
        final Map<RexTableInputRef, RexNode> refToRex = new HashMap<>();
        final EquivalenceClasses ec = new EquivalenceClasses();
        for (RexNode conj : RelOptUtil.conjunctions(cond)) {
            if (!conj.isA(SqlKind.EQUALS)) {
                // Not an equality, we bail out
                return nonRewritable;
            }
            RexCall equiCond = (RexCall) conj;
            RexNode eqOp1 = equiCond.getOperands().get(0);
            Set<RexNode> eqOp1ExprsLineage = mq.getExpressionLineage(join, eqOp1);
            if (eqOp1ExprsLineage == null) {
                // Cannot be mapped, bail out
                return nonRewritable;
            }
            RexNode eqOp2 = equiCond.getOperands().get(1);
            Set<RexNode> eqOp2ExprsLineage = mq.getExpressionLineage(join, eqOp2);
            if (eqOp2ExprsLineage == null) {
                // Cannot be mapped, bail out
                return nonRewritable;
            }
            List<RexTableInputRef> eqOp2ExprsFiltered = null;
            for (RexNode eqOpExprLineage1 : eqOp1ExprsLineage) {
                RexTableInputRef inputRef1 = extractTableInputRef(eqOpExprLineage1);
                if (inputRef1 == null) {
                    // Bail out as this condition could not be map into an input reference
                    return nonRewritable;
                }
                refToRex.put(inputRef1, eqOp1);
                if (eqOp2ExprsFiltered == null) {
                    // First iteration
                    eqOp2ExprsFiltered = new ArrayList<>();
                    for (RexNode eqOpExprLineage2 : eqOp2ExprsLineage) {
                        RexTableInputRef inputRef2 = extractTableInputRef(eqOpExprLineage2);
                        if (inputRef2 == null) {
                            // Bail out as this condition could not be map into an input reference
                            return nonRewritable;
                        }
                        // Add to list of expressions for follow-up iterations
                        eqOp2ExprsFiltered.add(inputRef2);
                        // Add to equivalence classes and backwards mapping
                        ec.addEquivalence(inputRef1, inputRef2);
                        refToRex.put(inputRef2, eqOp2);
                    }
                } else {
                    // Rest of iterations, only adding, no checking
                    for (RexTableInputRef inputRef2 : eqOp2ExprsFiltered) {
                        ec.addEquivalence(inputRef1, inputRef2);
                    }
                }
            }
        }
        if (ec.getEquivalenceClassesMap().isEmpty()) {
            // This may be a cartesian product, we bail out
            return nonRewritable;
        }

        // 3) Gather all tables from the FK side and the table from the
        // non-FK side
        final Set<RelTableRef> leftTables = mq.getTableReferences(join.getLeft());
        final Set<RelTableRef> rightTables = Sets.difference(mq.getTableReferences(join),
                mq.getTableReferences(join.getLeft()));
        final Set<RelTableRef> fkTables = join.getLeft() == fkInput ? leftTables : rightTables;
        final Set<RelTableRef> nonFkTables = join.getLeft() == fkInput ? rightTables : leftTables;
        assert nonFkTables.size() == 1;
        final RelTableRef nonFkTable = nonFkTables.iterator().next();
        final List<String> nonFkTableQName = nonFkTable.getQualifiedName();

        // 4) For each table, check whether there is a matching on the non-FK side.
        // If there is and it is the only condition, we are ready to transform
        boolean canBeRewritten = false;
        List<RexNode> nullableNodes = null;
        for (RelTableRef tRef : fkTables) {
            List<RelReferentialConstraint> constraints = tRef.getTable().getReferentialConstraints();
            for (RelReferentialConstraint constraint : constraints) {
                if (constraint.getTargetQualifiedName().equals(nonFkTableQName)) {
                    nullableNodes = new ArrayList<>();
                    EquivalenceClasses ecT = EquivalenceClasses.copy(ec);
                    boolean allContained = true;
                    for (int pos = 0; pos < constraint.getNumColumns(); pos++) {
                        int foreignKeyPos = constraint.getColumnPairs().get(pos).source;
                        RelDataType foreignKeyColumnType = tRef.getTable().getRowType().getFieldList()
                                .get(foreignKeyPos).getType();
                        RexTableInputRef foreignKeyColumnRef = RexTableInputRef.of(tRef, foreignKeyPos,
                                foreignKeyColumnType);
                        int uniqueKeyPos = constraint.getColumnPairs().get(pos).target;
                        RexTableInputRef uniqueKeyColumnRef = RexTableInputRef.of(nonFkTable, uniqueKeyPos,
                                nonFkTable.getTable().getRowType().getFieldList().get(uniqueKeyPos).getType());
                        if (ecT.getEquivalenceClassesMap().containsKey(uniqueKeyColumnRef) && ecT
                                .getEquivalenceClassesMap().get(uniqueKeyColumnRef).contains(foreignKeyColumnRef)) {
                            if (foreignKeyColumnType.isNullable()) {
                                if (joinType == JoinRelType.INNER) {
                                    // If it is nullable and it is an INNER, we just need a IS NOT NULL filter
                                    RexNode originalCondOp = refToRex.get(foreignKeyColumnRef);
                                    assert originalCondOp != null;
                                    nullableNodes.add(originalCondOp);
                                } else {
                                    // If it is nullable and this is not an INNER, we cannot execute any transformation
                                    allContained = false;
                                    break;
                                }
                            }
                            // Remove this condition from eq classes as we have checked that it is present
                            // in the join condition
                            ecT.removeEquivalence(uniqueKeyColumnRef, foreignKeyColumnRef);
                        } else {
                            // No relationship, we cannot do anything
                            allContained = false;
                            break;
                        }
                    }
                    if (allContained && ecT.getEquivalenceClassesMap().isEmpty()) {
                        // We made it
                        canBeRewritten = true;
                        break;
                    }
                }
            }
        }

        return RewritablePKFKJoinInfo.of(canBeRewritten, nullableNodes);
    }

    public static class RewritablePKFKJoinInfo {
        public final boolean rewritable;
        public final List<RexNode> nullableNodes;

        private RewritablePKFKJoinInfo(boolean rewritable, List<RexNode> nullableNodes) {
            this.rewritable = rewritable;
            this.nullableNodes = nullableNodes == null ? null : ImmutableList.copyOf(nullableNodes);
        }

        public static RewritablePKFKJoinInfo of(boolean rewritable, List<RexNode> nullableNodes) {
            return new RewritablePKFKJoinInfo(rewritable, nullableNodes);
        }
    }

    private static RexTableInputRef extractTableInputRef(RexNode node) {
        RexTableInputRef ref = null;
        if (node instanceof RexTableInputRef) {
            ref = (RexTableInputRef) node;
        } else if (RexUtil.isLosslessCast(node)
                && ((RexCall) node).getOperands().get(0) instanceof RexTableInputRef) {
            ref = (RexTableInputRef) ((RexCall) node).getOperands().get(0);
        }
        return ref;
    }

    /**
     * Class representing an equivalence class, i.e., a set of equivalent columns
     *
     * TODO: This is a subset of a private class in materialized view rewriting
     * in Calcite. It should be moved to its own class in Calcite so it can be
     * accessible here.
     */
    private static class EquivalenceClasses {

        // Contains the node to equivalence class nodes
        private final Map<RexTableInputRef, Set<RexTableInputRef>> nodeToEquivalenceClass;
        // Contains the pair of equivalences to original expression that they originate from
        private final Multimap<Pair<RexTableInputRef, RexTableInputRef>, RexCall> equivalenceToOriginalNode;

        protected EquivalenceClasses() {
            nodeToEquivalenceClass = new HashMap<>();
            equivalenceToOriginalNode = HashMultimap.create();
        }

        protected void addEquivalence(RexTableInputRef p1, RexTableInputRef p2, RexCall originalCond) {
            addEquivalence(p1, p2);
            equivalenceToOriginalNode.put(Pair.of(p1, p2), originalCond);
            equivalenceToOriginalNode.put(Pair.of(p2, p1), originalCond);
        }

        protected void addEquivalence(RexTableInputRef p1, RexTableInputRef p2) {
            Set<RexTableInputRef> c1 = nodeToEquivalenceClass.get(p1);
            Set<RexTableInputRef> c2 = nodeToEquivalenceClass.get(p2);
            if (c1 != null && c2 != null) {
                // Both present, we need to merge
                if (c1.size() < c2.size()) {
                    // We swap them to merge
                    Set<RexTableInputRef> c2Temp = c2;
                    c2 = c1;
                    c1 = c2Temp;
                }
                for (RexTableInputRef newRef : c2) {
                    c1.add(newRef);
                    nodeToEquivalenceClass.put(newRef, c1);
                }
            } else if (c1 != null) {
                // p1 present, we need to merge into it
                c1.add(p2);
                nodeToEquivalenceClass.put(p2, c1);
            } else if (c2 != null) {
                // p2 present, we need to merge into it
                c2.add(p1);
                nodeToEquivalenceClass.put(p1, c2);
            } else {
                // None are present, add to same equivalence class
                Set<RexTableInputRef> equivalenceClass = new LinkedHashSet<>();
                equivalenceClass.add(p1);
                equivalenceClass.add(p2);
                nodeToEquivalenceClass.put(p1, equivalenceClass);
                nodeToEquivalenceClass.put(p2, equivalenceClass);
            }
        }

        protected Map<RexTableInputRef, Set<RexTableInputRef>> getEquivalenceClassesMap() {
            return nodeToEquivalenceClass;
        }

        // Returns the original nodes that the equivalences were generated from
        protected Set<RexCall> removeEquivalence(RexTableInputRef p1, RexTableInputRef p2) {
            nodeToEquivalenceClass.get(p1).remove(p2);
            if (nodeToEquivalenceClass.get(p1).size() == 1) { // self
                nodeToEquivalenceClass.remove(p1);
            }
            nodeToEquivalenceClass.get(p2).remove(p1);
            if (nodeToEquivalenceClass.get(p2).size() == 1) { // self
                nodeToEquivalenceClass.remove(p2);
            }
            Set<RexCall> originalNodes = new HashSet<>();
            originalNodes.addAll(equivalenceToOriginalNode.removeAll(Pair.of(p1, p2)));
            originalNodes.addAll(equivalenceToOriginalNode.removeAll(Pair.of(p2, p1)));
            return originalNodes;
        }

        protected static EquivalenceClasses copy(EquivalenceClasses ec) {
            final EquivalenceClasses newEc = new EquivalenceClasses();
            for (Entry<RexTableInputRef, Set<RexTableInputRef>> e : ec.nodeToEquivalenceClass.entrySet()) {
                newEc.nodeToEquivalenceClass.put(e.getKey(), Sets.newLinkedHashSet(e.getValue()));
            }
            for (Entry<Pair<RexTableInputRef, RexTableInputRef>, Collection<RexCall>> e : ec.equivalenceToOriginalNode
                    .asMap().entrySet()) {
                newEc.equivalenceToOriginalNode.putAll(e.getKey(), e.getValue());
            }
            return newEc;
        }
    }

    public static Pair<RelOptTable, List<Integer>> getColumnOriginSet(RelNode rel, ImmutableBitSet colSet) {
        RelMetadataQuery mq = rel.getCluster().getMetadataQuery();
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        Map<RelTableRef, List<Integer>> tabToOriginColumns = new HashMap<>();
        for (int col : colSet) {
            final RexInputRef tempColRef = rexBuilder.makeInputRef(rel, col);
            Set<RexNode> columnOrigins = mq.getExpressionLineage(rel, tempColRef);
            if (null == columnOrigins || columnOrigins.isEmpty()) {
                // if even on
                return null;
            }
            // we have either one or multiple origins of the column, we need to make sure that all of the column
            for (RexNode orgCol : columnOrigins) {
                RexTableInputRef inputRef = extractTableInputRef(orgCol);
                if (inputRef == null) {
                    return null;
                }
                List<Integer> cols = tabToOriginColumns.get(inputRef.getTableRef());
                if (cols == null) {
                    cols = new ArrayList<>();
                }
                cols.add(inputRef.getIndex());
                tabToOriginColumns.put(inputRef.getTableRef(), cols);
            }
        }

        // return the first table which has same number of backtracked columns as colSet
        // ideally we should return all, in case one doesn't work we can fall back to another
        for (Entry<RelTableRef, List<Integer>> mapEntries : tabToOriginColumns.entrySet()) {
            RelTableRef tblRef = mapEntries.getKey();
            List<Integer> mapColList = mapEntries.getValue();
            if (mapColList.size() == colSet.cardinality()) {
                RelOptTable tbl = tblRef.getTable();
                return Pair.of(tbl, mapColList);
            }
        }
        return null;
    }
}