com.ibm.bi.dml.hops.rewrite.RewriteMatrixMultChainOptimization.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.bi.dml.hops.rewrite.RewriteMatrixMultChainOptimization.java

Source

/**
 * (C) Copyright IBM Corp. 2010, 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
*/

package com.ibm.bi.dml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import com.ibm.bi.dml.hops.AggBinaryOp;
import com.ibm.bi.dml.hops.Hop;
import com.ibm.bi.dml.hops.HopsException;
import com.ibm.bi.dml.utils.Explain;

/**
 * Rule: Determine the optimal order of execution for a chain of
 * matrix multiplications Solution: Classic Dynamic Programming
 * Approach Currently, the approach based only on matrix dimensions
 * Goal: To reduce the number of computations in the run-time
 * (map-reduce) layer
 */
public class RewriteMatrixMultChainOptimization extends HopRewriteRule {

    private static final Log LOG = LogFactory.getLog(RewriteMatrixMultChainOptimization.class.getName());
    private static final boolean LDEBUG = false;

    static {
        // for internal debugging only
        if (LDEBUG) {
            Logger.getLogger("com.ibm.bi.dml.hops.rewrite.RewriteMatrixMultChainOptimization")
                    .setLevel((Level) Level.TRACE);
        }
    }

    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
        if (roots == null)
            return null;

        for (Hop h : roots) {
            // Find the optimal order for the chain whose result is the current HOP
            rule_OptimizeMMChains(h);
        }

        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
        if (root == null)
            return null;

        // Find the optimal order for the chain whose result is the current HOP
        rule_OptimizeMMChains(root);

        return root;
    }

    /**
     * rule_OptimizeMMChains(): This method recurses through all Hops in the DAG
     * to find chains that need to be optimized.
     */
    private void rule_OptimizeMMChains(Hop hop) throws HopsException {
        if (hop.getVisited() == Hop.VisitStatus.DONE)
            return;

        if (hop instanceof AggBinaryOp && ((AggBinaryOp) hop).isMatrixMultiply()
                && !((AggBinaryOp) hop).hasLeftPMInput() && hop.getVisited() != Hop.VisitStatus.DONE) {
            // Try to find and optimize the chain in which current Hop is the
            // last operator
            optimizeMMChain(hop);
        }

        for (Hop hi : hop.getInput())
            rule_OptimizeMMChains(hi);

        hop.setVisited(Hop.VisitStatus.DONE);
    }

    /**
     * optimizeMMChain(): It optimizes the matrix multiplication chain in which
     * the last Hop is "this". Step-1) Identify the chain (mmChain). (Step-2) clear all
     * links among the Hops that are involved in mmChain. (Step-3) Find the
     * optimal ordering (dynamic programming) (Step-4) Relink the hops in
     * mmChain.
     */
    private void optimizeMMChain(Hop hop) throws HopsException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("MM Chain Optimization for HOP: (" + " " + hop.getClass().getSimpleName() + ", "
                    + hop.getHopID() + ", " + hop.getName() + ")");
        }

        ArrayList<Hop> mmChain = new ArrayList<Hop>();
        ArrayList<Hop> mmOperators = new ArrayList<Hop>();
        ArrayList<Hop> tempList;

        // Step 1: Identify the chain (mmChain) & clear all links among the Hops
        // that are involved in mmChain.

        mmOperators.add(hop);
        // Initialize mmChain with my inputs
        for (Hop hi : hop.getInput()) {
            mmChain.add(hi);
        }

        // expand each Hop in mmChain to find the entire matrix multiplication
        // chain
        int i = 0;
        while (i < mmChain.size()) {

            boolean expandable = false;

            Hop h = mmChain.get(i);
            /*
             * Check if mmChain[i] is expandable: 
             * 1) It must be MATMULT 
             * 2) It must not have been visited already 
             *    (one MATMULT should get expanded only in one chain)
             * 3) Its output should not be used in multiple places
             *    (either within chain or outside the chain)
             */

            if (h instanceof AggBinaryOp && ((AggBinaryOp) h).isMatrixMultiply()
                    && !((AggBinaryOp) hop).hasLeftPMInput() && h.getVisited() != Hop.VisitStatus.DONE) {
                // check if the output of "h" is used at multiple places. If yes, it can
                // not be expanded.
                if (h.getParent().size() > 1 || inputCount((Hop) ((h.getParent().toArray())[0]), h) > 1) {
                    expandable = false;
                    break;
                } else
                    expandable = true;
            }

            h.setVisited(Hop.VisitStatus.DONE);

            if (!expandable) {
                i = i + 1;
            } else {
                tempList = mmChain.get(i).getInput();
                if (tempList.size() != 2) {
                    throw new HopsException(hop.printErrorLocation()
                            + "Hops::rule_OptimizeMMChain(): AggBinary must have exactly two inputs.");
                }

                // add current operator to mmOperators, and its input nodes to mmChain
                mmOperators.add(mmChain.get(i));
                mmChain.set(i, tempList.get(0));
                mmChain.add(i + 1, tempList.get(1));
            }
        }

        // print the MMChain
        if (LOG.isTraceEnabled()) {
            LOG.trace("Identified MM Chain: ");
            for (Hop h : mmChain) {
                logTraceHop(h, 1);
            }
        }

        if (mmChain.size() == 2) {
            // If the chain size is 2, then there is nothing to optimize.
            return;
        } else {
            // Step 2: construct dims array
            double[] dimsArray = new double[mmChain.size() + 1];
            boolean dimsKnown = getDimsArray(hop, mmChain, dimsArray);

            if (dimsKnown) {
                // Step 3: clear the links among Hops within the identified chain
                clearLinksWithinChain(hop, mmOperators);

                // Step 4: Find the optimal ordering via dynamic programming.

                // Invoke Dynamic Programming
                int size = mmChain.size();
                int[][] split = mmChainDP(dimsArray, mmChain.size());

                // Step 5: Relink the hops using the optimal ordering (split[][]) found from DP.
                LOG.trace("Optimal MM Chain: ");
                mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1);
            }
        }
    }

    /**
     * mmChainDP(): Core method to perform dynamic programming on a given array
     * of matrix dimensions.
     * 
     * Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein
     * Introduction to Algorithms, Third Edition, MIT Press, page 395.
     */
    private int[][] mmChainDP(double[] dimArray, int size) {
        double[][] dpMatrix = new double[size][size]; //min cost table
        int[][] split = new int[size][size]; //min cost index table

        //init minimum costs for chains of length 1
        for (int i = 0; i < size; i++) {
            Arrays.fill(dpMatrix[i], 0);
            Arrays.fill(split[i], -1);
        }

        //compute cost-optimal chains for increasing chain sizes 
        for (int l = 2; l <= size; l++) { // chain length
            for (int i = 0; i < size - l + 1; i++) {
                int j = i + l - 1;
                // find cost of (i,j)
                dpMatrix[i][j] = Double.MAX_VALUE;
                for (int k = i; k <= j - 1; k++) {
                    //recursive cost computation
                    double cost = dpMatrix[i][k] + dpMatrix[k + 1][j]
                            + (dimArray[i] * dimArray[k + 1] * dimArray[j + 1]);

                    //prune suboptimal
                    if (cost < dpMatrix[i][j]) {
                        dpMatrix[i][j] = cost;
                        split[i][j] = k;
                    }
                }

                if (LOG.isTraceEnabled()) {
                    LOG.trace("mmchainopt [i=" + (i + 1) + ",j=" + (j + 1) + "]: costs = " + dpMatrix[i][j]
                            + ", split = " + (split[i][j] + 1));
                }
            }
        }

        return split;
    }

    /**
     * mmChainRelinkHops(): This method gets invoked after finding the optimal
     * order (split[][]) from dynamic programming. It relinks the Hops that are
     * part of the mmChain. mmChain : basic operands in the entire matrix
     * multiplication chain. mmOperators : Hops that store the intermediate
     * results in the chain. For example: A = B %*% (C %*% D) there will be
     * three Hops in mmChain (B,C,D), and two Hops in mmOperators (one for each
     * %*%) .
     */
    private void mmChainRelinkHops(Hop h, int i, int j, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators,
            int opIndex, int[][] split, int level) {
        //single matrix - end of recursion
        if (i == j) {
            logTraceHop(h, level);
            return;
        }

        if (LOG.isTraceEnabled()) {
            String offset = Explain.getIdentation(level);
            LOG.trace(offset + "(");
        }

        // Set Input1 for current Hop h
        if (i == split[i][j]) {
            h.getInput().add(mmChain.get(i));
            mmChain.get(i).getParent().add(h);
        } else {
            h.getInput().add(mmOperators.get(opIndex));
            mmOperators.get(opIndex).getParent().add(h);
            opIndex = opIndex + 1;
        }

        // Set Input2 for current Hop h
        if (split[i][j] + 1 == j) {
            h.getInput().add(mmChain.get(j));
            mmChain.get(j).getParent().add(h);
        } else {
            h.getInput().add(mmOperators.get(opIndex));
            mmOperators.get(opIndex).getParent().add(h);
            opIndex = opIndex + 1;
        }

        // Find children for both the inputs
        mmChainRelinkHops(h.getInput().get(0), i, split[i][j], mmChain, mmOperators, opIndex, split, level + 1);
        mmChainRelinkHops(h.getInput().get(1), split[i][j] + 1, j, mmChain, mmOperators, opIndex, split, level + 1);

        // Propagate properties of input hops to current hop h
        h.refreshSizeInformation();

        if (LOG.isTraceEnabled()) {
            String offset = Explain.getIdentation(level);
            LOG.trace(offset + ")");
        }
    }

    /**
     * 
     * @param operators
     * @throws HopsException
     */
    private void clearLinksWithinChain(Hop hop, ArrayList<Hop> operators) throws HopsException {
        Hop op, input1, input2;

        for (int i = 0; i < operators.size(); i++) {
            op = operators.get(i);
            if (op.getInput().size() != 2 || (i != 0 && op.getParent().size() > 1)) {
                throw new HopsException(hop.printErrorLocation()
                        + "Unexpected error while applying optimization on matrix-mult chain. \n");
            }
            input1 = op.getInput().get(0);
            input2 = op.getInput().get(1);

            op.getInput().clear();
            input1.getParent().remove(op);
            input2.getParent().remove(op);
        }
    }

    /**
     * Obtains all dimension information of the chain and constructs the dimArray.
     * If all dimensions are known it returns true; othrewise the mmchain rewrite
     * should be ended without modifications.
     * 
     * @param hop
     * @param chain
     * @param dimArray
     * @return
     * @throws HopsException
     */
    private boolean getDimsArray(Hop hop, ArrayList<Hop> chain, double[] dimsArray) throws HopsException {
        boolean dimsKnown = true;

        // Build the array containing dimensions from all matrices in the chain      
        // check the dimensions in the matrix chain to insure all dimensions are known
        for (int i = 0; i < chain.size(); i++) {
            if (chain.get(i).getDim1() <= 0 || chain.get(i).getDim2() <= 0)
                dimsKnown = false;
        }

        if (dimsKnown) { //populate dims array if all dims known
            for (int i = 0; i < chain.size(); i++) {
                if (i == 0) {
                    dimsArray[i] = chain.get(i).getDim1();
                    if (dimsArray[i] <= 0) {
                        throw new HopsException(hop.printErrorLocation()
                                + "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i]);
                    }
                } else {
                    if (chain.get(i - 1).getDim2() != chain.get(i).getDim1()) {
                        throw new HopsException(
                                hop.printErrorLocation() + "Hops::optimizeMMChain() : Matrix Dimension Mismatch: "
                                        + chain.get(i - 1).getDim2() + " != " + chain.get(i).getDim1());
                    }
                }
                dimsArray[i + 1] = chain.get(i).getDim2();
                if (dimsArray[i + 1] <= 0) {
                    throw new HopsException(hop.printErrorLocation()
                            + "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]);
                }
            }
        }

        return dimsKnown;
    }

    /**
     * 
     * @param p
     * @param h
     * @return
     */
    private int inputCount(Hop p, Hop h) {
        int count = 0;
        for (int i = 0; i < p.getInput().size(); i++)
            if (p.getInput().get(i).equals(h))
                count++;
        return count;
    }

    /**
     * 
     * @param hop
     * @param level
     */
    private void logTraceHop(Hop hop, int level) {
        if (LOG.isTraceEnabled()) {
            String offset = Explain.getIdentation(level);
            LOG.trace(offset + "Hop " + hop.getName() + "(" + hop.getClass().getSimpleName() + ", " + hop.getHopID()
                    + ")" + " " + hop.getDim1() + "x" + hop.getDim2());
        }
    }
}