com.ibm.bi.dml.runtime.controlprogram.parfor.RemoteParForSpark.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.bi.dml.runtime.controlprogram.parfor.RemoteParForSpark.java

Source

/**
 * (C) Copyright IBM Corp. 2010, 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
*/

package com.ibm.bi.dml.runtime.controlprogram.parfor;

import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaSparkContext;

import scala.Tuple2;

import com.ibm.bi.dml.api.DMLScript;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.controlprogram.LocalVariableMap;
import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContext;
import com.ibm.bi.dml.runtime.controlprogram.context.SparkExecutionContext;
import com.ibm.bi.dml.utils.Statistics;

/**
 * This class serves two purposes: (1) isolating Spark imports to enable running in 
 * environments where no Spark libraries are available, and (2) to follow the same
 * structure as the parfor remote_mr job submission.
 * 
 * NOTE: currently, we still exchange inputs and outputs via hdfs (this covers the general case
 * if data already resides in HDFS, in-memory data, and partitioned inputs; also, it allows for
 * pre-aggregation by overwriting partial task results with pre-paggregated results from subsequent
 * iterations)
 * 
 * TODO broadcast variables if possible
 * TODO reducebykey on variable names
 */
public class RemoteParForSpark {

    protected static final Log LOG = LogFactory.getLog(RemoteParForSpark.class.getName());

    /**
     * 
     * @param pfid
     * @param program
     * @param tasks
     * @param ec
     * @param enableCPCaching
     * @param numMappers
     * @return
     * @throws DMLRuntimeException 
     * @throws DMLUnsupportedOperationException 
     */
    public static RemoteParForJobReturn runJob(long pfid, String program, List<Task> tasks, ExecutionContext ec,
            boolean cpCaching, int numMappers) throws DMLRuntimeException, DMLUnsupportedOperationException {
        String jobname = "ParFor-ESP";
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;

        SparkExecutionContext sec = (SparkExecutionContext) ec;
        JavaSparkContext sc = sec.getSparkContext();

        //initialize accumulators for tasks/iterations
        Accumulator<Integer> aTasks = sc.accumulator(0);
        Accumulator<Integer> aIters = sc.accumulator(0);

        //run remote_spark parfor job 
        //(w/o lazy evaluation to fit existing parfor framework, e.g., result merge)
        RemoteParForSparkWorker func = new RemoteParForSparkWorker(program, cpCaching, aTasks, aIters);
        List<Tuple2<Long, String>> out = sc.parallelize(tasks, numMappers) //create rdd of parfor tasks
                .flatMapToPair(func) //execute parfor tasks 
                .collect(); //get output handles

        //de-serialize results
        LocalVariableMap[] results = RemoteParForUtils.getResults(out, LOG);
        int numTasks = aTasks.value(); //get accumulator value
        int numIters = aIters.value(); //get accumulator value

        //create output symbol table entries
        RemoteParForJobReturn ret = new RemoteParForJobReturn(true, numTasks, numIters, results);

        //maintain statistics
        Statistics.incrementNoOfCompiledSPInst();
        Statistics.incrementNoOfExecutedSPInst();
        if (DMLScript.STATISTICS) {
            Statistics.maintainCPHeavyHitters(jobname, System.nanoTime() - t0);
        }

        return ret;
    }
}