edu.uci.ics.asterix.optimizer.rules.PushAggFuncIntoStandaloneAggregateRule.java Source code

Java tutorial

Introduction

Here is the source code for edu.uci.ics.asterix.optimizer.rules.PushAggFuncIntoStandaloneAggregateRule.java

Source

/*
 * Copyright 2009-2013 by The Regents of the University of California
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * you may obtain a copy of the License from
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.uci.ics.asterix.optimizer.rules;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;

import org.apache.commons.lang3.mutable.Mutable;
import org.apache.commons.lang3.mutable.MutableObject;

import edu.uci.ics.asterix.om.functions.AsterixBuiltinFunctions;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.IOptimizationContext;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalExpressionTag;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalOperatorTag;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalVariable;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.ConstantExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.AbstractBinaryJoinOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.AbstractLogicalOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.AggregateOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.AssignOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.visitors.VariableUtilities;
import edu.uci.ics.hyracks.algebricks.core.rewriter.base.IAlgebraicRewriteRule;

/**
 * Pushes aggregate functions into a stand alone aggregate operator (no group by).
 */
public class PushAggFuncIntoStandaloneAggregateRule implements IAlgebraicRewriteRule {

    @Override
    public boolean rewritePre(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
            throws AlgebricksException {
        return false;
    }

    @Override
    public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
            throws AlgebricksException {
        // Pattern to match: assign <-- aggregate <-- !(group-by)
        AbstractLogicalOperator op = (AbstractLogicalOperator) opRef.getValue();
        if (op.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
            return false;
        }
        AssignOperator assignOp = (AssignOperator) op;

        Mutable<ILogicalOperator> opRef2 = op.getInputs().get(0);
        AbstractLogicalOperator op2 = (AbstractLogicalOperator) opRef2.getValue();
        if (op2.getOperatorTag() == LogicalOperatorTag.AGGREGATE) {
            AggregateOperator aggOp = (AggregateOperator) op2;
            // Make sure the agg expr is a listify.
            return pushAggregateFunction(aggOp, assignOp, context);
        } else if (op2.getOperatorTag() == LogicalOperatorTag.INNERJOIN
                || op2.getOperatorTag() == LogicalOperatorTag.LEFTOUTERJOIN) {
            AbstractBinaryJoinOperator join = (AbstractBinaryJoinOperator) op2;
            // Tries to push aggregates through the join.
            if (containsAggregate(assignOp.getExpressions()) && pushableThroughJoin(join)) {
                pushAggregateFunctionThroughJoin(join, assignOp, context);
                return true;
            }
        }
        return false;
    }

    /**
     * Recursively check whether the list of expressions contains an aggregate function.
     * 
     * @param exprRefs
     * @return true if the list contains an aggregate function and false otherwise.
     */
    private boolean containsAggregate(List<Mutable<ILogicalExpression>> exprRefs) {
        for (Mutable<ILogicalExpression> exprRef : exprRefs) {
            ILogicalExpression expr = exprRef.getValue();
            if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
                continue;
            }
            AbstractFunctionCallExpression funcExpr = (AbstractFunctionCallExpression) expr;
            FunctionIdentifier funcIdent = AsterixBuiltinFunctions
                    .getAggregateFunction(funcExpr.getFunctionIdentifier());
            if (funcIdent == null) {
                // Recursively look in func args.
                if (containsAggregate(funcExpr.getArguments())) {
                    return true;
                }
            } else {
                // This is an aggregation function.
                return true;
            }
        }
        return false;
    }

    /**
     * Check whether the join is aggregate-pushable, that is,
     * 1) the join condition is true;
     * 2) each join branch produces only one tuple.
     * 
     * @param join
     * @return true if pushable
     */
    private boolean pushableThroughJoin(AbstractBinaryJoinOperator join) {
        ILogicalExpression condition = join.getCondition().getValue();
        if (condition.equals(ConstantExpression.TRUE)) {
            // Checks if the aggregation functions are pushable through the join
            boolean pushable = true;
            for (Mutable<ILogicalOperator> branchRef : join.getInputs()) {
                AbstractLogicalOperator branch = (AbstractLogicalOperator) branchRef.getValue();
                if (branch.getOperatorTag() == LogicalOperatorTag.AGGREGATE) {
                    pushable &= true;
                } else if (branch.getOperatorTag() == LogicalOperatorTag.INNERJOIN
                        || branch.getOperatorTag() == LogicalOperatorTag.LEFTOUTERJOIN) {
                    AbstractBinaryJoinOperator childJoin = (AbstractBinaryJoinOperator) branch;
                    pushable &= pushableThroughJoin(childJoin);
                } else {
                    pushable &= false;
                }
            }
            return pushable;
        }
        return false;
    }

    /**
     * Does the actual push of aggregates for qualified joins.
     * 
     * @param join
     * @param assignOp
     *            that contains aggregate function calls.
     * @param context
     * @throws AlgebricksException
     */
    private void pushAggregateFunctionThroughJoin(AbstractBinaryJoinOperator join, AssignOperator assignOp,
            IOptimizationContext context) throws AlgebricksException {
        for (Mutable<ILogicalOperator> branchRef : join.getInputs()) {
            AbstractLogicalOperator branch = (AbstractLogicalOperator) branchRef.getValue();
            if (branch.getOperatorTag() == LogicalOperatorTag.AGGREGATE) {
                AggregateOperator aggOp = (AggregateOperator) branch;
                pushAggregateFunction(aggOp, assignOp, context);
            } else if (branch.getOperatorTag() == LogicalOperatorTag.INNERJOIN
                    || branch.getOperatorTag() == LogicalOperatorTag.LEFTOUTERJOIN) {
                AbstractBinaryJoinOperator childJoin = (AbstractBinaryJoinOperator) branch;
                pushAggregateFunctionThroughJoin(childJoin, assignOp, context);
            }
        }
    }

    private boolean pushAggregateFunction(AggregateOperator aggOp, AssignOperator assignOp,
            IOptimizationContext context) throws AlgebricksException {
        Mutable<ILogicalOperator> opRef3 = aggOp.getInputs().get(0);
        AbstractLogicalOperator op3 = (AbstractLogicalOperator) opRef3.getValue();
        // If there's a group by below the agg, then we want to have the agg pushed into the group by.
        if (op3.getOperatorTag() == LogicalOperatorTag.GROUP) {
            return false;
        }
        if (aggOp.getVariables().size() != 1) {
            return false;
        }
        ILogicalExpression aggExpr = aggOp.getExpressions().get(0).getValue();
        if (aggExpr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
            return false;
        }
        AbstractFunctionCallExpression origAggFuncExpr = (AbstractFunctionCallExpression) aggExpr;
        if (origAggFuncExpr.getFunctionIdentifier() != AsterixBuiltinFunctions.LISTIFY) {
            return false;
        }

        LogicalVariable aggVar = aggOp.getVariables().get(0);
        List<LogicalVariable> used = new LinkedList<LogicalVariable>();
        VariableUtilities.getUsedVariables(assignOp, used);
        if (!used.contains(aggVar)) {
            return false;
        }

        List<Mutable<ILogicalExpression>> srcAssignExprRefs = new LinkedList<Mutable<ILogicalExpression>>();
        if (fingAggFuncExprRef(assignOp.getExpressions(), aggVar, srcAssignExprRefs) == false) {
            return false;
        }
        if (srcAssignExprRefs.isEmpty()) {
            return false;
        }

        AbstractFunctionCallExpression aggOpExpr = (AbstractFunctionCallExpression) aggOp.getExpressions().get(0)
                .getValue();
        aggOp.getExpressions().clear();
        aggOp.getVariables().clear();

        for (Mutable<ILogicalExpression> srcAssignExprRef : srcAssignExprRefs) {
            AbstractFunctionCallExpression assignFuncExpr = (AbstractFunctionCallExpression) srcAssignExprRef
                    .getValue();
            FunctionIdentifier aggFuncIdent = AsterixBuiltinFunctions
                    .getAggregateFunction(assignFuncExpr.getFunctionIdentifier());

            // Push the agg func into the agg op.                

            List<Mutable<ILogicalExpression>> aggArgs = new ArrayList<Mutable<ILogicalExpression>>();
            aggArgs.add(aggOpExpr.getArguments().get(0));
            AggregateFunctionCallExpression aggFuncExpr = AsterixBuiltinFunctions
                    .makeAggregateFunctionExpression(aggFuncIdent, aggArgs);
            LogicalVariable newVar = context.newVar();
            aggOp.getVariables().add(newVar);
            aggOp.getExpressions().add(new MutableObject<ILogicalExpression>(aggFuncExpr));

            // The assign now just "renames" the variable to make sure the upstream plan still works.
            srcAssignExprRef.setValue(new VariableReferenceExpression(newVar));
        }

        context.computeAndSetTypeEnvironmentForOperator(aggOp);
        context.computeAndSetTypeEnvironmentForOperator(assignOp);
        return true;
    }

    private boolean fingAggFuncExprRef(List<Mutable<ILogicalExpression>> exprRefs, LogicalVariable aggVar,
            List<Mutable<ILogicalExpression>> srcAssignExprRefs) {
        for (Mutable<ILogicalExpression> exprRef : exprRefs) {
            ILogicalExpression expr = exprRef.getValue();

            if (expr.getExpressionTag() == LogicalExpressionTag.VARIABLE) {
                if (((VariableReferenceExpression) expr).getVariableReference().equals(aggVar)) {
                    return false;
                }
            }

            if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
                continue;
            }
            AbstractFunctionCallExpression funcExpr = (AbstractFunctionCallExpression) expr;
            FunctionIdentifier funcIdent = AsterixBuiltinFunctions
                    .getAggregateFunction(funcExpr.getFunctionIdentifier());
            if (funcIdent == null) {
                // Recursively look in func args.
                if (fingAggFuncExprRef(funcExpr.getArguments(), aggVar, srcAssignExprRefs) == false) {
                    return false;
                }

            } else {
                // Check if this is the expr that uses aggVar.
                Collection<LogicalVariable> usedVars = new HashSet<LogicalVariable>();
                funcExpr.getUsedVariables(usedVars);
                if (usedVars.contains(aggVar)) {
                    srcAssignExprRefs.add(exprRef);
                }
            }
        }
        return true;
    }
}