com.thinkbiganalytics.spark.repl.SparkScriptEngine.java Source code

Java tutorial

Introduction

Here is the source code for com.thinkbiganalytics.spark.repl.SparkScriptEngine.java

Source

package com.thinkbiganalytics.spark.repl;

/*-
 * #%L
 * thinkbig-commons-spark-repl
 * %%
 * Copyright (C) 2017 ThinkBig Analytics
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import com.google.common.base.Joiner;
import com.google.common.base.Throwables;
import com.thinkbiganalytics.spark.SparkInterpreterBuilder;

import org.apache.commons.io.IOUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.io.InputStream;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.regex.Pattern;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.script.ScriptException;

import scala.collection.JavaConversions;
import scala.tools.nsc.Settings;
import scala.tools.nsc.interpreter.IMain;
import scala.tools.nsc.interpreter.Results;

/**
 * Evaluates Scala scripts using the Spark REPL interface.
 */
@Component
@ComponentScan("com.thinkbiganalytics.spark")
public class SparkScriptEngine extends ScriptEngine {

    private static final Logger log = LoggerFactory.getLogger(SparkScriptEngine.class);

    /**
     * Matches a multi-line comment in Scala
     */
    private static final Pattern COMMENT = Pattern.compile("/\\*.*\\*/");

    /**
     * Matches continuation lines in Scala
     */
    private static final Pattern LINE_CONTINUATION = Pattern.compile("^\\s*\\.");

    /**
     * Spark configuration
     */
    @Autowired
    private SparkConf conf;

    /**
     * List of patterns to deny in scripts
     */
    @Nullable
    private List<Pattern> denyPatterns;

    /**
     * Spark REPL interface
     */
    @Nullable
    private IMain interpreter;

    @Autowired
    private SparkInterpreterBuilder builder;

    @Nonnull
    @Override
    public ClassLoader getClassLoader() {
        // Get current context class loader
        final Thread currentThread = Thread.currentThread();
        final ClassLoader contextClassLoader = currentThread.getContextClassLoader();

        // Get interpreter class loader from context
        getInterpreter().setContextClassLoader();
        final ClassLoader interpreterClassLoader = currentThread.getContextClassLoader();

        // Reset context
        currentThread.setContextClassLoader(contextClassLoader);
        return interpreterClassLoader;
    }

    @Nonnull
    @Override
    protected SparkContext createSparkContext() {
        // Allow interpreter to modify Thread context for Spark
        getInterpreter().setContextClassLoader();

        // The SparkContext ClassLoader is needed during initialization (only for YARN master)
        return executeWithSparkClassLoader(new Callable<SparkContext>() {
            @Override
            public SparkContext call() throws Exception {
                log.info("Creating spark context with spark conf {}", conf);
                return new SparkContext(conf);
            }
        });
    }

    @Override
    protected void execute(@Nonnull final String script) throws ScriptException {
        log.debug("Executing script:\n{}", script);

        // Convert script to single line (for checking security violations)
        final StringBuilder safeScriptBuilder = new StringBuilder(script.length());

        for (final String line : script.split("\n")) {
            if (!LINE_CONTINUATION.matcher(line).find()) {
                safeScriptBuilder.append(';');
            }
            safeScriptBuilder.append(line);
        }

        final String safeScript = COMMENT.matcher(safeScriptBuilder.toString()).replaceAll("");

        // Check for security violations
        for (final Pattern pattern : getDenyPatterns()) {
            if (pattern.matcher(safeScript).find()) {
                log.error("Not executing script that matches deny pattern: {}", pattern);
                throw new ScriptException("Script not executed due to security policy.");
            }
        }

        // Execute script
        try {
            getInterpreter().interpret(safeScript);
        } catch (final AssertionError e) {
            log.warn("Caught assertion error when executing script. Retrying...", e);
            reset();
            getInterpreter().interpret(safeScript);
        }
    }

    @Override
    protected void reset() {
        super.reset();

        // Clear the interpreter
        if (interpreter != null) {
            interpreter.close();
            interpreter = null;
        }
    }

    /**
     * Executes the specified callable after replacing the current context class loader.
     *
     * <p>This is a work-around to avoid {@link ClassCastException} issues caused by conflicts between Hadoop and Kylo Spark Shell. Spark uses the context class loader when loading Hadoop components
     * for running Spark on YARN. When both Hadoop and Kylo Spark Shell provide the same class then both classes are loaded when creating a {@link SparkContext}. The fix is to set the context class
     * loader to the same class loader that was used to load the {@link SparkContext} class.</p>
     *
     * @param callable the function to be executed
     * @param <T>      the return type
     * @return the return value
     */
    private <T> T executeWithSparkClassLoader(@Nonnull final Callable<T> callable) {
        // Set context class loader
        final Thread currentThread = Thread.currentThread();
        final ClassLoader contextClassLoader = currentThread.getContextClassLoader();

        final ClassLoader sparkClassLoader = new ForwardingClassLoader(SparkContext.class.getClassLoader(),
                contextClassLoader);
        currentThread.setContextClassLoader(sparkClassLoader);

        // Execute callable
        try {
            return callable.call();
        } catch (final Exception e) {
            throw Throwables.propagate(e);
        } finally {
            // Reset context class loader
            currentThread.setContextClassLoader(contextClassLoader);
        }
    }

    /**
     * Gets the list of patterns that should prevent a script from executing.
     *
     * @return the deny patterns list
     * @throws IllegalStateException if the spark-deny-patterns.conf file cannot be found
     */
    @Nonnull
    private List<Pattern> getDenyPatterns() {
        if (denyPatterns == null) {
            // Load custom or default deny patterns
            String resourceName = "spark-deny-patterns.conf";
            InputStream resourceStream = getClass().getResourceAsStream("/" + resourceName);
            if (resourceStream == null) {
                resourceName = "spark-deny-patterns.default.conf";
                resourceStream = getClass().getResourceAsStream(resourceName);
            }

            // Parse lines
            final List<String> denyPatternLines;
            if (resourceStream != null) {
                try {
                    denyPatternLines = IOUtils.readLines(resourceStream, "UTF-8");
                    log.info("Loaded Spark deny patterns from {}.", resourceName);
                } catch (final IOException e) {
                    throw new IllegalStateException("Unable to load " + resourceName, e);
                }
            } else {
                log.info("Missing default Spark deny patterns.");
                denyPatternLines = Collections.emptyList();
            }

            // Compile patterns
            denyPatterns = new ArrayList<>();
            for (final String line : denyPatternLines) {
                final String trimLine = line.trim();
                if (!line.startsWith("#") && !trimLine.isEmpty()) {
                    denyPatterns.add(Pattern.compile(line));
                }
            }
        }
        return denyPatterns;
    }

    /**
     * Gets the Spark REPL interface to be used.
     *
     * @return the interpreter
     */
    @Nonnull
    private IMain getInterpreter() {
        if (this.interpreter == null) {
            // Determine engine settings
            final Settings settings = getSettings();

            // Initialize engine
            final ClassLoader parentClassLoader = getClass().getClassLoader();
            final SparkInterpreterBuilder b = this.builder.withSettings(settings).withPrintWriter(getPrintWriter())
                    .withClassLoader(parentClassLoader);
            final IMain interpreter = b.newInstance();

            interpreter.setContextClassLoader();
            interpreter.initializeSynchronous();

            // Setup environment
            final scala.collection.immutable.List<String> empty = JavaConversions
                    .asScalaBuffer(new ArrayList<String>()).toList();
            final Results.Result result = interpreter.bind("engine", SparkScriptEngine.class.getName(), this,
                    empty);
            if (result instanceof Results.Error$) {
                throw new IllegalStateException("Failed to initialize interpreter");
            }

            this.interpreter = interpreter;
        }
        return this.interpreter;
    }

    /**
     * Gets the settings for the interpreter.
     *
     * @return the interpreter settings
     */
    @Nonnull
    private Settings getSettings() {
        final Settings settings = new Settings();

        if (settings.classpath().isDefault()) {
            final String classPath = Joiner.on(':').join(((URLClassLoader) getClass().getClassLoader()).getURLs())
                    + ":" + System.getProperty("java.class.path");
            settings.classpath().value_$eq(classPath);
        }
        return settings;
    }
}