com.thinkbiganalytics.spark.repl.ScriptEngine.java Source code

Java tutorial

Introduction

Here is the source code for com.thinkbiganalytics.spark.repl.ScriptEngine.java

Source

package com.thinkbiganalytics.spark.repl;

/*-
 * #%L
 * thinkbig-commons-spark-repl
 * %%
 * Copyright (C) 2017 ThinkBig Analytics
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import com.google.common.base.Charsets;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.thinkbiganalytics.spark.util.ArrayUtils;

import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.hive.HiveContext;

import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.script.ScriptException;

import scala.tools.nsc.interpreter.NamedParam;

/**
 * Interface for an interpreter that compiles and evaluates Scala code containing a Spark job.
 *
 * <p>Scripts may access a {@link SparkContext} through the {@code sc} variable and a {@link SQLContext} through the
 * {@code sqlContext} variable.</p>
 *
 * <p>This class is <i>thread-safe</i> and ensures that only one script </p>
 */
public abstract class ScriptEngine {

    /**
     * End of line character
     */
    private static final byte[] END_LINE = new byte[] { '\n' };

    /**
     * Label used by the compiler to indicate a compile error
     */
    private static final byte[] LABEL = "<console>".getBytes(Charsets.UTF_8);

    /**
     * Separator between label, line number, and error message
     */
    private static final byte[] SEPARATOR = new byte[] { ':' };

    /**
     * Exception thrown by the last script
     */
    @Nonnull
    private final AtomicReference<Throwable> exception = new AtomicReference<>();

    /**
     * Compiler output stream for capturing compile errors
     */
    @Nonnull
    private final ByteArrayOutputStream out = new ByteArrayOutputStream();

    /**
     * Result of the last script
     */
    @Nonnull
    private final AtomicReference<Object> result = new AtomicReference<>();
    /**
     * Map of variable names to values for bindings
     */
    @Nonnull
    private final Map<String, Object> values = Maps.newHashMap();
    /**
     * Spark context
     */
    @Nullable
    private SparkContext sparkContext;
    /**
     * Spark SQL context
     */
    @Nullable
    private SQLContext sqlContext;

    /**
     * Executes the specified script.
     *
     * @param script the script to be executed
     * @return the value returned from the script
     * @throws ScriptException if an error occurs in the script
     */
    @Nullable
    public synchronized Object eval(@Nonnull final String script) throws ScriptException {
        List<NamedParam> bindings = ImmutableList.of();
        return eval(script, bindings);
    }

    /**
     * Executes the specified script with the given bindings.
     *
     * @param script   the script to be executed
     * @param bindings the variable bindings to be accessible to the script
     * @return the value returned from the script
     * @throws ScriptException if an error occurs in the script
     */
    @Nullable
    public synchronized Object eval(@Nonnull final String script, @Nonnull final List<NamedParam> bindings)
            throws ScriptException {
        // Define class containing script
        final StringBuilder cls = new StringBuilder();
        cls.append("class Script (engine: com.thinkbiganalytics.spark.repl.ScriptEngine)");
        cls.append("    extends com.thinkbiganalytics.spark.repl.Script (engine) {\n");
        cls.append("  override def eval (): Any = {\n");
        cls.append(script);
        cls.append("  }\n");

        // Add bindings to class
        this.values.clear();

        for (NamedParam param : bindings) {
            cls.append("  def ");
            cls.append(param.name());
            cls.append(" (): ");
            cls.append(param.tpe());
            cls.append(" = getValue(\"");
            cls.append(param.name());
            cls.append("\")\n");
            this.values.put(param.name(), param.value());
        }

        cls.append("}\n");

        // Instantiate class
        cls.append("new Script(engine).run()\n");

        // Execute script
        this.out.reset();

        execute(cls.toString());

        // Check for exception and return result
        checkCompileError();
        checkRuntimeError();

        return this.result.get();
    }

    /**
     * Gets the class loader used by the interpreter.
     *
     * @return the class loader
     */
    @Nonnull
    public abstract ClassLoader getClassLoader();

    /**
     * Gets the {@code SparkContext} available to scripts as {@code sc}.
     *
     * @return the Spark context
     */
    @Nonnull
    public SparkContext getSparkContext() {
        if (this.sparkContext == null) {
            this.sparkContext = createSparkContext();
        }
        return this.sparkContext;
    }

    /**
     * Gets the {@code SQLContext} available to scripts as {@code sqlContext}.
     *
     * @return the SQL context
     */
    @Nonnull
    public SQLContext getSQLContext() {
        if (this.sqlContext == null) {
            this.sqlContext = new HiveContext(getSparkContext());
        }
        return this.sqlContext;
    }

    /**
     * Creates the {@code SparkContext} that will be available to scripts as {@code sc}.
     *
     * @return the Spark context
     */
    @Nonnull
    protected abstract SparkContext createSparkContext();

    /**
     * Executes the specified script.
     *
     * @param script the script to be executed
     * @throws ScriptException if an error occurs in the script
     */
    protected abstract void execute(@Nonnull final String script) throws ScriptException;

    /**
     * Gets the writer for capturing compile errors.
     *
     * @return the compiler output stream
     */
    protected PrintWriter getPrintWriter() {
        return new PrintWriter(this.out);
    }

    /**
     * Resets the engine state so the {@link SparkContext} can be recreated.
     */
    protected void reset() {
        // Stop Spark
        if (sparkContext != null && !sparkContext.isStopped()) {
            sparkContext.stop();
        }

        // Clear instance variables
        exception.set(null);
        out.reset();
        result.set(null);
        sparkContext = null;
        sqlContext = null;
    }

    /**
     * Gets the value of the specified binding.
     *
     * @param name the name of the binding
     * @return the value of the binding
     */
    @Nullable
    Object getValue(@Nonnull final String name) {
        return this.values.get(name);
    }

    /**
     * Sets the runtime exception for the current script.
     *
     * @param t the exception
     */
    void setException(@Nonnull final Throwable t) {
        this.exception.set(t);
    }

    /**
     * Sets the result of the current script.
     *
     * @param result the result
     */
    void setResult(@Nullable final Object result) {
        this.exception.set(null);
        this.result.set(result);
    }

    /**
     * Checks the output stream for a compile error.
     *
     * @throws ScriptException if a compile error is found
     */
    private void checkCompileError() throws ScriptException {
        byte[] outBytes = this.out.toByteArray();

        // Look for label
        int labelIndex = ArrayUtils.indexOf(outBytes, 0, outBytes.length, LABEL, 0, LABEL.length, 0);
        if (labelIndex == -1) {
            return;
        }

        // Look for start and end of message
        int lineIndex = labelIndex + LABEL.length + 1;
        int msgStart = ArrayUtils.indexOf(outBytes, 0, outBytes.length, SEPARATOR, 0, SEPARATOR.length, lineIndex)
                + 2;
        int msgEnd = ArrayUtils.indexOf(outBytes, 0, outBytes.length, END_LINE, 0, END_LINE.length, msgStart);

        // Throw exception
        int line;
        String message;

        try {
            line = Integer.parseInt(new String(outBytes, lineIndex, msgStart - lineIndex - 2));
            message = new String(outBytes, msgStart, msgEnd - msgStart, "UTF-8");
        } catch (UnsupportedEncodingException e) {
            throw new IllegalStateException(e);
        }

        throw new ScriptException(message, "<console>", line);
    }

    /**
     * Checks for a runtime exception.
     *
     * @throws ScriptException if an exception is found
     */
    private void checkRuntimeError() throws ScriptException {
        Throwable exception = this.exception.get();

        if (exception != null) {
            Throwables.propagateIfPossible(exception, ScriptException.class);

            if (exception instanceof Exception) {
                throw new ScriptException((Exception) exception);
            } else {
                throw new ScriptException(exception.getMessage());
            }
        }
    }
}