org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimator.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimator.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptNode.ParamType;
import org.apache.sysml.runtime.controlprogram.parfor.opt.PerfTestTool.TestMeasure;

/**
 * Base class for all potential cost estimators
 * 
 * TODO account for shared read-only matrices when computing aggregated stats
 * 
 */
public abstract class CostEstimator {

    protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());

    //default parameters
    public static final double DEFAULT_EST_PARALLELISM = 1.0; //default degree of parallelism: serial
    public static final long FACTOR_NUM_ITERATIONS = 10; //default problem size
    public static final double DEFAULT_TIME_ESTIMATE = 5; //default execution time: 5ms
    public static final double DEFAULT_MEM_ESTIMATE_CP = 1024; //default memory consumption: 1KB 
    public static final double DEFAULT_MEM_ESTIMATE_MR = 10 * 1024 * 1024; //default memory consumption: 20MB // TODO investigate unused constant

    /**
     * Main leaf node estimation method - to be overwritten by specific cost estimators
     * 
     * @param measure ?
     * @param node internal representation of a plan alternative for program blocks and instructions
     * @return estimate?
     * @throws DMLRuntimeException if DMLRuntimeException occurs
     */
    public abstract double getLeafNodeEstimate(TestMeasure measure, OptNode node) throws DMLRuntimeException;

    /**
     * Main leaf node estimation method - to be overwritten by specific cost estimators
     * 
     * @param measure ?
     * @param node internal representation of a plan alternative for program blocks and instructions
     * @param et forced execution type for leaf node 
     * @return estimate?
     * @throws DMLRuntimeException if DMLRuntimeException occurs
     */
    public abstract double getLeafNodeEstimate(TestMeasure measure, OptNode node, ExecType et)
            throws DMLRuntimeException;

    /////////
    //methods invariant to concrete estimator
    ///

    /**
     * Main estimation method.
     * 
     * @param measure ?
     * @param node internal representation of a plan alternative for program blocks and instructions
     * @return estimate?
     * @throws DMLRuntimeException if DMLRuntimeException occurs
     */
    public double getEstimate(TestMeasure measure, OptNode node) throws DMLRuntimeException {
        return getEstimate(measure, node, null);
    }

    /**
     * Main estimation method.
     * 
     * @param measure ?
     * @param node internal representation of a plan alternative for program blocks and instructions
     * @param et execution type
     * @return estimate?
     * @throws DMLRuntimeException if DMLRuntimeException occurs
     */
    public double getEstimate(TestMeasure measure, OptNode node, ExecType et) throws DMLRuntimeException {
        double val = -1;

        if (node.isLeaf()) {
            if (et != null)
                val = getLeafNodeEstimate(measure, node, et); //forced type
            else
                val = getLeafNodeEstimate(measure, node); //default   
        } else {
            //aggreagtion methods for different program block types and measure types
            //TODO EXEC TIME requires reconsideration of for/parfor/if predicates 
            //TODO MEMORY requires reconsideration of parfor -> potential overestimation, but safe
            String tmp = null;
            double N = -1;
            switch (measure) {
            case EXEC_TIME:
                switch (node.getNodeType()) {
                case GENERIC:
                case FUNCCALL:
                    val = getSumEstimate(measure, node.getChilds(), et);
                    break;
                case IF:
                    if (node.getChilds().size() == 2)
                        val = getWeightedEstimate(measure, node.getChilds(), et);
                    else
                        val = getMaxEstimate(measure, node.getChilds(), et);
                    break;
                case WHILE:
                    val = FACTOR_NUM_ITERATIONS * getSumEstimate(measure, node.getChilds(), et);
                    break;
                case FOR:
                    tmp = node.getParam(ParamType.NUM_ITERATIONS);
                    N = (tmp != null) ? (double) Long.parseLong(tmp) : FACTOR_NUM_ITERATIONS;
                    val = N * getSumEstimate(measure, node.getChilds(), et);
                    break;
                case PARFOR:
                    tmp = node.getParam(ParamType.NUM_ITERATIONS);
                    N = (tmp != null) ? (double) Long.parseLong(tmp) : FACTOR_NUM_ITERATIONS;
                    val = N * getSumEstimate(measure, node.getChilds(), et) / node.getK();
                    break;
                default:
                    //do nothing
                }
                break;

            case MEMORY_USAGE:
                switch (node.getNodeType()) {
                case GENERIC:
                case FUNCCALL:
                case IF:
                case WHILE:
                case FOR:
                    val = getMaxEstimate(measure, node.getChilds(), et);
                    break;
                case PARFOR:
                    if (node.getExecType() == OptNode.ExecType.MR)
                        val = getMaxEstimate(measure, node.getChilds(), et); //executed in different JVMs
                    else if (node.getExecType() == OptNode.ExecType.CP)
                        val = getMaxEstimate(measure, node.getChilds(), et) * node.getK(); //everything executed within 1 JVM
                    break;
                default:
                    //do nothing
                }
                break;
            }
        }

        return val;
    }

    protected double getDefaultEstimate(TestMeasure measure) {
        double val = -1;

        switch (measure) {
        case EXEC_TIME:
            val = DEFAULT_TIME_ESTIMATE;
            break;
        case MEMORY_USAGE:
            val = DEFAULT_MEM_ESTIMATE_CP;
            break;
        }

        return val;
    }

    protected double getMaxEstimate(TestMeasure measure, ArrayList<OptNode> nodes, ExecType et)
            throws DMLRuntimeException {
        double max = Double.MIN_VALUE; //smallest positive value
        for (OptNode n : nodes) {
            double tmp = getEstimate(measure, n, et);
            if (tmp > max)
                max = tmp;
        }
        return max;
    }

    protected double getSumEstimate(TestMeasure measure, ArrayList<OptNode> nodes, ExecType et)
            throws DMLRuntimeException {
        double sum = 0;
        for (OptNode n : nodes)
            sum += getEstimate(measure, n, et);
        return sum;
    }

    protected double getWeightedEstimate(TestMeasure measure, ArrayList<OptNode> nodes, ExecType et)
            throws DMLRuntimeException {
        double ret = 0;
        int len = nodes.size();
        for (OptNode n : nodes)
            ret += getEstimate(measure, n, et);
        ret /= len; //weighting
        return ret;
    }
}