io.prestosql.sql.planner.EqualityInference.java Source code

Java tutorial

Introduction

Here is the source code for io.prestosql.sql.planner.EqualityInference.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.sql.planner;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.collect.SetMultimap;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.InListExpression;
import io.prestosql.sql.tree.InPredicate;
import io.prestosql.util.DisjointSet;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.equalTo;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.Iterables.filter;
import static io.prestosql.sql.ExpressionUtils.extractConjuncts;
import static io.prestosql.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.prestosql.sql.planner.NullabilityAnalyzer.mayReturnNullOnNonNullInput;
import static java.util.Objects.requireNonNull;

/**
 * Makes equality based inferences to rewrite Expressions and generate equality sets in terms of specified symbol scopes
 */
public class EqualityInference {
    // Ordering used to determine Expression preference when determining canonicals
    private static final Ordering<Expression> CANONICAL_ORDERING = Ordering.from((expression1, expression2) -> {
        // Current cost heuristic:
        // 1) Prefer fewer input symbols
        // 2) Prefer smaller expression trees
        // 3) Sort the expressions alphabetically - creates a stable consistent ordering (extremely useful for unit testing)
        // TODO: be more precise in determining the cost of an expression
        return ComparisonChain.start()
                .compare(SymbolsExtractor.extractAll(expression1).size(),
                        SymbolsExtractor.extractAll(expression2).size())
                .compare(SubExpressionExtractor.extract(expression1).size(),
                        SubExpressionExtractor.extract(expression2).size())
                .compare(expression1.toString(), expression2.toString()).result();
    });

    private final SetMultimap<Expression, Expression> equalitySets; // Indexed by canonical expression
    private final Map<Expression, Expression> canonicalMap; // Map each known expression to canonical expression
    private final Set<Expression> derivedExpressions;

    private EqualityInference(Iterable<Set<Expression>> equalityGroups, Set<Expression> derivedExpressions) {
        ImmutableSetMultimap.Builder<Expression, Expression> setBuilder = ImmutableSetMultimap.builder();
        for (Set<Expression> equalityGroup : equalityGroups) {
            if (!equalityGroup.isEmpty()) {
                setBuilder.putAll(CANONICAL_ORDERING.min(equalityGroup), equalityGroup);
            }
        }
        equalitySets = setBuilder.build();

        ImmutableMap.Builder<Expression, Expression> mapBuilder = ImmutableMap.builder();
        for (Map.Entry<Expression, Expression> entry : equalitySets.entries()) {
            Expression canonical = entry.getKey();
            Expression expression = entry.getValue();
            mapBuilder.put(expression, canonical);
        }
        canonicalMap = mapBuilder.build();

        this.derivedExpressions = ImmutableSet.copyOf(derivedExpressions);
    }

    /**
     * Attempts to rewrite an Expression in terms of the symbols allowed by the symbol scope
     * given the known equalities. Returns null if unsuccessful.
     * This method checks if rewritten expression is non-deterministic.
     */
    public Expression rewriteExpression(Expression expression, Predicate<Symbol> symbolScope) {
        checkArgument(isDeterministic(expression), "Only deterministic expressions may be considered for rewrite");
        return rewriteExpression(expression, symbolScope, true);
    }

    /**
     * Attempts to rewrite an Expression in terms of the symbols allowed by the symbol scope
     * given the known equalities. Returns null if unsuccessful.
     * This method allows rewriting non-deterministic expressions.
     */
    public Expression rewriteExpressionAllowNonDeterministic(Expression expression, Predicate<Symbol> symbolScope) {
        return rewriteExpression(expression, symbolScope, true);
    }

    private Expression rewriteExpression(Expression expression, Predicate<Symbol> symbolScope,
            boolean allowFullReplacement) {
        Iterable<Expression> subExpressions = SubExpressionExtractor.extract(expression);
        if (!allowFullReplacement) {
            subExpressions = filter(subExpressions, not(equalTo(expression)));
        }

        ImmutableMap.Builder<Expression, Expression> expressionRemap = ImmutableMap.builder();
        for (Expression subExpression : subExpressions) {
            Expression canonical = getScopedCanonical(subExpression, symbolScope);
            if (canonical != null) {
                expressionRemap.put(subExpression, canonical);
            }
        }

        // Perform a naive single-pass traversal to try to rewrite non-compliant portions of the tree. Prefers to replace
        // larger subtrees over smaller subtrees
        // TODO: this rewrite can probably be made more sophisticated
        Expression rewritten = ExpressionTreeRewriter
                .rewriteWith(new ExpressionNodeInliner(expressionRemap.build()), expression);
        if (!symbolToExpressionPredicate(symbolScope).apply(rewritten)) {
            // If the rewritten is still not compliant with the symbol scope, just give up
            return null;
        }
        return rewritten;
    }

    /**
     * Dumps the inference equalities as equality expressions that are partitioned by the symbolScope.
     * All stored equalities are returned in a compact set and will be classified into three groups as determined by the symbol scope:
     * <ol>
     * <li>equalities that fit entirely within the symbol scope</li>
     * <li>equalities that fit entirely outside of the symbol scope</li>
     * <li>equalities that straddle the symbol scope</li>
     * </ol>
     * <pre>
     * Example:
     *   Stored Equalities:
     *     a = b = c
     *     d = e = f = g
     *
     *   Symbol Scope:
     *     a, b, d, e
     *
     *   Output EqualityPartition:
     *     Scope Equalities:
     *       a = b
     *       d = e
     *     Complement Scope Equalities
     *       f = g
     *     Scope Straddling Equalities
     *       a = c
     *       d = f
     * </pre>
     */
    public EqualityPartition generateEqualitiesPartitionedBy(Predicate<Symbol> symbolScope) {
        ImmutableSet.Builder<Expression> scopeEqualities = ImmutableSet.builder();
        ImmutableSet.Builder<Expression> scopeComplementEqualities = ImmutableSet.builder();
        ImmutableSet.Builder<Expression> scopeStraddlingEqualities = ImmutableSet.builder();

        for (Collection<Expression> equalitySet : equalitySets.asMap().values()) {
            Set<Expression> scopeExpressions = new LinkedHashSet<>();
            Set<Expression> scopeComplementExpressions = new LinkedHashSet<>();
            Set<Expression> scopeStraddlingExpressions = new LinkedHashSet<>();

            // Try to push each non-derived expression into one side of the scope
            for (Expression expression : filter(equalitySet, not(derivedExpressions::contains))) {
                Expression scopeRewritten = rewriteExpression(expression, symbolScope, false);
                if (scopeRewritten != null) {
                    scopeExpressions.add(scopeRewritten);
                }
                Expression scopeComplementRewritten = rewriteExpression(expression, not(symbolScope), false);
                if (scopeComplementRewritten != null) {
                    scopeComplementExpressions.add(scopeComplementRewritten);
                }
                if (scopeRewritten == null && scopeComplementRewritten == null) {
                    scopeStraddlingExpressions.add(expression);
                }
            }
            // Compile the equality expressions on each side of the scope
            Expression matchingCanonical = getCanonical(scopeExpressions);
            if (scopeExpressions.size() >= 2) {
                for (Expression expression : filter(scopeExpressions, not(equalTo(matchingCanonical)))) {
                    scopeEqualities.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL,
                            matchingCanonical, expression));
                }
            }
            Expression complementCanonical = getCanonical(scopeComplementExpressions);
            if (scopeComplementExpressions.size() >= 2) {
                for (Expression expression : filter(scopeComplementExpressions,
                        not(equalTo(complementCanonical)))) {
                    scopeComplementEqualities.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL,
                            complementCanonical, expression));
                }
            }

            // Compile the scope straddling equality expressions
            List<Expression> connectingExpressions = new ArrayList<>();
            connectingExpressions.add(matchingCanonical);
            connectingExpressions.add(complementCanonical);
            connectingExpressions.addAll(scopeStraddlingExpressions);
            connectingExpressions = ImmutableList.copyOf(filter(connectingExpressions, Predicates.notNull()));
            Expression connectingCanonical = getCanonical(connectingExpressions);
            if (connectingCanonical != null) {
                for (Expression expression : filter(connectingExpressions, not(equalTo(connectingCanonical)))) {
                    scopeStraddlingEqualities.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL,
                            connectingCanonical, expression));
                }
            }
        }

        return new EqualityPartition(scopeEqualities.build(), scopeComplementEqualities.build(),
                scopeStraddlingEqualities.build());
    }

    /**
     * Returns the most preferrable expression to be used as the canonical expression
     */
    private static Expression getCanonical(Iterable<Expression> expressions) {
        if (Iterables.isEmpty(expressions)) {
            return null;
        }
        return CANONICAL_ORDERING.min(expressions);
    }

    /**
     * Returns a canonical expression that is fully contained by the symbolScope and that is equivalent
     * to the specified expression. Returns null if unable to to find a canonical.
     */
    @VisibleForTesting
    Expression getScopedCanonical(Expression expression, Predicate<Symbol> symbolScope) {
        Expression canonicalIndex = canonicalMap.get(expression);
        if (canonicalIndex == null) {
            return null;
        }
        return getCanonical(filter(equalitySets.get(canonicalIndex), symbolToExpressionPredicate(symbolScope)));
    }

    private static Predicate<Expression> symbolToExpressionPredicate(final Predicate<Symbol> symbolScope) {
        return expression -> Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope);
    }

    /**
     * Determines whether an Expression may be successfully applied to the equality inference
     */
    public static Predicate<Expression> isInferenceCandidate() {
        return expression -> {
            expression = normalizeInPredicateToEquality(expression);
            if (expression instanceof ComparisonExpression && isDeterministic(expression)
                    && !mayReturnNullOnNonNullInput(expression)) {
                ComparisonExpression comparison = (ComparisonExpression) expression;
                if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) {
                    // We should only consider equalities that have distinct left and right components
                    return !comparison.getLeft().equals(comparison.getRight());
                }
            }
            return false;
        };
    }

    /**
     * Rewrite single value InPredicates as equality if possible
     */
    private static Expression normalizeInPredicateToEquality(Expression expression) {
        if (expression instanceof InPredicate) {
            InPredicate inPredicate = (InPredicate) expression;
            if (inPredicate.getValueList() instanceof InListExpression) {
                InListExpression valueList = (InListExpression) inPredicate.getValueList();
                if (valueList.getValues().size() == 1) {
                    return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, inPredicate.getValue(),
                            Iterables.getOnlyElement(valueList.getValues()));
                }
            }
        }
        return expression;
    }

    /**
     * Provides a convenience Iterable of Expression conjuncts which have not been added to the inference
     */
    public static Iterable<Expression> nonInferrableConjuncts(Expression expression) {
        return filter(extractConjuncts(expression), not(isInferenceCandidate()));
    }

    public static EqualityInference createEqualityInference(Expression... expressions) {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        for (Expression expression : expressions) {
            builder.extractInferenceCandidates(expression);
        }
        return builder.build();
    }

    public static class EqualityPartition {
        private final List<Expression> scopeEqualities;
        private final List<Expression> scopeComplementEqualities;
        private final List<Expression> scopeStraddlingEqualities;

        public EqualityPartition(Iterable<Expression> scopeEqualities,
                Iterable<Expression> scopeComplementEqualities, Iterable<Expression> scopeStraddlingEqualities) {
            this.scopeEqualities = ImmutableList.copyOf(requireNonNull(scopeEqualities, "scopeEqualities is null"));
            this.scopeComplementEqualities = ImmutableList
                    .copyOf(requireNonNull(scopeComplementEqualities, "scopeComplementEqualities is null"));
            this.scopeStraddlingEqualities = ImmutableList
                    .copyOf(requireNonNull(scopeStraddlingEqualities, "scopeStraddlingEqualities is null"));
        }

        public List<Expression> getScopeEqualities() {
            return scopeEqualities;
        }

        public List<Expression> getScopeComplementEqualities() {
            return scopeComplementEqualities;
        }

        public List<Expression> getScopeStraddlingEqualities() {
            return scopeStraddlingEqualities;
        }
    }

    public static class Builder {
        private final DisjointSet<Expression> equalities = new DisjointSet<>();
        private final Set<Expression> derivedExpressions = new LinkedHashSet<>();

        public Builder extractInferenceCandidates(Expression expression) {
            return addAllEqualities(filter(extractConjuncts(expression), isInferenceCandidate()));
        }

        public Builder addAllEqualities(Iterable<Expression> expressions) {
            for (Expression expression : expressions) {
                addEquality(expression);
            }
            return this;
        }

        public Builder addEquality(Expression expression) {
            expression = normalizeInPredicateToEquality(expression);
            checkArgument(isInferenceCandidate().apply(expression),
                    "Expression must be a simple equality: " + expression);
            ComparisonExpression comparison = (ComparisonExpression) expression;
            addEquality(comparison.getLeft(), comparison.getRight());
            return this;
        }

        public Builder addEquality(Expression expression1, Expression expression2) {
            checkArgument(!expression1.equals(expression2),
                    "Need to provide equality between different expressions");
            checkArgument(isDeterministic(expression1), "Expression must be deterministic: " + expression1);
            checkArgument(isDeterministic(expression2), "Expression must be deterministic: " + expression2);

            equalities.findAndUnion(expression1, expression2);
            return this;
        }

        /**
         * Performs one pass of generating more equivalences by rewriting sub-expressions in terms of known equivalences.
         */
        private void generateMoreEquivalences() {
            Collection<Set<Expression>> equivalentClasses = equalities.getEquivalentClasses();

            // Map every expression to the set of equivalent expressions
            ImmutableMap.Builder<Expression, Set<Expression>> mapBuilder = ImmutableMap.builder();
            for (Set<Expression> expressions : equivalentClasses) {
                expressions.forEach(expression -> mapBuilder.put(expression, expressions));
            }

            // For every non-derived expression, extract the sub-expressions and see if they can be rewritten as other expressions. If so,
            // use this new information to update the known equalities.
            Map<Expression, Set<Expression>> map = mapBuilder.build();
            for (Expression expression : map.keySet()) {
                if (!derivedExpressions.contains(expression)) {
                    for (Expression subExpression : filter(SubExpressionExtractor.extract(expression),
                            not(equalTo(expression)))) {
                        Set<Expression> equivalentSubExpressions = map.get(subExpression);
                        if (equivalentSubExpressions != null) {
                            for (Expression equivalentSubExpression : filter(equivalentSubExpressions,
                                    not(equalTo(subExpression)))) {
                                Expression rewritten = ExpressionTreeRewriter.rewriteWith(
                                        new ExpressionNodeInliner(
                                                ImmutableMap.of(subExpression, equivalentSubExpression)),
                                        expression);
                                equalities.findAndUnion(expression, rewritten);
                                derivedExpressions.add(rewritten);
                            }
                        }
                    }
                }
            }
        }

        public EqualityInference build() {
            generateMoreEquivalences();
            return new EqualityInference(equalities.getEquivalentClasses(), derivedExpressions);
        }
    }
}