org.apache.zeppelin.spark.PySparkInterpreter.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.zeppelin.spark.PySparkInterpreter.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.zeppelin.spark;

import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.net.MalformedURLException;
import java.net.ServerSocket;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.exec.CommandLine;
import org.apache.commons.exec.DefaultExecutor;
import org.apache.commons.exec.ExecuteException;
import org.apache.commons.exec.ExecuteResultHandler;
import org.apache.commons.exec.ExecuteWatchdog;
import org.apache.commons.exec.PumpStreamHandler;
import org.apache.commons.exec.environment.EnvironmentUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterPropertyBuilder;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResult.Code;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.apache.zeppelin.interpreter.WrappedInterpreter;
import org.apache.zeppelin.spark.dep.DependencyContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import py4j.GatewayServer;

/**
 *
 */
public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler {
    Logger logger = LoggerFactory.getLogger(PySparkInterpreter.class);
    private GatewayServer gatewayServer;
    private DefaultExecutor executor;
    private int port;
    private ByteArrayOutputStream outputStream;
    private ByteArrayOutputStream errStream;
    private BufferedWriter ins;
    private PipedInputStream in;
    private ByteArrayOutputStream input;
    private String scriptPath;
    boolean pythonscriptRunning = false;

    static {
        Interpreter.register("pyspark", "spark", PySparkInterpreter.class.getName(),
                new InterpreterPropertyBuilder().add("zeppelin.pyspark.python",
                        SparkInterpreter.getSystemDefault("PYSPARK_PYTHON", null, "python"),
                        "Python command to run pyspark with").build());
    }

    public PySparkInterpreter(Properties property) {
        super(property);

        scriptPath = System.getProperty("java.io.tmpdir") + "/zeppelin_pyspark.py";
    }

    private void createPythonScript() {
        ClassLoader classLoader = getClass().getClassLoader();
        File out = new File(scriptPath);

        if (out.exists() && out.isDirectory()) {
            throw new InterpreterException("Can't create python script " + out.getAbsolutePath());
        }

        try {
            FileOutputStream outStream = new FileOutputStream(out);
            IOUtils.copy(classLoader.getResourceAsStream("python/zeppelin_pyspark.py"), outStream);
            outStream.close();
        } catch (IOException e) {
            throw new InterpreterException(e);
        }

        logger.info("File {} created", scriptPath);
    }

    @Override
    public void open() {
        DepInterpreter depInterpreter = getDepInterpreter();

        // load libraries from Dependency Interpreter
        URL[] urls = new URL[0];

        if (depInterpreter != null) {
            DependencyContext depc = depInterpreter.getDependencyContext();
            if (depc != null) {
                List<File> files = depc.getFiles();
                List<URL> urlList = new LinkedList<URL>();
                if (files != null) {
                    for (File f : files) {
                        try {
                            urlList.add(f.toURI().toURL());
                        } catch (MalformedURLException e) {
                            logger.error("Error", e);
                        }
                    }

                    urls = urlList.toArray(urls);
                }
            }
        }

        ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
        try {
            URLClassLoader newCl = new URLClassLoader(urls, oldCl);
            Thread.currentThread().setContextClassLoader(newCl);
            createGatewayServerAndStartScript();
        } catch (Exception e) {
            logger.error("Error", e);
            throw new InterpreterException(e);
        } finally {
            Thread.currentThread().setContextClassLoader(oldCl);
        }
    }

    private void createGatewayServerAndStartScript() {
        // create python script
        createPythonScript();

        port = findRandomOpenPortOnAllLocalInterfaces();

        gatewayServer = new GatewayServer(this, port);
        gatewayServer.start();

        // Run python shell
        CommandLine cmd = CommandLine.parse(getProperty("zeppelin.pyspark.python"));
        cmd.addArgument(scriptPath, false);
        cmd.addArgument(Integer.toString(port), false);
        cmd.addArgument(Integer.toString(getSparkInterpreter().getSparkVersion().toNumber()), false);
        executor = new DefaultExecutor();
        outputStream = new ByteArrayOutputStream();
        PipedOutputStream ps = new PipedOutputStream();
        in = null;
        try {
            in = new PipedInputStream(ps);
        } catch (IOException e1) {
            throw new InterpreterException(e1);
        }
        ins = new BufferedWriter(new OutputStreamWriter(ps));

        input = new ByteArrayOutputStream();

        PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, outputStream, in);
        executor.setStreamHandler(streamHandler);
        executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT));

        try {
            Map env = EnvironmentUtils.getProcEnvironment();

            executor.execute(cmd, env, this);
            pythonscriptRunning = true;
        } catch (IOException e) {
            throw new InterpreterException(e);
        }

        try {
            input.write("import sys, getopt\n".getBytes());
            ins.flush();
        } catch (IOException e) {
            throw new InterpreterException(e);
        }
    }

    private int findRandomOpenPortOnAllLocalInterfaces() {
        int port;
        try (ServerSocket socket = new ServerSocket(0);) {
            port = socket.getLocalPort();
            socket.close();
        } catch (IOException e) {
            throw new InterpreterException(e);
        }
        return port;
    }

    @Override
    public void close() {
        executor.getWatchdog().destroyProcess();
        gatewayServer.shutdown();
    }

    PythonInterpretRequest pythonInterpretRequest = null;

    /**
     *
     */
    public class PythonInterpretRequest {
        public String statements;
        public String jobGroup;

        public PythonInterpretRequest(String statements, String jobGroup) {
            this.statements = statements;
            this.jobGroup = jobGroup;
        }

        public String statements() {
            return statements;
        }

        public String jobGroup() {
            return jobGroup;
        }
    }

    Integer statementSetNotifier = new Integer(0);

    public PythonInterpretRequest getStatements() {
        synchronized (statementSetNotifier) {
            while (pythonInterpretRequest == null) {
                try {
                    statementSetNotifier.wait(1000);
                } catch (InterruptedException e) {
                }
            }
            PythonInterpretRequest req = pythonInterpretRequest;
            pythonInterpretRequest = null;
            return req;
        }
    }

    String statementOutput = null;
    boolean statementError = false;
    Integer statementFinishedNotifier = new Integer(0);

    public void setStatementsFinished(String out, boolean error) {
        synchronized (statementFinishedNotifier) {
            statementOutput = out;
            statementError = error;
            statementFinishedNotifier.notify();
        }

    }

    boolean pythonScriptInitialized = false;
    Integer pythonScriptInitializeNotifier = new Integer(0);

    public void onPythonScriptInitialized() {
        synchronized (pythonScriptInitializeNotifier) {
            pythonScriptInitialized = true;
            pythonScriptInitializeNotifier.notifyAll();
        }
    }

    @Override
    public InterpreterResult interpret(String st, InterpreterContext context) {
        if (!pythonscriptRunning) {
            return new InterpreterResult(Code.ERROR, "python process not running" + outputStream.toString());
        }

        outputStream.reset();

        synchronized (pythonScriptInitializeNotifier) {
            long startTime = System.currentTimeMillis();
            while (pythonScriptInitialized == false && pythonscriptRunning
                    && System.currentTimeMillis() - startTime < 10 * 1000) {
                try {
                    pythonScriptInitializeNotifier.wait(1000);
                } catch (InterruptedException e) {
                }
            }
        }

        if (pythonscriptRunning == false) {
            // python script failed to initialize and terminated
            return new InterpreterResult(Code.ERROR, "failed to start pyspark" + outputStream.toString());
        }
        if (pythonScriptInitialized == false) {
            // timeout. didn't get initialized message
            return new InterpreterResult(Code.ERROR, "pyspark is not responding " + outputStream.toString());
        }

        SparkInterpreter sparkInterpreter = getSparkInterpreter();
        if (!sparkInterpreter.getSparkVersion().isPysparkSupported()) {
            return new InterpreterResult(Code.ERROR,
                    "pyspark " + sparkInterpreter.getSparkContext().version() + " is not supported");
        }
        String jobGroup = sparkInterpreter.getJobGroup(context);
        ZeppelinContext z = sparkInterpreter.getZeppelinContext();
        z.setInterpreterContext(context);
        z.setGui(context.getGui());
        pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup);
        statementOutput = null;

        synchronized (statementSetNotifier) {
            statementSetNotifier.notify();
        }

        synchronized (statementFinishedNotifier) {
            while (statementOutput == null) {
                try {
                    statementFinishedNotifier.wait(1000);
                } catch (InterruptedException e) {
                }
            }
        }

        if (statementError) {
            return new InterpreterResult(Code.ERROR, statementOutput);
        } else {
            return new InterpreterResult(Code.SUCCESS, statementOutput);
        }
    }

    @Override
    public void cancel(InterpreterContext context) {
        SparkInterpreter sparkInterpreter = getSparkInterpreter();
        sparkInterpreter.cancel(context);
    }

    @Override
    public FormType getFormType() {
        return FormType.NATIVE;
    }

    @Override
    public int getProgress(InterpreterContext context) {
        SparkInterpreter sparkInterpreter = getSparkInterpreter();
        return sparkInterpreter.getProgress(context);
    }

    @Override
    public List<String> completion(String buf, int cursor) {
        // not supported
        return new LinkedList<String>();
    }

    private SparkInterpreter getSparkInterpreter() {
        InterpreterGroup intpGroup = getInterpreterGroup();
        LazyOpenInterpreter lazy = null;
        SparkInterpreter spark = null;
        synchronized (intpGroup) {
            for (Interpreter intp : getInterpreterGroup()) {
                if (intp.getClassName().equals(SparkInterpreter.class.getName())) {
                    Interpreter p = intp;
                    while (p instanceof WrappedInterpreter) {
                        if (p instanceof LazyOpenInterpreter) {
                            lazy = (LazyOpenInterpreter) p;
                        }
                        p = ((WrappedInterpreter) p).getInnerInterpreter();
                    }
                    spark = (SparkInterpreter) p;
                }
            }
        }
        if (lazy != null) {
            lazy.open();
        }
        return spark;
    }

    public ZeppelinContext getZeppelinContext() {
        SparkInterpreter sparkIntp = getSparkInterpreter();
        if (sparkIntp != null) {
            return getSparkInterpreter().getZeppelinContext();
        } else {
            return null;
        }
    }

    public JavaSparkContext getJavaSparkContext() {
        SparkInterpreter intp = getSparkInterpreter();
        if (intp == null) {
            return null;
        } else {
            return new JavaSparkContext(intp.getSparkContext());
        }
    }

    public SparkConf getSparkConf() {
        JavaSparkContext sc = getJavaSparkContext();
        if (sc == null) {
            return null;
        } else {
            return getJavaSparkContext().getConf();
        }
    }

    public SQLContext getSQLContext() {
        SparkInterpreter intp = getSparkInterpreter();
        if (intp == null) {
            return null;
        } else {
            return intp.getSQLContext();
        }
    }

    private DepInterpreter getDepInterpreter() {
        InterpreterGroup intpGroup = getInterpreterGroup();
        if (intpGroup == null)
            return null;
        synchronized (intpGroup) {
            for (Interpreter intp : intpGroup) {
                if (intp.getClassName().equals(DepInterpreter.class.getName())) {
                    Interpreter p = intp;
                    while (p instanceof WrappedInterpreter) {
                        p = ((WrappedInterpreter) p).getInnerInterpreter();
                    }
                    return (DepInterpreter) p;
                }
            }
        }
        return null;
    }

    @Override
    public void onProcessComplete(int exitValue) {
        pythonscriptRunning = false;
        logger.info("python process terminated. exit code " + exitValue);
    }

    @Override
    public void onProcessFailed(ExecuteException e) {
        pythonscriptRunning = false;
        logger.error("python process failed", e);
    }
}