com.cloudera.livy.client.local.LocalClient.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.livy.client.local.LocalClient.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.cloudera.livy.client.local;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.net.URI;
import java.net.URL;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Resources;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import org.apache.spark.SparkException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.cloudera.livy.Job;
import com.cloudera.livy.JobContext;
import com.cloudera.livy.JobHandle;
import com.cloudera.livy.LivyClient;
import com.cloudera.livy.client.common.BufferUtils;
import com.cloudera.livy.client.local.driver.RemoteDriver;
import com.cloudera.livy.client.local.rpc.Rpc;
import com.cloudera.livy.client.local.rpc.RpcServer;
import static com.cloudera.livy.client.local.LocalConf.Entry.*;

public class LocalClient implements LivyClient {
    private static final Logger LOG = LoggerFactory.getLogger(LocalClient.class);

    private static final long DEFAULT_SHUTDOWN_TIMEOUT = 10000; // In milliseconds

    private static final String OSX_TEST_OPTS = "SPARK_OSX_TEST_OPTS";
    private static final String SPARK_JARS_KEY = "spark.jars";
    private static final String SPARK_HOME_ENV = "SPARK_HOME";
    private static final String SPARK_HOME_KEY = "spark.home";
    private static final String DRIVER_OPTS_KEY = "spark.driver.extraJavaOptions";
    private static final String EXECUTOR_OPTS_KEY = "spark.executor.extraJavaOptions";

    private final LocalClientFactory factory;
    private final LocalConf conf;
    private final AtomicInteger childIdGenerator;
    private final Thread driverThread;
    private final Map<String, JobHandleImpl<?>> jobs;
    private final Rpc driverRpc;
    private final ClientProtocol protocol;
    private volatile boolean isAlive;
    private final String clientId;

    LocalClient(LocalClientFactory factory, LocalConf conf) throws IOException, SparkException {
        this.factory = factory;
        this.conf = conf;
        this.childIdGenerator = new AtomicInteger();
        this.jobs = Maps.newConcurrentMap();

        clientId = UUID.randomUUID().toString();
        String secret = factory.getServer().createSecret();
        this.driverThread = startDriver(factory.getServer(), clientId, secret);
        this.protocol = new ClientProtocol();

        try {
            // The RPC server will take care of timeouts here.
            this.driverRpc = factory.getServer().registerClient(clientId, secret, protocol).get();
        } catch (Throwable e) {
            LOG.warn("Error while waiting for client to connect.", e);
            driverThread.interrupt();
            try {
                driverThread.join();
            } catch (InterruptedException ie) {
                // Give up.
                LOG.debug("Interrupted before driver thread was finished.");
            }
            throw Throwables.propagate(e);
        }

        driverRpc.addListener(new Rpc.Listener() {
            @Override
            public void rpcClosed(Rpc rpc) {
                if (isAlive) {
                    LOG.warn("Client RPC channel closed unexpectedly.");
                    isAlive = false;
                }
            }
        });
        isAlive = true;
    }

    @Override
    public String clientId() {
        return clientId;
    }

    @Override
    public <T> JobHandle<T> submit(Job<T> job) {
        return protocol.submit(job);
    }

    @Override
    public <T> Future<T> run(Job<T> job) {
        return protocol.run(job);
    }

    @Override
    public void stop() {
        if (isAlive) {
            isAlive = false;
            try {
                protocol.endSession();
            } catch (Exception e) {
                LOG.warn("Exception while waiting for end session reply.", e);
            } finally {
                driverRpc.close();
                factory.unref();
            }
        }

        long endTime = System.currentTimeMillis() + DEFAULT_SHUTDOWN_TIMEOUT;
        try {
            driverThread.join(DEFAULT_SHUTDOWN_TIMEOUT);
        } catch (InterruptedException ie) {
            LOG.debug("Interrupted before driver thread was finished.");
        }
        if (endTime - System.currentTimeMillis() <= 0) {
            LOG.warn("Timed out shutting down remote driver, interrupting...");
            driverThread.interrupt();
        }
    }

    @Override
    public Future<?> addJar(URI uri) {
        return run(new AddJarJob(uri.toString()));
    }

    @Override
    public Future<?> addFile(URI uri) {
        return run(new AddFileJob(uri.toString()));
    }

    public String bypass(ByteBuffer serializedJob, boolean sync) {
        return protocol.bypass(serializedJob, sync);
    }

    public Future<BypassJobStatus> getBypassJobStatus(String id) {
        return protocol.getBypassJobStatus(id);
    }

    void cancel(String jobId) {
        protocol.cancel(jobId);
    }

    private Thread startDriver(final RpcServer rpcServer, final String clientId, final String secret)
            throws IOException {
        Runnable runnable;
        final String serverAddress = rpcServer.getAddress();
        final String serverPort = String.valueOf(rpcServer.getPort());

        if (conf.get(CLIENT_IN_PROCESS) != null) {
            // Mostly for testing things quickly. Do not do this in production.
            LOG.warn("!!!! Running remote driver in-process. !!!!");
            runnable = new Runnable() {
                @Override
                public void run() {
                    List<String> args = Lists.newArrayList();
                    args.add("--remote-host");
                    args.add(serverAddress);
                    args.add("--remote-port");
                    args.add(serverPort);
                    args.add("--client-id");
                    args.add(clientId);
                    args.add("--secret");
                    args.add(secret);

                    for (Map.Entry<String, String> e : conf) {
                        String key = e.getKey();
                        if (!key.startsWith("spark.")) {
                            key = LocalConf.SPARK_CONF_PREFIX + key;
                        }
                        args.add("--conf");
                        args.add(String.format("%s=%s", key, e.getValue()));
                    }
                    try {
                        RemoteDriver.main(args.toArray(new String[args.size()]));
                    } catch (Exception e) {
                        LOG.error("Error running driver.", e);
                    }
                }
            };
        } else {
            // If a Spark installation is provided, use the spark-submit script. Otherwise, call the
            // SparkSubmit class directly, which has some caveats (like having to provide a proper
            // version of Guava on the classpath depending on the deploy mode).
            String sparkHome = conf.get(SPARK_HOME_KEY);
            if (sparkHome == null) {
                sparkHome = System.getenv(SPARK_HOME_ENV);
            }
            if (sparkHome == null) {
                sparkHome = System.getProperty(SPARK_HOME_KEY);
            }

            String osxTestOpts = "";
            if (Strings.nullToEmpty(System.getProperty("os.name")).toLowerCase().contains("mac")) {
                osxTestOpts = Strings.nullToEmpty(System.getenv(OSX_TEST_OPTS));
            }

            String driverJavaOpts = Joiner.on(" ").skipNulls().join(osxTestOpts, conf.get(DRIVER_OPTS_KEY));
            String executorJavaOpts = Joiner.on(" ").skipNulls().join(osxTestOpts, conf.get(EXECUTOR_OPTS_KEY));

            // Create a file with all the job properties to be read by spark-submit. Change the
            // file's permissions so that only the owner can read it. This avoid having the
            // connection secret show up in the child process's command line.
            File properties = File.createTempFile("spark-submit.", ".properties");
            if (!properties.setReadable(false) || !properties.setReadable(true, true)) {
                throw new IOException("Cannot change permissions of job properties file.");
            }
            properties.deleteOnExit();

            Properties allProps = new Properties();
            // first load the defaults from spark-defaults.conf if available
            try {
                URL sparkDefaultsUrl = Thread.currentThread().getContextClassLoader()
                        .getResource("spark-defaults.conf");
                if (sparkDefaultsUrl != null) {
                    LOG.info("Loading spark defaults: " + sparkDefaultsUrl);
                    allProps.load(new ByteArrayInputStream(Resources.toByteArray(sparkDefaultsUrl)));
                }
            } catch (Exception e) {
                String msg = "Exception trying to load spark-defaults.conf: " + e;
                throw new IOException(msg, e);
            }
            // then load the SparkClientImpl config
            for (Map.Entry<String, String> e : conf) {
                String key = e.getKey();
                if (!key.startsWith("spark.")) {
                    key = LocalConf.SPARK_CONF_PREFIX + key;
                }
                allProps.put(key, e.getValue());
            }
            allProps.put(LocalConf.SPARK_CONF_PREFIX + CLIENT_ID.key, clientId);
            allProps.put(LocalConf.SPARK_CONF_PREFIX + CLIENT_SECRET.key, secret);
            allProps.put(DRIVER_OPTS_KEY, driverJavaOpts);
            allProps.put(EXECUTOR_OPTS_KEY, executorJavaOpts);

            Writer writer = new OutputStreamWriter(new FileOutputStream(properties), Charsets.UTF_8);
            try {
                allProps.store(writer, "Spark Context configuration");
            } finally {
                writer.close();
            }

            // Define how to pass options to the child process. If launching in client (or local)
            // mode, the driver options need to be passed directly on the command line. Otherwise,
            // SparkSubmit will take care of that for us.
            String master = conf.get("spark.master");
            Preconditions.checkArgument(master != null, "spark.master is not defined.");

            List<String> argv = Lists.newArrayList();

            if (sparkHome != null) {
                argv.add(new File(sparkHome, "bin/spark-submit").getAbsolutePath());
            } else {
                LOG.info("No spark.home provided, calling SparkSubmit directly.");
                argv.add(new File(System.getProperty("java.home"), "bin/java").getAbsolutePath());

                if (master.startsWith("local") || master.startsWith("mesos") || master.endsWith("-client")
                        || master.startsWith("spark")) {
                    String mem = conf.get("spark.driver.memory");
                    if (mem != null) {
                        argv.add("-Xms" + mem);
                        argv.add("-Xmx" + mem);
                    }

                    String cp = conf.get("spark.driver.extraClassPath");
                    if (cp != null) {
                        argv.add("-classpath");
                        argv.add(cp);
                    }

                    String libPath = conf.get("spark.driver.extraLibPath");
                    if (libPath != null) {
                        argv.add("-Djava.library.path=" + libPath);
                    }

                    String extra = conf.get(DRIVER_OPTS_KEY);
                    if (extra != null) {
                        for (String opt : extra.split("[ ]")) {
                            if (!opt.trim().isEmpty()) {
                                argv.add(opt.trim());
                            }
                        }
                    }
                }

                argv.add("org.apache.spark.deploy.SparkSubmit");
            }

            if (master.equals("yarn-cluster")) {
                String executorCores = conf.get("spark.executor.cores");
                if (executorCores != null) {
                    argv.add("--executor-cores");
                    argv.add(executorCores);
                }

                String executorMemory = conf.get("spark.executor.memory");
                if (executorMemory != null) {
                    argv.add("--executor-memory");
                    argv.add(executorMemory);
                }

                String numOfExecutors = conf.get("spark.executor.instances");
                if (numOfExecutors != null) {
                    argv.add("--num-executors");
                    argv.add(numOfExecutors);
                }
            }

            argv.add("--properties-file");
            argv.add(properties.getAbsolutePath());
            argv.add("--class");
            argv.add(RemoteDriver.class.getName());

            String jar = "spark-internal";
            String livyJars = conf.get(LIVY_JARS);
            if (livyJars == null) {
                String livyHome = System.getenv("LIVY_HOME");
                Preconditions.checkState(livyHome != null, "Need one of LIVY_HOME or %s set.", LIVY_JARS.key);

                File clientJars = new File(livyHome, "client-jars");
                Preconditions.checkState(clientJars.isDirectory(),
                        "Cannot find 'client-jars' directory under LIVY_HOME.");

                List<String> jars = new ArrayList<>();
                for (File f : clientJars.listFiles()) {
                    jars.add(f.getAbsolutePath());
                }
                livyJars = Joiner.on(",").join(jars);
            }

            String userJars = conf.get(SPARK_JARS_KEY);
            if (userJars != null) {
                String allJars = Joiner.on(",").join(livyJars, userJars);
                conf.set(SPARK_JARS_KEY, allJars);
            } else {
                argv.add("--jars");
                argv.add(livyJars);
            }

            argv.add(jar);
            argv.add("--remote-host");
            argv.add(serverAddress);
            argv.add("--remote-port");
            argv.add(serverPort);

            LOG.info("Running client driver with argv: {}", Joiner.on(" ").join(argv));
            final Process child = new ProcessBuilder(argv.toArray(new String[argv.size()])).start();

            int childId = childIdGenerator.incrementAndGet();
            redirect("stdout-redir-" + childId, child.getInputStream());
            redirect("stderr-redir-" + childId, child.getErrorStream());

            runnable = new Runnable() {
                @Override
                public void run() {
                    try {
                        int exitCode = child.waitFor();
                        if (exitCode != 0) {
                            rpcServer.cancelClient(clientId, "Child process exited before connecting back");
                            LOG.warn("Child process exited with code {}.", exitCode);
                        }
                    } catch (InterruptedException ie) {
                        LOG.warn("Waiting thread interrupted, killing child process.");
                        Thread.interrupted();
                        child.destroy();
                    } catch (Exception e) {
                        LOG.warn("Exception while waiting for child process.", e);
                    }
                }
            };
        }

        Thread thread = new Thread(runnable);
        thread.setDaemon(true);
        thread.setName("Driver");
        thread.start();
        return thread;
    }

    private void redirect(String name, InputStream in) {
        Thread thread = new Thread(new Redirector(in));
        thread.setName(name);
        thread.setDaemon(true);
        thread.start();
    }

    private class ClientProtocol extends BaseProtocol {

        <T> JobHandleImpl<T> submit(Job<T> job) {
            final String jobId = UUID.randomUUID().toString();
            Object msg = new JobRequest<T>(jobId, job);

            final Promise<T> promise = driverRpc.createPromise();
            final JobHandleImpl<T> handle = new JobHandleImpl<T>(LocalClient.this, promise, jobId);
            jobs.put(jobId, handle);

            final io.netty.util.concurrent.Future<Void> rpc = driverRpc.call(msg);
            LOG.debug("Send JobRequest[{}].", jobId);

            // Link the RPC and the promise so that events from one are propagated to the other as
            // needed.
            rpc.addListener(new GenericFutureListener<io.netty.util.concurrent.Future<Void>>() {
                @Override
                public void operationComplete(io.netty.util.concurrent.Future<Void> f) {
                    if (f.isSuccess()) {
                        handle.changeState(JobHandle.State.QUEUED);
                    } else if (!promise.isDone()) {
                        promise.setFailure(f.cause());
                    }
                }
            });
            promise.addListener(new GenericFutureListener<Promise<T>>() {
                @Override
                public void operationComplete(Promise<T> p) {
                    if (jobId != null) {
                        jobs.remove(jobId);
                    }
                    if (p.isCancelled() && !rpc.isDone()) {
                        rpc.cancel(true);
                    }
                }
            });
            return handle;
        }

        <T> Future<T> run(Job<T> job) {
            @SuppressWarnings("unchecked")
            final io.netty.util.concurrent.Future<T> rpc = (io.netty.util.concurrent.Future<T>) driverRpc
                    .call(new SyncJobRequest(job), Object.class);
            return rpc;
        }

        String bypass(ByteBuffer serializedJob, boolean sync) {
            String jobId = UUID.randomUUID().toString();
            Object msg = new BypassJobRequest(jobId, BufferUtils.toByteArray(serializedJob), sync);
            driverRpc.call(msg);
            return jobId;
        }

        Future<BypassJobStatus> getBypassJobStatus(String id) {
            return driverRpc.call(new GetBypassJobStatus(id), BypassJobStatus.class);
        }

        void cancel(String jobId) {
            driverRpc.call(new CancelJob(jobId));
        }

        Future<?> endSession() {
            return driverRpc.call(new EndSession());
        }

        private void handle(ChannelHandlerContext ctx, java.lang.Error msg) {
            LOG.warn("Error reported from remote driver.", msg.getCause());
        }

        private void handle(ChannelHandlerContext ctx, JobMetrics msg) {
            JobHandleImpl<?> handle = jobs.get(msg.jobId);
            if (handle != null) {
                handle.getMetrics().addMetrics(msg.sparkJobId, msg.stageId, msg.taskId, msg.metrics);
            } else {
                LOG.warn("Received metrics for unknown job {}", msg.jobId);
            }
        }

        private void handle(ChannelHandlerContext ctx, JobResult msg) {
            JobHandleImpl<?> handle = jobs.remove(msg.id);
            if (handle != null) {
                LOG.info("Received result for {}", msg.id);
                Throwable error = msg.error != null ? new SparkException(msg.error) : null;
                if (error == null) {
                    handle.setSuccess(msg.result);
                } else {
                    handle.setFailure(error);
                }
            } else {
                LOG.warn("Received result for unknown job {}", msg.id);
            }
        }

        private void handle(ChannelHandlerContext ctx, JobStarted msg) {
            JobHandleImpl<?> handle = jobs.get(msg.id);
            if (handle != null) {
                handle.changeState(JobHandle.State.STARTED);
            } else {
                LOG.warn("Received event for unknown job {}", msg.id);
            }
        }

        private void handle(ChannelHandlerContext ctx, JobSubmitted msg) {
            JobHandleImpl<?> handle = jobs.get(msg.clientJobId);
            if (handle != null) {
                LOG.info("Received spark job ID: {} for {}", msg.sparkJobId, msg.clientJobId);
                handle.addSparkJobId(msg.sparkJobId);
            } else {
                LOG.warn("Received spark job ID: {} for unknown job {}", msg.sparkJobId, msg.clientJobId);
            }
        }

    }

    private class Redirector implements Runnable {

        private final BufferedReader in;

        Redirector(InputStream in) {
            this.in = new BufferedReader(new InputStreamReader(in));
        }

        @Override
        public void run() {
            try {
                String line = null;
                while ((line = in.readLine()) != null) {
                    LOG.info(line);
                }
            } catch (Exception e) {
                LOG.warn("Error in redirector thread.", e);
            }
        }

    }

    private static class AddJarJob implements Job<Object> {

        private final String path;

        AddJarJob() {
            this(null);
        }

        AddJarJob(String path) {
            this.path = path;
        }

        @Override
        public Object call(JobContext jc) throws Exception {
            jc.sc().addJar(path);
            return null;
        }

    }

    private static class AddFileJob implements Job<Object> {

        private final String path;

        AddFileJob() {
            this(null);
        }

        AddFileJob(String path) {
            this.path = path;
        }

        @Override
        public Object call(JobContext jc) throws Exception {
            jc.sc().addFile(path);
            return null;
        }

    }

}