Java tutorial
/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package com.ibm.bi.dml.hops.rewrite; import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.log4j.Logger; import com.ibm.bi.dml.hops.Hop; import com.ibm.bi.dml.hops.HopsException; import com.ibm.bi.dml.hops.OptimizerUtils; import com.ibm.bi.dml.parser.DMLProgram; import com.ibm.bi.dml.parser.ForStatement; import com.ibm.bi.dml.parser.ForStatementBlock; import com.ibm.bi.dml.parser.FunctionStatement; import com.ibm.bi.dml.parser.FunctionStatementBlock; import com.ibm.bi.dml.parser.IfStatement; import com.ibm.bi.dml.parser.IfStatementBlock; import com.ibm.bi.dml.parser.LanguageException; import com.ibm.bi.dml.parser.ParForStatementBlock; import com.ibm.bi.dml.parser.StatementBlock; import com.ibm.bi.dml.parser.WhileStatement; import com.ibm.bi.dml.parser.WhileStatementBlock; /** * This program rewriter applies a variety of rule-based rewrites * on all hop dags of the given program in one pass over the entire * program. * */ public class ProgramRewriter { private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName()); //internal local debug level private static final boolean LDEBUG = false; private static final boolean CHECK = false; private ArrayList<HopRewriteRule> _dagRuleSet = null; private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null; static { // for internal debugging only if (LDEBUG) { Logger.getLogger("com.ibm.bi.dml.hops.rewrite").setLevel((Level) Level.DEBUG); } } public ProgramRewriter() { // by default which is used during initial compile // apply all (static and dynamic) rewrites this(true, true); } public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); //initialize StatementBlock rewrite ruleSet (with fixed rewrite order) _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); //STATIC REWRITES (which do not rely on size information) if (staticRewrites) { //add static HOP DAG rewrite rules _dagRuleSet.add(new RewriteTransientWriteParentHandling()); _dagRuleSet.add(new RewriteRemoveReadAfterWrite()); //dependency: before blocksize _dagRuleSet.add(new RewriteBlockSizeAndReblock()); _dagRuleSet.add(new RewriteRemoveUnnecessaryCasts()); if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) _dagRuleSet.add(new RewriteCommonSubexpressionElimination()); if (OptimizerUtils.ALLOW_CONSTANT_FOLDING) _dagRuleSet.add(new RewriteConstantFolding()); //dependency: cse if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) _dagRuleSet.add(new RewriteAlgebraicSimplificationStatic()); //dependencies: cse if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) //dependency: simplifications (no need to merge leafs again) _dagRuleSet.add(new RewriteCommonSubexpressionElimination()); if (OptimizerUtils.ALLOW_AUTO_VECTORIZATION) _dagRuleSet.add(new RewriteIndexingVectorization()); //dependency: cse, simplifications _dagRuleSet.add(new RewriteInjectSparkPReadCheckpointing()); //dependency: reblock //add statment block rewrite rules if (OptimizerUtils.ALLOW_BRANCH_REMOVAL) _sbRuleSet.add(new RewriteRemoveUnnecessaryBranches()); //dependency: constant folding if (OptimizerUtils.ALLOW_SPLIT_HOP_DAGS) _sbRuleSet.add(new RewriteSplitDagUnknownCSVRead()); //dependency: reblock if (OptimizerUtils.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS) _sbRuleSet.add(new RewriteSplitDagDataDependentOperators()); if (OptimizerUtils.ALLOW_AUTO_VECTORIZATION) _sbRuleSet.add(new RewriteForLoopVectorization()); //dependency: reblock (reblockop) _sbRuleSet.add(new RewriteInjectSparkLoopCheckpointing(true)); //dependency: reblock (blocksizes) } // DYNAMIC REWRITES (which do require size information) if (dynamicRewrites) { _dagRuleSet.add(new RewriteMatrixMultChainOptimization()); //dependency: cse if (OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION) { _dagRuleSet.add(new RewriteAlgebraicSimplificationDynamic()); //dependencies: cse _dagRuleSet.add(new RewriteAlgebraicSimplificationStatic()); //dependencies: cse } //reapply cse after rewrites because (1) applied rewrites on operators w/ multiple parents, and //(2) newly introduced operators potentially created redundancy (incl leaf merge to allow for cse) if (OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION) _dagRuleSet.add(new RewriteCommonSubexpressionElimination(true)); //dependency: simplifications } } /** * Construct a program rewriter for a given rewrite which is passed from outside. * * @param rewrite */ public ProgramRewriter(HopRewriteRule rewrite) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _dagRuleSet.add(rewrite); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); } /** * Construct a program rewriter for a given rewrite which is passed from outside. * * @param rewrite */ public ProgramRewriter(StatementBlockRewriteRule rewrite) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); _sbRuleSet.add(rewrite); } /** * Construct a program rewriter for the given rewrite sets which are passed from outside. * * @param rewrite */ public ProgramRewriter(ArrayList<HopRewriteRule> hRewrites, ArrayList<StatementBlockRewriteRule> sbRewrites) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _dagRuleSet.addAll(hRewrites); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); _sbRuleSet.addAll(sbRewrites); } /** * * @param dmlp * @return * @throws LanguageException * @throws HopsException */ public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) throws LanguageException, HopsException { ProgramRewriteStatus state = new ProgramRewriteStatus(); // for each namespace, handle function statement blocks for (String namespaceKey : dmlp.getNamespaces().keySet()) for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname); rewriteStatementBlockHopDAGs(fsblock, state); rewriteStatementBlock(fsblock, state); } // handle regular statement blocks in "main" method for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { StatementBlock current = dmlp.getStatementBlock(i); rewriteStatementBlockHopDAGs(current, state); } dmlp.setStatementBlocks(rewriteStatementBlocks(dmlp.getStatementBlocks(), state)); return state; } /** * * @param current * @throws LanguageException * @throws HopsException */ public void rewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) throws LanguageException, HopsException { //ensure robustness for calls from outside if (state == null) state = new ProgramRewriteStatus(); if (current instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock) current; FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); for (StatementBlock sb : fstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) current; WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state)); for (StatementBlock sb : wstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) current; IfStatement istmt = (IfStatement) isb.getStatement(0); isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state)); for (StatementBlock sb : istmt.getIfBody()) rewriteStatementBlockHopDAGs(sb, state); for (StatementBlock sb : istmt.getElseBody()) rewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof ForStatementBlock) //incl parfor { ForStatementBlock fsb = (ForStatementBlock) current; ForStatement fstmt = (ForStatement) fsb.getStatement(0); fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state)); fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state)); fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state)); for (StatementBlock sb : fstmt.getBody()) rewriteStatementBlockHopDAGs(sb, state); } else //generic (last-level) { current.set_hops(rewriteHopDAGs(current.get_hops(), state)); } } /** * * @param roots * @throws LanguageException * @throws HopsException */ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { for (HopRewriteRule r : _dagRuleSet) { Hop.resetVisitStatus(roots); //reset for each rule roots = r.rewriteHopDAGs(roots, state); if (CHECK) { LOG.info("Validation after: " + r.getClass().getName()); HopDagValidator.validateHopDag(roots); } } return roots; } /** * * @param root * @throws LanguageException * @throws HopsException */ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException { for (HopRewriteRule r : _dagRuleSet) { root.resetVisitStatus(); //reset for each rule root = r.rewriteHopDAG(root, state); if (CHECK) { LOG.info("Validation after: " + r.getClass().getName()); HopDagValidator.validateHopDag(root); } } return root; } /** * * @param sbs * @return * @throws HopsException */ public ArrayList<StatementBlock> rewriteStatementBlocks(ArrayList<StatementBlock> sbs, ProgramRewriteStatus state) throws HopsException { //ensure robustness for calls from outside if (state == null) state = new ProgramRewriteStatus(); ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); //rewrite statement blocks (with potential expansion) for (StatementBlock sb : sbs) tmp.addAll(rewriteStatementBlock(sb, state)); //copy results into original collection sbs.clear(); sbs.addAll(tmp); return sbs; } /** * * @param sb * @return * @throws HopsException */ private ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException { ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); ret.add(sb); //recursive invocation if (sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock) sb; FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status)); } else if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); wstmt.setBody(rewriteStatementBlocks(wstmt.getBody(), status)); } else if (sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement istmt = (IfStatement) isb.getStatement(0); istmt.setIfBody(rewriteStatementBlocks(istmt.getIfBody(), status)); istmt.setElseBody(rewriteStatementBlocks(istmt.getElseBody(), status)); } else if (sb instanceof ForStatementBlock) //incl parfor { //maintain parfor context information (e.g., for checkpointing) boolean prestatus = status.isInParforContext(); if (sb instanceof ParForStatementBlock) status.setInParforContext(true); ForStatementBlock fsb = (ForStatementBlock) sb; ForStatement fstmt = (ForStatement) fsb.getStatement(0); fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status)); status.setInParforContext(prestatus); } //apply rewrite rules for (StatementBlockRewriteRule r : _sbRuleSet) { ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); for (StatementBlock sbc : ret) tmp.addAll(r.rewriteStatementBlock(sbc, status)); //take over set of rewritten sbs ret.clear(); ret.addAll(tmp); } return ret; } }