org.apache.sysml.hops.codegen.opt.PlanSelectionFuseCostBasedV2.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.hops.codegen.opt.PlanSelectionFuseCostBasedV2.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DnnOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.OpOpN;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.NaryOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.opt.ReachabilityGraph.SubProblem;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
import org.apache.sysml.hops.codegen.template.TemplateOuterProduct;
import org.apache.sysml.hops.codegen.template.TemplateRow;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysml.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.Statistics;

/**
 * This cost-based plan selection algorithm chooses fused operators
 * based on the DAG structure and resulting overall costs. This includes
 * holistic decisions on 
 * <ul>
 *   <li>Materialization points per consumer</li>
 *   <li>Sparsity exploitation and operator ordering</li>
 *   <li>Decisions on overlapping template types</li>
 *   <li>Decisions on multi-aggregates with shared reads</li>
 *   <li>Constraints (e.g., memory budgets and block sizes)</li>  
 * </ul>
 * 
 */
public class PlanSelectionFuseCostBasedV2 extends PlanSelection {
    private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBasedV2.class.getName());

    //common bandwidth characteristics, with a conservative write bandwidth in order 
    //to cover result allocation, write into main memory, and potential evictions
    private static final double WRITE_BANDWIDTH_IO = 512 * 1024 * 1024; //512MB/s
    private static final double WRITE_BANDWIDTH_MEM = 2d * 1024 * 1024 * 1024; //2GB/s
    private static final double READ_BANDWIDTH_MEM = 32d * 1024 * 1024 * 1024; //32GB/s
    private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH_IO / 4;
    private static final double COMPUTE_BANDWIDTH = 2d * 1024 * 1024 * 1024 //1GFLOPs/core
            * InfrastructureAnalyzer.getLocalParallelism();

    //sparsity estimate for unknown sparsity to prefer sparse-safe fusion plans
    private static final double SPARSE_SAFE_SPARSITY_EST = 0.1;

    //after evaluating the costs of the opening heuristics fuse-all and fuse-no-redundancy,
    //remaining candidate plans of large partitions (w/ >= COST_MIN_EPS_NUM_POINTS) are
    //only evaluated if the current costs are > (1+COST_MIN_EPS) * static (i.e., minimal) costs.
    public static final double COST_MIN_EPS = 0.01; //1%
    public static final int COST_MIN_EPS_NUM_POINTS = 20; //2^20 = 1M plans

    //In order to avoid unnecessary repeated reoptimization we use a plan cache for
    //mapping partition signatures (including input sizes) to optimal plans. However,
    //since hop ids change during dynamic recompilation, we use an approximate signature
    //that is cheap to compute and therefore only use this for large partitions.
    private static final int PLAN_CACHE_NUM_POINTS = 10; //2^10 = 1024
    private static final int PLAN_CACHE_SIZE = 1024;
    private static final LinkedHashMap<PartitionSignature, boolean[]> _planCache = new LinkedHashMap<>();

    //optimizer configuration
    public static boolean COST_PRUNING = true;
    public static boolean STRUCTURAL_PRUNING = true;
    public static boolean PLAN_CACHING = true;
    private static final TemplateRow ROW_TPL = new TemplateRow();

    //cost vector id generator, whose ids are only used for memoization per call to getPlanCost;
    //hence, we use a sequence generator per optimizer instance to avoid thread contention in 
    //multi-threaded parfor scenarios with concurrent dynamic recompilation and thus optimization.
    private final IDSequence COST_ID = new IDSequence();

    @Override
    public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
        //step 1: analyze connected partitions (nodes, roots, mat points)
        Collection<PlanPartition> parts = PlanAnalyzer.analyzePlanPartitions(memo, roots, true);

        //step 2: optimize individual plan partitions
        int sumMatPoints = 0;
        for (PlanPartition part : parts) {
            //create composite templates (within the partition)
            createAndAddMultiAggPlans(memo, part.getPartition(), part.getRoots());

            //plan enumeration and plan selection
            selectPlans(memo, part);
            sumMatPoints += part.getMatPointsExt().length;
        }

        //step 3: add composite templates (across partitions)
        createAndAddMultiAggPlans(memo, roots);

        //take all distinct best plans
        for (Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet())
            memo.setDistinct(e.getKey(), e.getValue());

        //maintain statistics
        if (DMLScript.STATISTICS) {
            if (sumMatPoints >= 63)
                LOG.warn("Long overflow on maintaining codegen statistics " + "for a DAG with " + sumMatPoints
                        + " interesting points.");
            Statistics.incrementCodegenEnumAll(UtilFunctions.pow(2, sumMatPoints));
        }
    }

    private void selectPlans(CPlanMemoTable memo, PlanPartition part) {
        //prune special case patterns and invalid plans (e.g., blocksize)
        pruneInvalidAndSpecialCasePlans(memo, part);

        //if no materialization points, use basic fuse-all w/ partition awareness
        if (part.getMatPointsExt() == null || part.getMatPointsExt().length == 0) {
            for (Long hopID : part.getRoots())
                rSelectPlansFuseAll(memo, memo.getHopRefs().get(hopID), null, part.getPartition());
        } else {
            //obtain hop compute costs per cell once
            HashMap<Long, Double> computeCosts = new HashMap<>();
            for (Long hopID : part.getPartition())
                getComputeCosts(memo.getHopRefs().get(hopID), computeCosts);

            //prepare pruning helpers and prune memo table w/ determined mat points
            StaticCosts costs = new StaticCosts(computeCosts, sumComputeCost(computeCosts), getReadCost(part, memo),
                    getWriteCost(part.getRoots(), memo), minOuterSparsity(part, memo));
            ReachabilityGraph rgraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null;
            if (STRUCTURAL_PRUNING) {
                part.setMatPointsExt(rgraph.getSortedSearchSpace());
                for (Long hopID : part.getPartition())
                    memo.pruneRedundant(hopID, true, part.getMatPointsExt());
            }

            //enumerate and cost plans, returns optional plan
            boolean[] bestPlan = enumPlans(memo, part, costs, rgraph, part.getMatPointsExt(), 0);

            //prune memo table wrt best plan and select plans
            HashSet<Long> visited = new HashSet<>();
            for (Long hopID : part.getRoots())
                rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID), visited, part, part.getMatPointsExt(),
                        bestPlan);
            HashSet<Long> visited2 = new HashSet<>();
            for (Long hopID : part.getRoots())
                rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID), visited2, part, bestPlan);

            for (Long hopID : part.getRoots())
                rSelectPlansFuseAll(memo, memo.getHopRefs().get(hopID), null, part.getPartition());
        }
    }

    /**
     * Core plan enumeration algorithm, invoked recursively for conditionally independent
     * subproblems. This algorithm fully explores the exponential search space of 2^m,
     * where m is the number of interesting materialization points. We iterate over
     * a linearized search space without every instantiating the search tree. Furthermore,
     * in order to reduce the enumeration overhead, we apply two high-impact pruning
     * techniques (1) pruning by evolving lower/upper cost bounds, and (2) pruning by
     * conditional structural properties (so-called cutsets of interesting points). 
     * 
     * @param memo memoization table of partial fusion plans
     * @param part connected component (partition) of partial fusion plans with all necessary meta data
     * @param costs summary of static costs (e.g., partition reads, writes, and compute costs per operator)
     * @param rgraph reachability graph of interesting materialization points
     * @param matPoints sorted materialization points (defined the search space)
     * @param off offset for recursive invocation, indicating the fixed plan part
     * @return optimal assignment of materialization points
     */
    private boolean[] enumPlans(CPlanMemoTable memo, PlanPartition part, StaticCosts costs,
            ReachabilityGraph rgraph, InterestingPoint[] matPoints, int off) {
        //scan linearized search space, w/ skips for branch and bound pruning
        //and structural pruning (where we solve conditionally independent problems)
        //bestC is monotonically non-increasing and serves as the upper bound
        final int Mlen = matPoints.length - off;
        final long len = UtilFunctions.pow(2, Mlen);
        long numEvalPlans = 2, numEvalPartPlans = 0;

        //evaluate heuristics fuse-all and fuse-no-redundancy to quickly obtain a good lower bound
        final boolean[] plan0 = createAssignment(Mlen, off, 0); // fuse-all
        final boolean[] planN = createAssignment(Mlen, off, len - 1); //fuse-no-redundancy
        final double C0 = getPlanCost(memo, part, matPoints, plan0, costs._computeCosts, Double.MAX_VALUE);
        final double CN = getPlanCost(memo, part, matPoints, planN, costs._computeCosts, Double.MAX_VALUE);
        boolean[] bestPlan = (C0 <= CN) ? plan0 : planN;
        double bestC = Math.min(C0, CN);
        final boolean evalRemain = (Mlen < COST_MIN_EPS_NUM_POINTS || !COST_PRUNING
                || bestC > (1 + COST_MIN_EPS) * costs.getMinCosts());
        if (LOG.isTraceEnabled())
            LOG.trace("Enum opening: " + Arrays.toString(bestPlan) + " -> " + bestC);
        if (!evalRemain)
            LOG.warn("Skip enum for |M|=" + Mlen + ", C=" + bestC + ", Cmin=" + costs.getMinCosts());

        //probe plan cache for existing optimized plan
        PartitionSignature pKey = null;
        if (probePlanCache(matPoints)) {
            pKey = new PartitionSignature(part, matPoints.length, costs, C0, CN);
            boolean[] plan = getPlan(pKey);
            if (plan != null) {
                Statistics.incrementCodegenEnumAllP((rgraph != null || !STRUCTURAL_PRUNING) ? len : 0);
                return plan;
            }
        }

        //evaluate remaining plans, except already evaluated heuristics
        for (long i = 1; i < len - 1 & evalRemain; i++) {
            //construct assignment
            boolean[] plan = createAssignment(Mlen, off, i);
            long pskip = 0; //skip after costing

            //skip plans with structural pruning
            if (STRUCTURAL_PRUNING && (rgraph != null) && rgraph.isCutSet(plan)) {
                //compute skip (which also acts as boundary for subproblems)
                pskip = rgraph.getNumSkipPlans(plan);
                if (LOG.isTraceEnabled())
                    LOG.trace("Enum: Structural pruning for cut set: " + rgraph.getCutSet(plan));

                //start increment rgraph get subproblems
                SubProblem[] prob = rgraph.getSubproblems(plan);

                //solve subproblems independently and combine into best plan
                for (int j = 0; j < prob.length; j++) {
                    if (LOG.isTraceEnabled())
                        LOG.trace("Enum: Subproblem " + (j + 1) + "/" + prob.length + ": " + prob[j]);
                    boolean[] bestTmp = enumPlans(memo, part, costs, null, prob[j].freeMat, prob[j].offset);
                    LibSpoofPrimitives.vectWrite(bestTmp, plan, prob[j].freePos);
                }

                //note: the overall plan costs are evaluated in full, which reused
                //the default code path; hence we postpone the skip after costing
            }
            //skip plans with branch and bound pruning (cost)
            else if (COST_PRUNING) {
                double lbC = getLowerBoundCosts(part, matPoints, memo, costs, plan);
                if (lbC >= bestC) {
                    long skip = getNumSkipPlans(plan);
                    if (LOG.isTraceEnabled())
                        LOG.trace("Enum: Skip " + skip + " plans (by cost).");
                    i += skip - 1;
                    continue;
                }
            }

            //cost assignment on hops. Stop early if exceeds bestC.
            double pCBound = COST_PRUNING ? bestC : Double.MAX_VALUE;
            double C = getPlanCost(memo, part, matPoints, plan, costs._computeCosts, pCBound);
            if (LOG.isTraceEnabled())
                LOG.trace("Enum: " + Arrays.toString(plan) + " -> " + C);
            numEvalPartPlans += (C == Double.POSITIVE_INFINITY) ? 1 : 0;
            numEvalPlans++;

            //cost comparisons
            if (bestPlan == null || C < bestC) {
                bestC = C;
                bestPlan = plan;
                if (LOG.isTraceEnabled())
                    LOG.trace("Enum: Found new best plan.");
            }

            //post skipping
            i += pskip;
            if (pskip != 0 && LOG.isTraceEnabled())
                LOG.trace("Enum: Skip " + pskip + " plans (by structure).");
        }

        if (DMLScript.STATISTICS) {
            Statistics.incrementCodegenEnumAllP((rgraph != null || !STRUCTURAL_PRUNING) ? len : 0);
            Statistics.incrementCodegenEnumEval(numEvalPlans);
            Statistics.incrementCodegenEnumEvalP(numEvalPartPlans);
        }
        if (LOG.isTraceEnabled())
            LOG.trace("Enum: Optimal plan: " + Arrays.toString(bestPlan));

        //keep large plans 
        if (probePlanCache(matPoints))
            putPlan(pKey, bestPlan);

        //copy best plan w/o fixed offset plan
        return (bestPlan == null) ? new boolean[Mlen] : Arrays.copyOfRange(bestPlan, off, bestPlan.length);
    }

    private static boolean[] createAssignment(int len, int off, long pos) {
        boolean[] ret = new boolean[off + len];
        Arrays.fill(ret, 0, off, true);
        long tmp = pos;
        for (int i = 0; i < len; i++) {
            long mask = UtilFunctions.pow(2, len - i - 1);
            ret[off + i] = tmp >= mask;
            tmp %= mask;
        }
        return ret;
    }

    private static long getNumSkipPlans(boolean[] plan) {
        int pos = ArrayUtils.lastIndexOf(plan, true);
        return UtilFunctions.pow(2, plan.length - pos - 1);
    }

    private static double getLowerBoundCosts(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo,
            StaticCosts costs, boolean[] plan) {
        //compute the lower bound from static and plan-dependent costs
        double lb = Math.max(costs._read, costs._compute) + costs._write
                + getMaterializationCost(part, M, memo, plan);

        //if the partition contains outer templates, we need to correct the lower bound
        if (part.hasOuter())
            lb *= costs._minSparsity;

        return lb;
    }

    private static double getMaterializationCost(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo,
            boolean[] plan) {
        double costs = 0;
        //currently active materialization points
        HashSet<Long> matTargets = new HashSet<>();
        for (int i = 0; i < plan.length; i++) {
            long hopID = M[i].getToHopID();
            if (plan[i] && !matTargets.contains(hopID)) {
                matTargets.add(hopID);
                Hop hop = memo.getHopRefs().get(hopID);
                long size = getSize(hop);
                costs += size * 8 / WRITE_BANDWIDTH_MEM + size * 8 / READ_BANDWIDTH_MEM;
            }
        }
        //points with non-partition consumers
        for (Long hopID : part.getExtConsumed())
            if (!matTargets.contains(hopID)) {
                matTargets.add(hopID);
                Hop hop = memo.getHopRefs().get(hopID);
                costs += getSize(hop) * 8 / WRITE_BANDWIDTH_MEM;
            }

        return costs;
    }

    private static double getReadCost(PlanPartition part, CPlanMemoTable memo) {
        double costs = 0;
        //get partition input reads (at least read once)
        for (Long hopID : part.getInputs()) {
            Hop hop = memo.getHopRefs().get(hopID);
            costs += getSafeMemEst(hop) / READ_BANDWIDTH_MEM;
        }
        return costs;
    }

    private static double getWriteCost(Collection<Long> R, CPlanMemoTable memo) {
        double costs = 0;
        for (Long hopID : R) {
            Hop hop = memo.getHopRefs().get(hopID);
            costs += getSize(hop) * 8 / WRITE_BANDWIDTH_MEM;
        }
        return costs;
    }

    private static double sumComputeCost(HashMap<Long, Double> computeCosts) {
        return computeCosts.values().stream().mapToDouble(d -> d / COMPUTE_BANDWIDTH).sum();
    }

    private static double minOuterSparsity(PlanPartition part, CPlanMemoTable memo) {
        return !part.hasOuter() ? 1.0
                : part.getPartition().stream().map(k -> HopRewriteUtils.getLargestInput(memo.getHopRefs().get(k)))
                        .mapToDouble(h -> h.dimsKnown(true) ? h.getSparsity() : SPARSE_SAFE_SPARSITY_EST).min()
                        .orElse(SPARSE_SAFE_SPARSITY_EST);
    }

    private static double sumTmpInputOutputSize(CPlanMemoTable memo, CostVector vect) {
        //size of intermediate inputs and outputs, i.e., output and inputs other than treads
        return vect.outSize + vect.inSizes.entrySet().stream()
                .filter(e -> !HopRewriteUtils.isData(memo.getHopRefs().get(e.getKey()), DataOpTypes.TRANSIENTREAD))
                .mapToDouble(e -> e.getValue()).sum();
    }

    private static double sumInputMemoryEstimates(CPlanMemoTable memo, CostVector vect) {
        return vect.inSizes.keySet().stream().mapToDouble(e -> getSafeMemEst(memo.getHopRefs().get(e))).sum();
    }

    private static double getSafeMemEst(Hop hop) {
        return !hop.dimsKnown() ? getSize(hop) * 8 : hop.getOutputMemEstimate();
    }

    private static long getSize(Hop hop) {
        return Math.max(hop.getDim1(), 1) * Math.max(hop.getDim2(), 1);
    }

    //within-partition multi-agg templates
    private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
        //create index of plans that reference full aggregates to avoid circular dependencies
        HashSet<Long> refHops = new HashSet<>();
        for (Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet())
            if (!e.getValue().isEmpty()) {
                Hop hop = memo.getHopRefs().get(e.getKey());
                for (Hop c : hop.getInput())
                    refHops.add(c.getHopID());
            }

        //find all full aggregations (the fact that they are in the same partition guarantees 
        //that they also have common subexpressions, also full aggregations are by def root nodes)
        ArrayList<Long> fullAggs = new ArrayList<>();
        for (Long hopID : R) {
            Hop root = memo.getHopRefs().get(hopID);
            if (!refHops.contains(hopID) && isMultiAggregateRoot(root))
                fullAggs.add(hopID);
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Found within-partition ua(RC) aggregations: "
                    + Arrays.toString(fullAggs.toArray(new Long[0])));
        }

        //construct and add multiagg template plans (w/ max 3 aggregations)
        for (int i = 0; i < fullAggs.size(); i += 3) {
            int ito = Math.min(i + 3, fullAggs.size());
            if (ito - i >= 2) {
                MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, fullAggs.get(i), fullAggs.get(i + 1),
                        ((ito - i) == 3) ? fullAggs.get(i + 2) : -1, ito - i);
                if (isValidMultiAggregate(memo, me)) {
                    for (int j = i; j < ito; j++) {
                        memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
                        if (LOG.isTraceEnabled())
                            LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
                    }
                } else if (LOG.isTraceEnabled()) {
                    LOG.trace("Removed invalid multiagg plan: " + me);
                }
            }
        }
    }

    //across-partition multi-agg templates with shared reads
    private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
        //collect full aggregations as initial set of candidates
        HashSet<Long> fullAggs = new HashSet<>();
        Hop.resetVisitStatus(roots);
        for (Hop hop : roots)
            rCollectFullAggregates(hop, fullAggs);
        Hop.resetVisitStatus(roots);

        //remove operators with assigned multi-agg plans
        fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));

        //check applicability for further analysis
        if (fullAggs.size() <= 1)
            return;

        if (LOG.isTraceEnabled()) {
            LOG.trace("Found across-partition ua(RC) aggregations: "
                    + Arrays.toString(fullAggs.toArray(new Long[0])));
        }

        //collect information for all candidates 
        //(subsumed aggregations, and inputs to fused operators) 
        List<AggregateInfo> aggInfos = new ArrayList<>();
        for (Long hopID : fullAggs) {
            Hop aggHop = memo.getHopRefs().get(hopID);
            AggregateInfo tmp = new AggregateInfo(aggHop);
            for (int i = 0; i < aggHop.getInput().size(); i++) {
                Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i == 0
                        ? aggHop.getInput().get(0).getInput().get(0)
                        : aggHop.getInput().get(i);
                rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
            }
            if (tmp._fusedInputs.isEmpty()) {
                if (HopRewriteUtils.isMatrixMultiply(aggHop)) {
                    tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
                    tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
                } else
                    tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
            }
            aggInfos.add(tmp);
        }

        if (LOG.isTraceEnabled()) {
            LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
            for (AggregateInfo info : aggInfos)
                LOG.trace(info);
        }

        //sort aggregations by num dependencies to simplify merging
        //clusters of aggregations with parallel dependencies
        aggInfos = aggInfos.stream().sorted(Comparator.comparing(a -> a._inputAggs.size()))
                .collect(Collectors.toList());

        //greedy grouping of multi-agg candidates
        boolean converged = false;
        while (!converged) {
            AggregateInfo merged = null;
            for (int i = 0; i < aggInfos.size(); i++) {
                AggregateInfo current = aggInfos.get(i);
                for (int j = i + 1; j < aggInfos.size(); j++) {
                    AggregateInfo that = aggInfos.get(j);
                    if (current.isMergable(that)) {
                        merged = current.merge(that);
                        aggInfos.remove(j);
                        j--;
                    }
                }
            }
            converged = (merged == null);
        }

        if (LOG.isTraceEnabled()) {
            LOG.trace("Merged across-partition ua(RC) aggregation info: ");
            for (AggregateInfo info : aggInfos)
                LOG.trace(info);
        }

        //construct and add multiagg template plans (w/ max 3 aggregations)
        for (AggregateInfo info : aggInfos) {
            if (info._aggregates.size() <= 1)
                continue;
            Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
            MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, aggs[0], aggs[1],
                    (aggs.length > 2) ? aggs[2] : -1, aggs.length);
            for (int i = 0; i < aggs.length; i++) {
                memo.add(memo.getHopRefs().get(aggs[i]), me);
                addBestPlan(aggs[i], me);
                if (LOG.isTraceEnabled())
                    LOG.trace("Added multiagg* plan: " + aggs[i] + " " + me);

            }
        }
    }

    private static boolean isMultiAggregateRoot(Hop root) {
        return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
                && ((AggUnaryOp) root).getDirection() == Direction.RowCol)
                || (root instanceof AggBinaryOp && root.getDim1() == 1 && root.getDim2() == 1
                        && HopRewriteUtils.isTransposeOperation(root.getInput().get(0)));
    }

    private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) {
        //ensure input consistent sizes (otherwise potential for incorrect results)
        boolean ret = true;
        Hop refSize = memo.getHopRefs().get(me.input1).getInput().get(0);
        for (int i = 1; ret && i < 3; i++) {
            if (me.isPlanRef(i))
                ret &= HopRewriteUtils.isEqualSize(refSize, memo.getHopRefs().get(me.input(i)).getInput().get(0));
        }

        //ensure that aggregates are independent of each other, i.e.,
        //they to not have potentially transitive parent child references
        for (int i = 0; ret && i < 3; i++)
            if (me.isPlanRef(i)) {
                HashSet<Long> probe = new HashSet<>();
                for (int j = 0; j < 3; j++)
                    if (i != j)
                        probe.add(me.input(j));
                ret &= rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe);
            }
        return ret;
    }

    private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) {
        boolean ret = true;
        for (Hop c : current.getInput())
            ret &= rCheckMultiAggregate(c, probe);
        ret &= !probe.contains(current.getHopID());
        return ret;
    }

    private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) {
        if (current.isVisited())
            return;

        //collect all applicable full aggregations per read
        if (isMultiAggregateRoot(current))
            aggs.add(current.getHopID());

        //recursively process children
        for (Hop c : current.getInput())
            rCollectFullAggregates(c, aggs);

        current.setVisited();
    }

    private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo,
            TemplateType type) {
        //collect input aggregates (dependents)
        if (isMultiAggregateRoot(current))
            aggInfo.addInputAggregate(current.getHopID());

        //recursively process children
        MemoTableEntry me = (type != null) ? memo.getBest(current.getHopID()) : null;
        for (int i = 0; i < current.getInput().size(); i++) {
            Hop c = current.getInput().get(i);
            if (me != null && me.isPlanRef(i))
                rExtractAggregateInfo(memo, c, aggInfo, type);
            else {
                if (type != null && c.getDataType().isMatrix()) //add fused input
                    aggInfo.addFusedInput(c.getHopID());
                rExtractAggregateInfo(memo, c, aggInfo, null);
            }
        }
    }

    private static HashSet<Long> collectIrreplaceableRowOps(CPlanMemoTable memo, PlanPartition part) {
        //get row entries that are (a) reachable from rowwise ops (top down) other than
        //operator root nodes, or dependent upon row-wise ops (bottom up)
        HashSet<Long> blacklist = new HashSet<>();
        HashSet<Pair<Long, Integer>> visited = new HashSet<>();
        for (Long hopID : part.getRoots()) {
            rCollectDependentRowOps(memo.getHopRefs().get(hopID), memo, part, blacklist, visited, null, false);
        }
        return blacklist;
    }

    private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable memo, PlanPartition part,
            HashSet<Long> blacklist, HashSet<Pair<Long, Integer>> visited, TemplateType type, boolean foundRowOp) {
        //avoid redundant evaluation of processed and non-partition nodes
        Pair<Long, Integer> key = Pair.of(hop.getHopID(),
                (foundRowOp ? Short.MAX_VALUE : 0) + ((type != null) ? type.ordinal() + 1 : 0));
        if (visited.contains(key) || !part.getPartition().contains(hop.getHopID())) {
            return;
        }

        //process node itself (top-down)
        MemoTableEntry me = (type == null) ? memo.getBest(hop.getHopID()) : memo.getBest(hop.getHopID(), type);
        boolean inRow = (me != null && me.type == TemplateType.ROW && type == TemplateType.ROW);
        boolean diffPlans = part.getMatPointsExt().length > 0 //guard against plan differences
                && memo.contains(hop.getHopID(), TemplateType.ROW)
                && !memo.hasOnlyExactMatches(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
        if (inRow && foundRowOp)
            blacklist.add(hop.getHopID());
        if (isRowAggOp(hop, inRow) || diffPlans) {
            blacklist.add(hop.getHopID());
            foundRowOp = true;
        }

        //process children recursively
        for (int i = 0; i < hop.getInput().size(); i++) {
            boolean lfoundRowOp = foundRowOp && me != null
                    && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type));
            rCollectDependentRowOps(hop.getInput().get(i), memo, part, blacklist, visited,
                    me != null ? me.type : null, lfoundRowOp);
        }

        //process node itself (bottom-up)
        if (!blacklist.contains(hop.getHopID())) {
            for (int i = 0; i < hop.getInput().size(); i++)
                if (me != null && me.type == TemplateType.ROW
                        && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type))
                        && blacklist.contains(hop.getInput().get(i).getHopID())) {
                    blacklist.add(hop.getHopID());
                }
        }

        visited.add(key);
    }

    private static boolean isRowAggOp(Hop hop, boolean inRow) {
        return HopRewriteUtils.isBinary(hop, OpOp2.CBIND) || HopRewriteUtils.isNary(hop, OpOpN.CBIND)
                || (hop instanceof AggBinaryOp
                        && (inRow || !hop.dimsKnown() || (hop.getDim1() != 1 && hop.getDim2() != 1)))
                || (HopRewriteUtils.isTransposeOperation(hop) && (hop.getDim1() != 1 && hop.getDim2() != 1)
                        && !HopRewriteUtils.isDataGenOp(hop.getInput().get(0), DataGenMethod.SEQ))
                || (hop instanceof AggUnaryOp && inRow);
    }

    private static boolean isValidRow2CellOp(Hop hop) {
        return !(HopRewriteUtils.isBinary(hop, OpOp2.CBIND)
                || (hop instanceof AggBinaryOp && hop.getDim1() != 1 && hop.getDim2() != 1));
    }

    private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable memo, PlanPartition part) {
        //prune invalid row entries w/ violated blocksize constraint
        if (OptimizerUtils.isSparkExecutionMode()) {
            for (Long hopID : part.getPartition()) {
                if (!memo.contains(hopID, TemplateType.ROW))
                    continue;
                Hop hop = memo.getHopRefs().get(hopID);
                boolean isSpark = DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK
                        || OptimizerUtils.getTotalMemEstimate(hop.getInput().toArray(new Hop[0]), hop,
                                true) > OptimizerUtils.getLocalMemBudget();
                boolean validNcol = hop.getDataType().isScalar()
                        || (HopRewriteUtils.isTransposeOperation(hop) ? hop.getDim1() <= hop.getRowsInBlock()
                                : hop.getDim2() <= hop.getColsInBlock());
                for (Hop in : hop.getInput())
                    validNcol &= in.getDataType().isScalar() || (in.getDim2() <= in.getColsInBlock())
                            || (hop instanceof AggBinaryOp && in.getDim1() <= in.getRowsInBlock()
                                    && HopRewriteUtils.isTransposeOperation(in));
                if (isSpark && !validNcol) {
                    List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
                    memo.remove(memo.getHopRefs().get(hopID), TemplateType.ROW);
                    memo.removeAllRefTo(hopID, TemplateType.ROW);
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Removed row memo table entries w/ violated blocksize constraint (" + hopID
                                + "): " + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
                    }
                }
            }
        }

        //prune row aggregates with pure cellwise operations
        //(we determine a blacklist of all operators in a partition that either
        //depend upon row aggregates or on which row aggregates depend)
        HashSet<Long> blacklist = collectIrreplaceableRowOps(memo, part);
        for (Long hopID : part.getPartition()) {
            if (blacklist.contains(hopID))
                continue;
            MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
            if (me != null && me.type == TemplateType.ROW
                    && memo.hasOnlyExactMatches(hopID, TemplateType.ROW, TemplateType.CELL)) {
                List<MemoTableEntry> rmList = memo.get(hopID, TemplateType.ROW);
                memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(rmList));
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Removed row memo table entries w/o aggregation: "
                            + Arrays.toString(rmList.toArray(new MemoTableEntry[0])));
                }
            }
        }

        //prune suboptimal outer product plans that are dominated by outer product plans w/ same number of 
        //references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), 
        //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
        for (Long hopID : part.getPartition()) {
            if (memo.countEntries(hopID, TemplateType.OUTER) == 2) {
                List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER);
                MemoTableEntry me1 = entries.get(0);
                MemoTableEntry me2 = entries.get(1);
                MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
                if (rmEntry != null) {
                    memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
                    memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
                    if (LOG.isTraceEnabled())
                        LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
                }
            }
        }
    }

    private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited,
            PlanPartition part, InterestingPoint[] matPoints, boolean[] plan) {
        //memoization (not via hops because in middle of dag)
        if (visited.contains(current.getHopID()))
            return;

        //remove memo table entries if necessary
        long hopID = current.getHopID();
        if (part.getPartition().contains(hopID) && memo.contains(hopID)) {
            Iterator<MemoTableEntry> iter = memo.get(hopID).iterator();
            while (iter.hasNext()) {
                MemoTableEntry me = iter.next();
                if (!hasNoRefToMatPoint(hopID, me, matPoints, plan) && me.type != TemplateType.OUTER) {
                    iter.remove();
                    if (LOG.isTraceEnabled())
                        LOG.trace("Removed memo table entry: " + me);
                }
            }
        }

        //process children recursively
        for (Hop c : current.getInput())
            rPruneSuboptimalPlans(memo, c, visited, part, matPoints, plan);

        visited.add(current.getHopID());
    }

    private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited,
            PlanPartition part, boolean[] plan) {
        //memoization (not via hops because in middle of dag)
        if (visited.contains(current.getHopID()))
            return;

        //process children recursively
        for (Hop c : current.getInput())
            rPruneInvalidPlans(memo, c, visited, part, plan);

        //find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs, 
        //i.e., plans that become invalid after the previous pruning step
        long hopID = current.getHopID();
        if (part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW)) {
            Iterator<MemoTableEntry> iter = memo.get(hopID, TemplateType.ROW).iterator();
            while (iter.hasNext()) {
                MemoTableEntry me = iter.next();
                //convert leaf node with pure vector inputs
                boolean applyLeaf = (!me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current));

                //convert inner node without row template input
                boolean applyInner = !applyLeaf && !ROW_TPL.open(current);
                for (int i = 0; i < 3 & applyInner; i++)
                    if (me.isPlanRef(i))
                        applyInner &= !memo.contains(me.input(i), TemplateType.ROW);

                if (applyLeaf || applyInner) {
                    String type = applyLeaf ? "leaf" : "inner";
                    if (isValidRow2CellOp(current)) {
                        me.type = TemplateType.CELL;
                        if (LOG.isTraceEnabled())
                            LOG.trace("Converted " + type + " memo table entry from row to cell: " + me);
                    } else {
                        if (LOG.isTraceEnabled())
                            LOG.trace("Removed " + type + " memo table entry row (unsupported cell): " + me);
                        iter.remove();
                    }
                }
            }
        }

        visited.add(current.getHopID());
    }

    /////////////////////////////////////////////////////////
    // Cost model fused operators w/ materialization points
    //////////

    private double getPlanCost(CPlanMemoTable memo, PlanPartition part, InterestingPoint[] matPoints,
            boolean[] plan, HashMap<Long, Double> computeCosts, final double costBound) {
        //high level heuristic: every hop or fused operator has the following cost: 
        //WRITE + max(COMPUTE, READ), where WRITE costs are given by the output size, 
        //READ costs by the input sizes, and COMPUTE by operation specific FLOP
        //counts times number of cells of main input, disregarding sparsity for now.

        HashSet<VisitMarkCost> visited = new HashSet<>();
        double costs = 0;
        int rem = part.getRoots().size();
        for (Long hopID : part.getRoots()) {
            costs += rGetPlanCosts(memo, memo.getHopRefs().get(hopID), visited, part, matPoints, plan, computeCosts,
                    null, null, costBound - costs);
            if (costs >= costBound && --rem > 0) //stop early
                return Double.POSITIVE_INFINITY;
        }
        return costs;
    }

    private double rGetPlanCosts(CPlanMemoTable memo, final Hop current, HashSet<VisitMarkCost> visited,
            PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts,
            CostVector costsCurrent, TemplateType currentType, final double costBound) {
        final long currentHopId = current.getHopID();
        //memoization per hop id and cost vector to account for redundant
        //computation without double counting materialized results or compute
        //costs of complex operation DAGs within a single fused operator
        if (!visited.add(new VisitMarkCost(currentHopId,
                (costsCurrent == null || currentType == TemplateType.MAGG) ? -1 : costsCurrent.ID)))
            return 0; //already existing 

        //open template if necessary, including memoization
        //under awareness of current plan choice
        MemoTableEntry best = null;
        boolean opened = (currentType == null);
        if (memo.contains(currentHopId)) {
            //note: this is the inner loop of plan enumeration and hence, we do not 
            //use streams, lambda expressions, etc to avoid unnecessary overhead
            if (currentType == null) {
                for (MemoTableEntry me : memo.get(currentHopId))
                    best = me.isValid() && hasNoRefToMatPoint(currentHopId, me, matPoints, plan)
                            && BasicPlanComparator.icompare(me, best) < 0 ? me : best;
                opened = true;
            } else {
                for (MemoTableEntry me : memo.get(currentHopId))
                    best = (me.type == currentType || me.type == TemplateType.CELL)
                            && hasNoRefToMatPoint(currentHopId, me, matPoints, plan)
                            && TypedPlanComparator.icompare(me, best, currentType) < 0 ? me : best;
            }
        }

        //create new cost vector if opened, initialized with write costs
        CostVector costVect = !opened ? costsCurrent : new CostVector(getSize(current));
        double costs = 0;

        //add other roots for multi-agg template to account for shared costs
        if (opened && best != null && best.type == TemplateType.MAGG) {
            //account costs to first multi-agg root 
            if (best.input1 == currentHopId)
                for (int i = 1; i < 3; i++) {
                    if (!best.isPlanRef(i))
                        continue;
                    costs += rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited, part, matPoints,
                            plan, computeCosts, costVect, TemplateType.MAGG, costBound - costs);
                    if (costs >= costBound)
                        return Double.POSITIVE_INFINITY;
                }
            //skip other multi-agg roots
            else
                return 0;
        }

        //add compute costs of current operator to costs vector
        costVect.computeCosts += computeCosts.get(currentHopId);

        //process children recursively
        for (int i = 0; i < current.getInput().size(); i++) {
            Hop c = current.getInput().get(i);
            if (best != null && best.isPlanRef(i))
                costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, costVect, best.type,
                        costBound - costs);
            else if (best != null && isImplicitlyFused(current, i, best.type))
                costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c));
            else { //include children and I/O costs
                if (part.getPartition().contains(c.getHopID()))
                    costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null,
                            costBound - costs);
                if (costVect != null && c.getDataType().isMatrix())
                    costVect.addInputSize(c.getHopID(), getSize(c));
            }
            if (costs >= costBound)
                return Double.POSITIVE_INFINITY;
        }

        //add costs for opened fused operator
        if (opened) {
            double memInputs = sumInputMemoryEstimates(memo, costVect);
            double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH_MEM
                    + Math.max(memInputs / READ_BANDWIDTH_MEM, costVect.computeCosts / COMPUTE_BANDWIDTH);
            //read correction for distributed computation
            if (memInputs > OptimizerUtils.getLocalMemBudget())
                tmpCosts += costVect.getSideInputSize() * 8 / READ_BANDWIDTH_BROADCAST;
            //sparsity correction for outer-product template (and sparse-safe cell)
            Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
            if (best != null && best.type == TemplateType.OUTER)
                tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
            //write correction for known evictions in CP
            else if (memInputs <= OptimizerUtils.getLocalMemBudget()
                    && sumTmpInputOutputSize(memo, costVect) * 8 > LazyWriteBuffer.getWriteBufferLimit())
                tmpCosts += costVect.outSize * 8 / WRITE_BANDWIDTH_IO;
            costs += tmpCosts;
            if (LOG.isTraceEnabled()) {
                String type = (best != null) ? best.type.name() : "HOP";
                LOG.trace("Cost vector (" + type + " " + currentHopId + "): " + costVect + " -> " + tmpCosts);
            }
        }
        //add costs for non-partition read in the middle of fused operator
        else if (part.getExtConsumed().contains(current.getHopID())) {
            costs += rGetPlanCosts(memo, current, visited, part, matPoints, plan, computeCosts, null, null,
                    costBound - costs);
        }

        //sanity check non-negative costs
        if (costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs))
            throw new RuntimeException("Wrong cost estimate: " + costs);

        return costs;
    }

    private static void getComputeCosts(Hop current, HashMap<Long, Double> computeCosts) {
        //get costs for given hop
        double costs = 1;
        if (current instanceof UnaryOp) {
            switch (((UnaryOp) current).getOp()) {
            case ABS:
            case ROUND:
            case CEIL:
            case FLOOR:
            case SIGN:
                costs = 1;
                break;
            case SPROP:
            case SQRT:
                costs = 2;
                break;
            case EXP:
                costs = 18;
                break;
            case SIGMOID:
                costs = 21;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case NCOL:
            case NROW:
            case PRINT:
            case ASSERT:
            case CAST_AS_BOOLEAN:
            case CAST_AS_DOUBLE:
            case CAST_AS_INT:
            case CAST_AS_MATRIX:
            case CAST_AS_SCALAR:
                costs = 1;
                break;
            case SIN:
                costs = 18;
                break;
            case COS:
                costs = 22;
                break;
            case TAN:
                costs = 42;
                break;
            case ASIN:
                costs = 93;
                break;
            case ACOS:
                costs = 103;
                break;
            case ATAN:
                costs = 40;
                break;
            case SINH:
                costs = 93;
                break; // TODO:
            case COSH:
                costs = 103;
                break;
            case TANH:
                costs = 40;
                break;
            case CUMSUM:
            case CUMMIN:
            case CUMMAX:
            case CUMPROD:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
            }
        } else if (current instanceof BinaryOp) {
            switch (((BinaryOp) current).getOp()) {
            case MULT:
            case PLUS:
            case MINUS:
            case MIN:
            case MAX:
            case AND:
            case OR:
            case EQUAL:
            case NOTEQUAL:
            case LESS:
            case LESSEQUAL:
            case GREATER:
            case GREATEREQUAL:
            case CBIND:
            case RBIND:
                costs = 1;
                break;
            case INTDIV:
                costs = 6;
                break;
            case MODULUS:
                costs = 8;
                break;
            case DIV:
                costs = 22;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case POW:
                costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
                break;
            case MINUS_NZ:
            case MINUS1_MULT:
                costs = 2;
                break;
            case MOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp
                        ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1))
                        : 2);
                switch (type) {
                case 0:
                    costs = 1;
                    break; //count
                case 1:
                    costs = 8;
                    break; //mean
                case 2:
                    costs = 16;
                    break; //cm2
                case 3:
                    costs = 31;
                    break; //cm3
                case 4:
                    costs = 51;
                    break; //cm4
                case 5:
                    costs = 16;
                    break; //variance
                }
                break;
            case COV:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
            }
        } else if (current instanceof TernaryOp) {
            switch (((TernaryOp) current).getOp()) {
            case IFELSE:
            case PLUS_MULT:
            case MINUS_MULT:
                costs = 2;
                break;
            case CTABLE:
                costs = 3;
                break;
            case MOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp
                        ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1))
                        : 2);
                switch (type) {
                case 0:
                    costs = 2;
                    break; //count
                case 1:
                    costs = 9;
                    break; //mean
                case 2:
                    costs = 17;
                    break; //cm2
                case 3:
                    costs = 32;
                    break; //cm3
                case 4:
                    costs = 52;
                    break; //cm4
                case 5:
                    costs = 17;
                    break; //variance
                }
                break;
            case COV:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
            }
        } else if (current instanceof NaryOp) {
            costs = HopRewriteUtils.isNary(current, OpOpN.MIN, OpOpN.MAX) ? current.getInput().size() : 1;
        } else if (current instanceof ParameterizedBuiltinOp) {
            costs = 1;
        } else if (current instanceof IndexingOp) {
            costs = 1;
        } else if (current instanceof ReorgOp) {
            costs = 1;
        } else if (current instanceof DnnOp) {
            switch (((DnnOp) current).getOp()) {
            case BIASADD:
            case BIASMULT:
                costs = 2;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((DnnOp) current).getOp());
            }
        } else if (current instanceof AggBinaryOp) {
            //outer product template w/ matrix-matrix 
            //or row template w/ matrix-vector or matrix-matrix
            costs = 2 * current.getInput().get(0).getDim2();
            if (current.getInput().get(0).dimsKnown(true))
                costs *= current.getInput().get(0).getSparsity();
        } else if (current instanceof AggUnaryOp) {
            switch (((AggUnaryOp) current).getOp()) {
            case SUM:
                costs = 4;
                break;
            case SUM_SQ:
                costs = 5;
                break;
            case MIN:
            case MAX:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
            }
            switch (((AggUnaryOp) current).getDirection()) {
            case Col:
                costs *= Math.max(current.getInput().get(0).getDim1(), 1);
                break;
            case Row:
                costs *= Math.max(current.getInput().get(0).getDim2(), 1);
                break;
            case RowCol:
                costs *= getSize(current.getInput().get(0));
                break;
            }
        }

        //scale by current output size in order to correctly reflect
        //a mix of row and cell operations in the same fused operator
        //(e.g., row template with fused column vector operations)
        costs *= getSize(current);

        computeCosts.put(current.getHopID(), costs);
    }

    private static boolean hasNoRefToMatPoint(long hopID, MemoTableEntry me, InterestingPoint[] M, boolean[] plan) {
        return !InterestingPoint.isMatPoint(M, hopID, me, plan);
    }

    private static boolean isImplicitlyFused(Hop hop, int index, TemplateType type) {
        return type == TemplateType.ROW && HopRewriteUtils.isMatrixMultiply(hop) && index == 0
                && HopRewriteUtils.isTransposeOperation(hop.getInput().get(index));
    }

    private static boolean probePlanCache(InterestingPoint[] matPoints) {
        return matPoints.length >= PLAN_CACHE_NUM_POINTS;
    }

    private static boolean[] getPlan(PartitionSignature pKey) {
        boolean[] plan = null;
        synchronized (_planCache) {
            plan = _planCache.get(pKey);
        }
        if (DMLScript.STATISTICS) {
            if (plan != null)
                Statistics.incrementCodegenPlanCacheHits();
            Statistics.incrementCodegenPlanCacheTotal();
        }
        return plan;
    }

    private static void putPlan(PartitionSignature pKey, boolean[] plan) {
        synchronized (_planCache) {
            //maintain size of plan cache (remove first)
            if (_planCache.size() >= PLAN_CACHE_SIZE) {
                Iterator<Entry<PartitionSignature, boolean[]>> iter = _planCache.entrySet().iterator();
                iter.next();
                iter.remove();
            }

            //add last entry 
            _planCache.put(pKey, plan);
        }
    }

    private class CostVector {
        public final long ID;
        public final double outSize;
        public double computeCosts = 0;
        public final HashMap<Long, Double> inSizes = new HashMap<>();

        public CostVector(double outputSize) {
            ID = COST_ID.getNextID();
            outSize = outputSize;
        }

        public void addInputSize(long hopID, double inputSize) {
            //ensures that input sizes are not double counted
            inSizes.put(hopID, inputSize);
        }

        @SuppressWarnings("unused")
        public double getInputSize() {
            return inSizes.values().stream().mapToDouble(d -> d.doubleValue()).sum();
        }

        public double getSideInputSize() {
            double max = getMaxInputSize();
            return inSizes.values().stream().filter(d -> d < max).mapToDouble(d -> d.doubleValue()).sum();
        }

        public double getMaxInputSize() {
            return inSizes.values().stream().mapToDouble(d -> d.doubleValue()).max().orElse(0);
        }

        public long getMaxInputSizeHopID() {
            long id = -1;
            double max = 0;
            for (Entry<Long, Double> e : inSizes.entrySet())
                if (max < e.getValue()) {
                    id = e.getKey();
                    max = e.getValue();
                }
            return id;
        }

        @Override
        public String toString() {
            return "[" + outSize + ", " + computeCosts + ", {"
                    + Arrays.toString(inSizes.keySet().toArray(new Long[0])) + ", "
                    + Arrays.toString(inSizes.values().toArray(new Double[0])) + "}]";
        }
    }

    private static class StaticCosts {
        public final HashMap<Long, Double> _computeCosts;
        public final double _compute;
        public final double _read;
        public final double _write;
        public final double _minSparsity;

        public StaticCosts(HashMap<Long, Double> allComputeCosts, double computeCost, double readCost,
                double writeCost, double minSparsity) {
            _computeCosts = allComputeCosts;
            _compute = computeCost;
            _read = readCost;
            _write = writeCost;
            _minSparsity = minSparsity;
        }

        public double getMinCosts() {
            return Math.max(_read, _compute) + _write;
        }
    }

    private static class AggregateInfo {
        public final HashMap<Long, Hop> _aggregates;
        public final HashSet<Long> _inputAggs = new HashSet<>();
        public final HashSet<Long> _fusedInputs = new HashSet<>();

        public AggregateInfo(Hop aggregate) {
            _aggregates = new HashMap<>();
            _aggregates.put(aggregate.getHopID(), aggregate);
        }

        public void addInputAggregate(long hopID) {
            _inputAggs.add(hopID);
        }

        public void addFusedInput(long hopID) {
            _fusedInputs.add(hopID);
        }

        public boolean isMergable(AggregateInfo that) {
            //check independence
            boolean ret = _aggregates.size() < 3 && _aggregates.size() + that._aggregates.size() <= 3;
            for (Long hopID : that._aggregates.keySet())
                ret &= !_inputAggs.contains(hopID);
            for (Long hopID : _aggregates.keySet())
                ret &= !that._inputAggs.contains(hopID);
            //check partial shared reads
            ret &= !CollectionUtils.intersection(_fusedInputs, that._fusedInputs).isEmpty();
            //check consistent sizes (result correctness)
            Hop in1 = _aggregates.values().iterator().next();
            Hop in2 = that._aggregates.values().iterator().next();
            return ret && HopRewriteUtils.isEqualSize(
                    in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1) ? 1 : 0),
                    in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2) ? 1 : 0));
        }

        public AggregateInfo merge(AggregateInfo that) {
            _aggregates.putAll(that._aggregates);
            _inputAggs.addAll(that._inputAggs);
            _fusedInputs.addAll(that._fusedInputs);
            return this;
        }

        @Override
        public String toString() {
            return "[" + Arrays.toString(_aggregates.keySet().toArray(new Long[0])) + ": " + "{"
                    + Arrays.toString(_inputAggs.toArray(new Long[0])) + "}," + "{"
                    + Arrays.toString(_fusedInputs.toArray(new Long[0])) + "}]";
        }
    }

    private class PartitionSignature {
        private final int partNodes, inputNodes, rootNodes, matPoints;
        private final double cCompute, cRead, cWrite, cPlan0, cPlanN;

        public PartitionSignature(PlanPartition part, int M, StaticCosts costs, double cP0, double cPN) {
            partNodes = part.getPartition().size();
            inputNodes = part.getInputs().size();
            rootNodes = part.getRoots().size();
            matPoints = M;
            cCompute = costs._compute;
            cRead = costs._read;
            cWrite = costs._write;
            cPlan0 = cP0;
            cPlanN = cPN;
        }

        @Override
        public int hashCode() {
            return UtilFunctions.intHashCode(
                    Arrays.hashCode(new int[] { partNodes, inputNodes, rootNodes, matPoints }),
                    Arrays.hashCode(new double[] { cCompute, cRead, cWrite, cPlan0, cPlanN }));
        }

        @Override
        public boolean equals(Object o) {
            if (!(o instanceof PartitionSignature))
                return false;
            PartitionSignature that = (PartitionSignature) o;
            return partNodes == that.partNodes && inputNodes == that.inputNodes && rootNodes == that.rootNodes
                    && matPoints == that.matPoints && cCompute == that.cCompute && cRead == that.cRead
                    && cWrite == that.cWrite && cPlan0 == that.cPlan0 && cPlanN == that.cPlanN;
        }
    }
}