com.ibm.bi.dml.hops.rewrite.ProgramRewriter.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.bi.dml.hops.rewrite.ProgramRewriter.java

Source

/**
 * (C) Copyright IBM Corp. 2010, 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
*/

package com.ibm.bi.dml.hops.rewrite;

import java.util.ArrayList;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import com.ibm.bi.dml.hops.Hop;
import com.ibm.bi.dml.hops.HopsException;
import com.ibm.bi.dml.hops.OptimizerUtils;
import com.ibm.bi.dml.parser.DMLProgram;
import com.ibm.bi.dml.parser.ForStatement;
import com.ibm.bi.dml.parser.ForStatementBlock;
import com.ibm.bi.dml.parser.FunctionStatement;
import com.ibm.bi.dml.parser.FunctionStatementBlock;
import com.ibm.bi.dml.parser.IfStatement;
import com.ibm.bi.dml.parser.IfStatementBlock;
import com.ibm.bi.dml.parser.LanguageException;
import com.ibm.bi.dml.parser.ParForStatementBlock;
import com.ibm.bi.dml.parser.StatementBlock;
import com.ibm.bi.dml.parser.WhileStatement;
import com.ibm.bi.dml.parser.WhileStatementBlock;

/**
 * This program rewriter applies a variety of rule-based rewrites
 * on all hop dags of the given program in one pass over the entire
 * program. 
 * 
 */
public class ProgramRewriter {
    private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());

    //internal local debug level
    private static final boolean LDEBUG = false;
    private static final boolean CHECK = false;

    private ArrayList<HopRewriteRule> _dagRuleSet = null;
    private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null;

    static {
        // for internal debugging only
        if (LDEBUG) {
            Logger.getLogger("com.ibm.bi.dml.hops.rewrite").setLevel((Level) Level.DEBUG);
        }

    }

    public ProgramRewriter() {
        // by default which is used during initial compile 
        // apply all (static and dynamic) rewrites
        this(true, true);
    }

    public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) {
        //initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
        _dagRuleSet = new ArrayList<HopRewriteRule>();

        //initialize StatementBlock rewrite ruleSet (with fixed rewrite order)
        _sbRuleSet = new ArrayList<StatementBlockRewriteRule>();

        //STATIC REWRITES (which do not rely on size information)
        if (staticRewrites) {
            //add static HOP DAG rewrite rules
            _dagRuleSet.add(new RewriteTransientWriteParentHandling());
            _dagRuleSet.add(new RewriteRemoveReadAfterWrite()); //dependency: before blocksize
            _dagRuleSet.add(new RewriteBlockSizeAndReblock());
            _dagRuleSet.add(new RewriteRemoveUnnecessaryCasts());
            if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION)
                _dagRuleSet.add(new RewriteCommonSubexpressionElimination());
            if (OptimizerUtils.ALLOW_CONSTANT_FOLDING)
                _dagRuleSet.add(new RewriteConstantFolding()); //dependency: cse
            if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION)
                _dagRuleSet.add(new RewriteAlgebraicSimplificationStatic()); //dependencies: cse
            if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) //dependency: simplifications (no need to merge leafs again)
                _dagRuleSet.add(new RewriteCommonSubexpressionElimination());
            if (OptimizerUtils.ALLOW_AUTO_VECTORIZATION)
                _dagRuleSet.add(new RewriteIndexingVectorization()); //dependency: cse, simplifications
            _dagRuleSet.add(new RewriteInjectSparkPReadCheckpointing()); //dependency: reblock

            //add statment block rewrite rules
            if (OptimizerUtils.ALLOW_BRANCH_REMOVAL)
                _sbRuleSet.add(new RewriteRemoveUnnecessaryBranches()); //dependency: constant folding      
            if (OptimizerUtils.ALLOW_SPLIT_HOP_DAGS)
                _sbRuleSet.add(new RewriteSplitDagUnknownCSVRead()); //dependency: reblock   
            if (OptimizerUtils.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS)
                _sbRuleSet.add(new RewriteSplitDagDataDependentOperators());
            if (OptimizerUtils.ALLOW_AUTO_VECTORIZATION)
                _sbRuleSet.add(new RewriteForLoopVectorization()); //dependency: reblock (reblockop)
            _sbRuleSet.add(new RewriteInjectSparkLoopCheckpointing(true)); //dependency: reblock (blocksizes)
        }

        // DYNAMIC REWRITES (which do require size information)
        if (dynamicRewrites) {
            _dagRuleSet.add(new RewriteMatrixMultChainOptimization()); //dependency: cse 

            if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) {
                _dagRuleSet.add(new RewriteAlgebraicSimplificationDynamic()); //dependencies: cse
                _dagRuleSet.add(new RewriteAlgebraicSimplificationStatic()); //dependencies: cse
            }

            //reapply cse after rewrites because (1) applied rewrites on operators w/ multiple parents, and
            //(2) newly introduced operators potentially created redundancy (incl leaf merge to allow for cse)
            if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION)
                _dagRuleSet.add(new RewriteCommonSubexpressionElimination(true)); //dependency: simplifications          
        }
    }

    /**
     * Construct a program rewriter for a given rewrite which is passed from outside.
     * 
     * @param rewrite
     */
    public ProgramRewriter(HopRewriteRule rewrite) {
        //initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
        _dagRuleSet = new ArrayList<HopRewriteRule>();
        _dagRuleSet.add(rewrite);

        _sbRuleSet = new ArrayList<StatementBlockRewriteRule>();
    }

    /**
     * Construct a program rewriter for a given rewrite which is passed from outside.
     * 
     * @param rewrite
     */
    public ProgramRewriter(StatementBlockRewriteRule rewrite) {
        //initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
        _dagRuleSet = new ArrayList<HopRewriteRule>();

        _sbRuleSet = new ArrayList<StatementBlockRewriteRule>();
        _sbRuleSet.add(rewrite);
    }

    /**
     * Construct a program rewriter for the given rewrite sets which are passed from outside.
     * 
     * @param rewrite
     */
    public ProgramRewriter(ArrayList<HopRewriteRule> hRewrites, ArrayList<StatementBlockRewriteRule> sbRewrites) {
        //initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
        _dagRuleSet = new ArrayList<HopRewriteRule>();
        _dagRuleSet.addAll(hRewrites);

        _sbRuleSet = new ArrayList<StatementBlockRewriteRule>();
        _sbRuleSet.addAll(sbRewrites);
    }

    /**
     * 
     * @param dmlp
     * @return
     * @throws LanguageException
     * @throws HopsException
     */
    public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) throws LanguageException, HopsException {
        ProgramRewriteStatus state = new ProgramRewriteStatus();

        // for each namespace, handle function statement blocks
        for (String namespaceKey : dmlp.getNamespaces().keySet())
            for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
                FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
                rewriteStatementBlockHopDAGs(fsblock, state);
                rewriteStatementBlock(fsblock, state);
            }

        // handle regular statement blocks in "main" method
        for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
            StatementBlock current = dmlp.getStatementBlock(i);
            rewriteStatementBlockHopDAGs(current, state);
        }
        dmlp.setStatementBlocks(rewriteStatementBlocks(dmlp.getStatementBlocks(), state));

        return state;
    }

    /**
     * 
     * @param current
     * @throws LanguageException
     * @throws HopsException
     */
    public void rewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state)
            throws LanguageException, HopsException {
        //ensure robustness for calls from outside
        if (state == null)
            state = new ProgramRewriteStatus();

        if (current instanceof FunctionStatementBlock) {
            FunctionStatementBlock fsb = (FunctionStatementBlock) current;
            FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
            for (StatementBlock sb : fstmt.getBody())
                rewriteStatementBlockHopDAGs(sb, state);
        } else if (current instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock) current;
            WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
            wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state));
            for (StatementBlock sb : wstmt.getBody())
                rewriteStatementBlockHopDAGs(sb, state);
        } else if (current instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock) current;
            IfStatement istmt = (IfStatement) isb.getStatement(0);
            isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state));
            for (StatementBlock sb : istmt.getIfBody())
                rewriteStatementBlockHopDAGs(sb, state);
            for (StatementBlock sb : istmt.getElseBody())
                rewriteStatementBlockHopDAGs(sb, state);
        } else if (current instanceof ForStatementBlock) //incl parfor
        {
            ForStatementBlock fsb = (ForStatementBlock) current;
            ForStatement fstmt = (ForStatement) fsb.getStatement(0);
            fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state));
            fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state));
            fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state));
            for (StatementBlock sb : fstmt.getBody())
                rewriteStatementBlockHopDAGs(sb, state);
        } else //generic (last-level)
        {
            current.set_hops(rewriteHopDAGs(current.get_hops(), state));
        }
    }

    /**
     * 
     * @param roots
     * @throws LanguageException
     * @throws HopsException
     */
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
        for (HopRewriteRule r : _dagRuleSet) {
            Hop.resetVisitStatus(roots); //reset for each rule
            roots = r.rewriteHopDAGs(roots, state);

            if (CHECK) {
                LOG.info("Validation after: " + r.getClass().getName());
                HopDagValidator.validateHopDag(roots);
            }
        }

        return roots;
    }

    /**
     * 
     * @param root
     * @throws LanguageException
     * @throws HopsException
     */
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
        for (HopRewriteRule r : _dagRuleSet) {
            root.resetVisitStatus(); //reset for each rule
            root = r.rewriteHopDAG(root, state);

            if (CHECK) {
                LOG.info("Validation after: " + r.getClass().getName());
                HopDagValidator.validateHopDag(root);
            }
        }

        return root;
    }

    /**
     * 
     * @param sbs
     * @return
     * @throws HopsException 
     */
    public ArrayList<StatementBlock> rewriteStatementBlocks(ArrayList<StatementBlock> sbs,
            ProgramRewriteStatus state) throws HopsException {
        //ensure robustness for calls from outside
        if (state == null)
            state = new ProgramRewriteStatus();

        ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();

        //rewrite statement blocks (with potential expansion)
        for (StatementBlock sb : sbs)
            tmp.addAll(rewriteStatementBlock(sb, state));

        //copy results into original collection
        sbs.clear();
        sbs.addAll(tmp);

        return sbs;
    }

    /**
     * 
     * @param sb
     * @return
     * @throws HopsException
     */
    private ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status)
            throws HopsException {
        ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
        ret.add(sb);

        //recursive invocation
        if (sb instanceof FunctionStatementBlock) {
            FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
            FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
            fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status));
        } else if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock) sb;
            WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
            wstmt.setBody(rewriteStatementBlocks(wstmt.getBody(), status));
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock) sb;
            IfStatement istmt = (IfStatement) isb.getStatement(0);
            istmt.setIfBody(rewriteStatementBlocks(istmt.getIfBody(), status));
            istmt.setElseBody(rewriteStatementBlocks(istmt.getElseBody(), status));
        } else if (sb instanceof ForStatementBlock) //incl parfor
        {
            //maintain parfor context information (e.g., for checkpointing)
            boolean prestatus = status.isInParforContext();
            if (sb instanceof ParForStatementBlock)
                status.setInParforContext(true);

            ForStatementBlock fsb = (ForStatementBlock) sb;
            ForStatement fstmt = (ForStatement) fsb.getStatement(0);
            fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status));

            status.setInParforContext(prestatus);
        }

        //apply rewrite rules
        for (StatementBlockRewriteRule r : _sbRuleSet) {
            ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
            for (StatementBlock sbc : ret)
                tmp.addAll(r.rewriteStatementBlock(sbc, status));

            //take over set of rewritten sbs      
            ret.clear();
            ret.addAll(tmp);
        }

        return ret;
    }
}