org.apache.crunch.impl.spark.SparkRuntime.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.crunch.impl.spark.SparkRuntime.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.crunch.impl.spark;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.AbstractFuture;
import org.apache.crunch.CombineFn;
import org.apache.crunch.PCollection;
import org.apache.crunch.PipelineExecution;
import org.apache.crunch.PipelineResult;
import org.apache.crunch.SourceTarget;
import org.apache.crunch.Target;
import org.apache.crunch.impl.dist.collect.PCollectionImpl;
import org.apache.crunch.impl.spark.fn.MapFunction;
import org.apache.crunch.impl.spark.fn.OutputConverterFunction;
import org.apache.crunch.impl.spark.fn.PairMapFunction;
import org.apache.crunch.io.MapReduceTarget;
import org.apache.crunch.io.PathTarget;
import org.apache.crunch.materialize.MaterializableIterable;
import org.apache.crunch.types.Converter;
import org.apache.crunch.types.PType;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapreduce.Job;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.storage.StorageLevel;

import java.io.IOException;
import java.net.URI;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;

public class SparkRuntime extends AbstractFuture<PipelineResult> implements PipelineExecution {

    private SparkPipeline pipeline;
    private JavaSparkContext sparkContext;
    private Configuration conf;
    private CombineFn combineFn;
    private SparkRuntimeContext ctxt;
    private Map<PCollectionImpl<?>, Set<Target>> outputTargets;
    private Map<PCollectionImpl<?>, MaterializableIterable> toMaterialize;
    private Map<PCollection<?>, StorageLevel> toCache;
    private final CountDownLatch doneSignal = new CountDownLatch(1);
    private AtomicReference<Status> status = new AtomicReference<Status>(Status.READY);
    private PipelineResult result;
    private boolean started;
    private Thread monitorThread;

    // Note that this is the oppposite of the MR sort
    static final Comparator<PCollectionImpl<?>> DEPTH_COMPARATOR = new Comparator<PCollectionImpl<?>>() {
        @Override
        public int compare(PCollectionImpl<?> left, PCollectionImpl<?> right) {
            int cmp = left.getDepth() - right.getDepth();
            if (cmp == 0) {
                // Ensure we don't throw away two output collections at the same depth.
                // Using the collection name would be nicer here, but names aren't
                // necessarily unique.
                cmp = new Integer(left.hashCode()).compareTo(right.hashCode());
            }
            return cmp;
        }
    };

    public SparkRuntime(SparkPipeline pipeline, JavaSparkContext sparkContext, Configuration conf,
            Map<PCollectionImpl<?>, Set<Target>> outputTargets,
            Map<PCollectionImpl<?>, MaterializableIterable> toMaterialize,
            Map<PCollection<?>, StorageLevel> toCache) {
        this.pipeline = pipeline;
        this.sparkContext = sparkContext;
        this.conf = conf;
        this.ctxt = new SparkRuntimeContext(sparkContext.broadcast(conf),
                sparkContext.accumulator(Maps.<String, Long>newHashMap(), new CounterAccumulatorParam()));
        this.outputTargets = Maps.newTreeMap(DEPTH_COMPARATOR);
        this.outputTargets.putAll(outputTargets);
        this.toMaterialize = toMaterialize;
        this.toCache = toCache;
        this.status.set(Status.READY);
        this.monitorThread = new Thread(new Runnable() {
            @Override
            public void run() {
                monitorLoop();
            }
        });
    }

    public void setCombineFn(CombineFn combineFn) {
        this.combineFn = combineFn;
    }

    public CombineFn getCombineFn() {
        CombineFn ret = combineFn;
        this.combineFn = null;
        return ret;
    }

    private void distributeFiles() {
        try {
            URI[] uris = DistributedCache.getCacheFiles(conf);
            if (uris != null) {
                URI[] outURIs = new URI[uris.length];
                for (int i = 0; i < uris.length; i++) {
                    Path path = new Path(uris[i]);
                    FileSystem fs = path.getFileSystem(conf);
                    if (fs.isFile(path)) {
                        outURIs[i] = uris[i];
                    } else {
                        Path mergePath = new Path(path.getParent(), "sparkreadable-" + path.getName());
                        FileUtil.copyMerge(fs, path, fs, mergePath, false, conf, "");
                        outURIs[i] = mergePath.toUri();
                    }
                    sparkContext.addFile(outURIs[i].toString());
                }
                DistributedCache.setCacheFiles(outURIs, conf);
            }
        } catch (IOException e) {
            throw new RuntimeException("Error retrieving cache files", e);
        }
    }

    public synchronized SparkRuntime execute() {
        if (!started) {
            monitorThread.start();
            started = true;
        }
        return this;
    }

    public JavaSparkContext getSparkContext() {
        return sparkContext;
    }

    public SparkRuntimeContext getRuntimeContext() {
        return ctxt;
    }

    public Configuration getConfiguration() {
        return conf;
    }

    public boolean isValid(JavaRDDLike<?, ?> rdd) {
        return (rdd != null); //TODO: support multi-contexts
    }

    public StorageLevel getStorageLevel(PCollection<?> pcollection) {
        return toCache.get(pcollection);
    }

    @Override
    public String getPlanDotFile() {
        return "";
    }

    @Override
    public void waitFor(long timeout, TimeUnit timeUnit) throws InterruptedException {
        doneSignal.await(timeout, timeUnit);
    }

    @Override
    public void waitUntilDone() throws InterruptedException {
        doneSignal.await();
    }

    private void monitorLoop() {
        status.set(Status.RUNNING);
        Map<PCollectionImpl<?>, Set<SourceTarget<?>>> targetDeps = Maps
                .<PCollectionImpl<?>, PCollectionImpl<?>, Set<SourceTarget<?>>>newTreeMap(DEPTH_COMPARATOR);
        for (PCollectionImpl<?> pcollect : outputTargets.keySet()) {
            targetDeps.put(pcollect, pcollect.getTargetDependencies());
        }

        while (!targetDeps.isEmpty() && doneSignal.getCount() > 0) {
            Set<Target> allTargets = Sets.newHashSet();
            for (PCollectionImpl<?> pcollect : targetDeps.keySet()) {
                allTargets.addAll(outputTargets.get(pcollect));
            }

            Map<PCollectionImpl<?>, JavaRDDLike<?, ?>> pcolToRdd = Maps.newTreeMap(DEPTH_COMPARATOR);
            for (PCollectionImpl<?> pcollect : targetDeps.keySet()) {
                if (Sets.intersection(allTargets, targetDeps.get(pcollect)).isEmpty()) {
                    JavaRDDLike<?, ?> rdd = ((SparkCollection) pcollect).getJavaRDDLike(this);
                    pcolToRdd.put(pcollect, rdd);
                }
            }
            distributeFiles();
            for (Map.Entry<PCollectionImpl<?>, JavaRDDLike<?, ?>> e : pcolToRdd.entrySet()) {
                JavaRDDLike<?, ?> rdd = e.getValue();
                PType<?> ptype = e.getKey().getPType();
                Set<Target> targets = outputTargets.get(e.getKey());
                if (targets.size() > 1) {
                    rdd.rdd().cache();
                }
                for (Target t : targets) {
                    Configuration conf = new Configuration(getConfiguration());
                    if (t instanceof MapReduceTarget) { //TODO: check this earlier
                        Converter c = t.getConverter(ptype);
                        JavaPairRDD<?, ?> outRDD;
                        if (rdd instanceof JavaRDD) {
                            outRDD = ((JavaRDD) rdd).map(new MapFunction(ptype.getOutputMapFn(), ctxt))
                                    .map(new OutputConverterFunction(c));
                        } else {
                            outRDD = ((JavaPairRDD) rdd).map(new PairMapFunction(ptype.getOutputMapFn(), ctxt))
                                    .map(new OutputConverterFunction(c));
                        }

                        try {
                            Job job = new Job(conf);
                            if (t instanceof PathTarget) {
                                PathTarget pt = (PathTarget) t;
                                pt.configureForMapReduce(job, ptype, pt.getPath(), null);
                                Path tmpPath = pipeline.createTempPath();
                                outRDD.saveAsNewAPIHadoopFile(tmpPath.toString(), c.getKeyClass(),
                                        c.getValueClass(), job.getOutputFormatClass(), job.getConfiguration());
                                pt.handleOutputs(job.getConfiguration(), tmpPath, -1);
                            } else if (t instanceof MapReduceTarget) {
                                MapReduceTarget mrt = (MapReduceTarget) t;
                                mrt.configureForMapReduce(job, ptype, new Path("/tmp"), null);
                                outRDD.saveAsHadoopDataset(new JobConf(job.getConfiguration()));
                            } else {
                                throw new IllegalArgumentException(
                                        "Spark execution cannot handle non-MapReduceTarget: " + t);
                            }
                        } catch (Exception et) {
                            et.printStackTrace();
                            status.set(Status.FAILED);
                            set(PipelineResult.EMPTY);
                            doneSignal.countDown();
                        }
                    }
                }
            }
            for (PCollectionImpl<?> output : pcolToRdd.keySet()) {
                if (toMaterialize.containsKey(output)) {
                    MaterializableIterable mi = toMaterialize.get(output);
                    if (mi.isSourceTarget()) {
                        output.materializeAt((SourceTarget) mi.getSource());
                    }
                }
                targetDeps.remove(output);
            }
        }
        if (status.get() != Status.FAILED || status.get() != Status.KILLED) {
            status.set(Status.SUCCEEDED);
            result = new PipelineResult(ImmutableList.of(new PipelineResult.StageResult("Spark", null)),
                    Status.SUCCEEDED);
            set(result);
        } else {
            set(PipelineResult.EMPTY);
        }
        doneSignal.countDown();
    }

    @Override
    public PipelineResult get() throws InterruptedException, ExecutionException {
        if (getStatus() == Status.READY) {
            execute();
        }
        return super.get();
    }

    @Override
    public PipelineResult get(long timeout, TimeUnit unit)
            throws InterruptedException, TimeoutException, ExecutionException {
        if (getStatus() == Status.READY) {
            execute();
        }
        return super.get(timeout, unit);
    }

    @Override
    public Status getStatus() {
        return status.get();
    }

    @Override
    public PipelineResult getResult() {
        return result;
    }

    @Override
    public void kill() throws InterruptedException {
        if (started) {
            sparkContext.stop();
            set(PipelineResult.EMPTY);
        }
    }
}