 * Copyright  2016 Cask Data, Inc.
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.


import co.cask.cdap.common.conf.CConfiguration;
import co.cask.cdap.common.internal.guava.ClassPath;
import co.cask.cdap.common.lang.ClassLoaders;
import co.cask.cdap.common.lang.ClassPathResources;
import co.cask.cdap.common.lang.FilterClassLoader;
import co.cask.cdap.common.lang.ProgramClassLoader;
import co.cask.cdap.common.lang.WeakReferenceDelegatorClassLoader;
import co.cask.cdap.common.lang.jar.BundleJarUtil;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.SparkConf;
import org.apache.spark.streaming.DStreamGraph;
import org.apache.spark.streaming.StreamingContext;
import org.apache.twill.common.Cancellable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import scala.collection.parallel.TaskSupport;
import scala.collection.parallel.ThreadPoolTaskSupport;
import scala.collection.parallel.mutable.ParArray;

import java.util.Collections;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

 * Util class for common functions needed for Spark implementation.
public final class SparkRuntimeUtils {

    private static final Logger LOG = LoggerFactory.getLogger(SparkRuntimeUtils.class);

    // ClassLoader filter
    public static final FilterClassLoader.Filter SPARK_PROGRAM_CLASS_LOADER_FILTER = new FilterClassLoader.Filter() {

        final FilterClassLoader.Filter defaultFilter = FilterClassLoader.defaultFilter();
        volatile Set<ClassPath.ResourceInfo> sparkStreamingResources;

        public boolean acceptResource(final String resource) {
            // All Spark API, Spark, Scala and Akka classes should come from parent.
            if (resource.startsWith("co/cask/cdap/api/spark/")) {
                return true;
            if (resource.startsWith("scala/")) {
                return true;
            if (resource.startsWith("akka/")) {
                return true;
            if (resource.startsWith("org/apache/spark/")) {
                // Only allows the core Spark Streaming classes, but not any streaming extensions (like Kafka).
                if (resource.startsWith("org/apache/spark/streaming")) {
                    return Iterables.any(getSparkStreamingResources(), new Predicate<ClassPath.ResourceInfo>() {
                        public boolean apply(ClassPath.ResourceInfo input) {
                            return input.getResourceName().equals(resource);
                return true;
            return defaultFilter.acceptResource(resource);

        public boolean acceptPackage(final String packageName) {
            if (packageName.equals("co.cask.cdap.api.spark") || packageName.startsWith("co.cask.cdap.api.spark.")) {
                return true;
            if (packageName.equals("scala") || packageName.startsWith("scala.")) {
                return true;
            if (packageName.equals("akka") || packageName.startsWith("akka.")) {
                return true;
            if (packageName.equals("org.apache.spark") || packageName.startsWith("org.apache.spark.")) {
                // Only allows the core Spark Streaming classes, but not any streaming extensions (like Kafka).
                if (packageName.equals("org.apache.spark.streaming")
                        || packageName.startsWith("org.apache.spark.streaming.")) {
                    return Iterables.any(Iterables.filter(getSparkStreamingResources(), ClassPath.ClassInfo.class),
                            new Predicate<ClassPath.ClassInfo>() {
                                public boolean apply(ClassPath.ClassInfo input) {
                                    return input.getPackageName().equals(packageName);
                return true;
            return defaultFilter.acceptResource(packageName);

         * Gets the set of resources information that are from the Spark Streaming Core. It excludes any
         * Spark streaming extensions, such as Kafka or Flume. They need to be excluded since they are not
         * part of Spark distribution and it should be loaded from the user program ClassLoader. This filtering
         * is needed for unit-testing because in unit-test, those extension classes are loadable from the system
         * classloader, causing same classes being loaded through different classloader.
        private Set<ClassPath.ResourceInfo> getSparkStreamingResources() {
            if (sparkStreamingResources != null) {
                return sparkStreamingResources;
            synchronized (this) {
                if (sparkStreamingResources != null) {
                    return sparkStreamingResources;

                try {
                    sparkStreamingResources = ClassPathResources.getClassPathResources(getClass().getClassLoader(),
                } catch (IOException e) {
                    LOG.warn("Failed to find resources for Spark StreamingContext.", e);
                    sparkStreamingResources = Collections.emptySet();
                return sparkStreamingResources;

     * Creates a {@link ProgramClassLoader} that have Spark classes visible.
    public static ProgramClassLoader createProgramClassLoader(CConfiguration cConf, File dir,
            ClassLoader unfilteredClassLoader) {
        ClassLoader parent = new FilterClassLoader(unfilteredClassLoader, SPARK_PROGRAM_CLASS_LOADER_FILTER);
        return new ProgramClassLoader(cConf, dir, parent);

     * Creates a zip file which contains a serialized {@link Properties} with a given zip entry name, together with
     * all files under the given directory. This is called from Client.createConfArchive() as a workaround for the
     * SPARK-13441 bug.
     * @param sparkConf the {@link SparkConf} to save
     * @param propertiesEntryName name of the zip entry for the properties
     * @param confDirPath directory to scan for files to include in the zip file
     * @param outputZipPath output file
     * @return the zip file
    public static File createConfArchive(SparkConf sparkConf, final String propertiesEntryName, String confDirPath,
            String outputZipPath) {
        final Properties properties = new Properties();
        for (Tuple2<String, String> tuple : sparkConf.getAll()) {
            properties.put(tuple._1(), tuple._2());

        try {
            File confDir = new File(confDirPath);
            final File zipFile = new File(outputZipPath);
            BundleJarUtil.createArchive(confDir, new OutputSupplier<ZipOutputStream>() {
                public ZipOutputStream getOutput() throws IOException {
                    ZipOutputStream zipOutput = new ZipOutputStream(new FileOutputStream(zipFile));
                    zipOutput.putNextEntry(new ZipEntry(propertiesEntryName));
          , "Spark configuration.");

                    return zipOutput;
            LOG.debug("Spark config archive created at {} from {}", zipFile, confDir);
            return zipFile;
        } catch (IOException e) {
            throw new RuntimeException(e);

     * Sets the context ClassLoader to the given {@link SparkClassLoader}. It will also set the
     * ClassLoader for the {@link Configuration} contained inside the {@link SparkClassLoader}.
     * @return a {@link Cancellable} to reset the classloader to the one prior to the call
    public static Cancellable setContextClassLoader(final SparkClassLoader sparkClassLoader) {
        final Configuration hConf = sparkClassLoader.getRuntimeContext().getConfiguration();
        final ClassLoader oldConfClassLoader = hConf.getClassLoader();

        // Always wrap it with WeakReference to avoid ClassLoader leakage from Spark.
        ClassLoader classLoader = new WeakReferenceDelegatorClassLoader(sparkClassLoader);
        final ClassLoader oldClassLoader = ClassLoaders.setContextClassLoader(classLoader);
        return new Cancellable() {
            public void cancel() {

                // Do not remove the next line.
                // This is necessary to keep a strong reference to the SparkClassLoader so that it won't get GC until this
                // cancel() is called
                LOG.trace("Reset context ClassLoader. The SparkClassLoader is: {}", sparkClassLoader);

     * Sets the {@link TaskSupport} for the given Scala {@link ParArray} to {@link ThreadPoolTaskSupport}.
     * This method is mainly used by {@link SparkRunnerClassLoader} to set the {@link TaskSupport} for the
     * parallel array used inside the {@link DStreamGraph} class in spark to avoid thread leakage after the
     * Spark program execution finished.
    public static <T> ParArray<T> setTaskSupport(ParArray<T> parArray) {
        ThreadPoolExecutor executor = new ThreadPoolExecutor(0, Integer.MAX_VALUE, 1, TimeUnit.SECONDS,
                new SynchronousQueue<Runnable>(),
                new ThreadFactoryBuilder().setNameFormat("task-support-%d").build());
        parArray.tasksupport_$eq(new ThreadPoolTaskSupport(executor));
        return parArray;

    private SparkRuntimeUtils() {
        // private