co.cask.cdap.etl.batch.spark.SparkBatchSourceFactory.java Source code

Java tutorial

Introduction

Here is the source code for co.cask.cdap.etl.batch.spark.SparkBatchSourceFactory.java

Source

/*
 * Copyright  2015 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.cdap.etl.batch.spark;

import co.cask.cdap.api.data.batch.Input;
import co.cask.cdap.api.data.batch.InputFormatProvider;
import co.cask.cdap.api.data.batch.Split;
import co.cask.cdap.api.data.format.FormatSpecification;
import co.cask.cdap.api.data.format.StructuredRecord;
import co.cask.cdap.api.data.stream.StreamBatchReadable;
import co.cask.cdap.api.spark.JavaSparkExecutionContext;
import co.cask.cdap.api.stream.StreamEventDecoder;
import com.google.common.base.Objects;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;

import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;

import static java.lang.Thread.currentThread;

/**
 * A POJO class for storing source information being set from {@link SparkBatchSourceContext} and used in
 * {@link ETLSparkProgram}.
 */
final class SparkBatchSourceFactory {

    private enum SourceType {
        STREAM(1), PROVIDER(2), DATASET(3);

        private final byte id;

        SourceType(int id) {
            this.id = (byte) id;
        }

        static SourceType from(byte id) {
            for (SourceType type : values()) {
                if (type.id == id) {
                    return type;
                }
            }
            throw new IllegalArgumentException("No SourceType found for id " + id);
        }
    }

    private final StreamBatchReadable streamBatchReadable;
    private final InputFormatProvider inputFormatProvider;
    private final DatasetInfo datasetInfo;

    static SparkBatchSourceFactory create(StreamBatchReadable streamBatchReadable) {
        return new SparkBatchSourceFactory(streamBatchReadable, null, null);
    }

    static SparkBatchSourceFactory create(InputFormatProvider inputFormatProvider) {
        return new SparkBatchSourceFactory(null, inputFormatProvider, null);
    }

    static SparkBatchSourceFactory create(String datasetName) {
        return create(datasetName, ImmutableMap.<String, String>of());
    }

    static SparkBatchSourceFactory create(String datasetName, Map<String, String> datasetArgs) {
        return create(datasetName, datasetArgs, null);
    }

    static SparkBatchSourceFactory create(String datasetName, Map<String, String> datasetArgs,
            @Nullable List<Split> splits) {
        return new SparkBatchSourceFactory(null, null, new DatasetInfo(datasetName, datasetArgs, splits));
    }

    static SparkBatchSourceFactory create(Input input) {
        if (input instanceof Input.DatasetInput) {
            // Note if input format provider is trackable then it comes in as DatasetInput
            Input.DatasetInput datasetInput = (Input.DatasetInput) input;
            return create(datasetInput.getName(), datasetInput.getArguments(), datasetInput.getSplits());
        } else if (input instanceof Input.StreamInput) {
            Input.StreamInput streamInput = (Input.StreamInput) input;
            return create(streamInput.getStreamBatchReadable());
        } else if (input instanceof Input.InputFormatProviderInput) {
            Input.InputFormatProviderInput ifpInput = (Input.InputFormatProviderInput) input;
            return new SparkBatchSourceFactory(null, ifpInput.getInputFormatProvider(), null);
        }
        throw new IllegalArgumentException("Unknown input format type: " + input.getClass().getCanonicalName());
    }

    static SparkBatchSourceFactory deserialize(InputStream inputStream) throws IOException {
        DataInput input = new DataInputStream(inputStream);

        // Deserialize based on the type
        switch (SourceType.from(input.readByte())) {
        case STREAM:
            return new SparkBatchSourceFactory(new StreamBatchReadable(URI.create(input.readUTF())), null, null);
        case PROVIDER:
            return new SparkBatchSourceFactory(null, new BasicInputFormatProvider(input.readUTF(),
                    Serializations.deserializeMap(input, Serializations.createStringObjectReader())), null);
        case DATASET:
            return new SparkBatchSourceFactory(null, null, DatasetInfo.deserialize(input));
        }
        throw new IllegalArgumentException("Invalid input. Failed to decode SparkBatchSourceFactory.");
    }

    private SparkBatchSourceFactory(@Nullable StreamBatchReadable streamBatchReadable,
            @Nullable InputFormatProvider inputFormatProvider, @Nullable DatasetInfo datasetInfo) {
        this.streamBatchReadable = streamBatchReadable;
        this.inputFormatProvider = inputFormatProvider;
        this.datasetInfo = datasetInfo;
    }

    public void serialize(OutputStream outputStream) throws IOException {
        DataOutput output = new DataOutputStream(outputStream);
        if (streamBatchReadable != null) {
            output.writeByte(SourceType.STREAM.id);
            output.writeUTF(streamBatchReadable.toURI().toString());
            return;
        }
        if (inputFormatProvider != null) {
            output.writeByte(SourceType.PROVIDER.id);
            output.writeUTF(inputFormatProvider.getInputFormatClassName());
            Serializations.serializeMap(inputFormatProvider.getInputFormatConfiguration(),
                    Serializations.createStringObjectWriter(), output);
            return;
        }
        if (datasetInfo != null) {
            output.writeByte(SourceType.DATASET.id);
            datasetInfo.serialize(output);
            return;
        }
        // This should never happen since the constructor is private and it only get calls from static create() methods
        // which make sure one and only one of those source type will be specified.
        throw new IllegalStateException("Unknown source type");
    }

    @SuppressWarnings("unchecked")
    public <K, V> JavaPairRDD<K, V> createRDD(JavaSparkExecutionContext sec, JavaSparkContext jsc,
            Class<K> keyClass, Class<V> valueClass) {
        if (streamBatchReadable != null) {
            FormatSpecification formatSpec = streamBatchReadable.getFormatSpecification();
            if (formatSpec != null) {
                return (JavaPairRDD<K, V>) sec.fromStream(streamBatchReadable.getStreamName(), formatSpec,
                        streamBatchReadable.getStartTime(), streamBatchReadable.getEndTime(),
                        StructuredRecord.class);
            }

            String decoderType = streamBatchReadable.getDecoderType();
            if (decoderType == null) {
                return (JavaPairRDD<K, V>) sec.fromStream(streamBatchReadable.getStreamName(),
                        streamBatchReadable.getStartTime(), streamBatchReadable.getEndTime(), valueClass);
            } else {
                try {
                    Class<StreamEventDecoder<K, V>> decoderClass = (Class<StreamEventDecoder<K, V>>) Thread
                            .currentThread().getContextClassLoader().loadClass(decoderType);
                    return sec.fromStream(streamBatchReadable.getStreamName(), streamBatchReadable.getStartTime(),
                            streamBatchReadable.getEndTime(), decoderClass, keyClass, valueClass);
                } catch (Exception e) {
                    throw Throwables.propagate(e);
                }
            }
        }
        if (inputFormatProvider != null) {
            Configuration hConf = new Configuration();
            hConf.clear();
            for (Map.Entry<String, String> entry : inputFormatProvider.getInputFormatConfiguration().entrySet()) {
                hConf.set(entry.getKey(), entry.getValue());
            }
            ClassLoader classLoader = Objects.firstNonNull(currentThread().getContextClassLoader(),
                    getClass().getClassLoader());
            try {
                @SuppressWarnings("unchecked")
                Class<InputFormat> inputFormatClass = (Class<InputFormat>) classLoader
                        .loadClass(inputFormatProvider.getInputFormatClassName());
                return jsc.newAPIHadoopRDD(hConf, inputFormatClass, keyClass, valueClass);
            } catch (ClassNotFoundException e) {
                throw Throwables.propagate(e);
            }
        }
        if (datasetInfo != null) {
            return sec.fromDataset(datasetInfo.getDatasetName(), datasetInfo.getDatasetArgs());
        }
        // This should never happen since the constructor is private and it only get calls from static create() methods
        // which make sure one and only one of those source type will be specified.
        throw new IllegalStateException("Unknown source type");
    }

    private static final class BasicInputFormatProvider implements InputFormatProvider {

        private final String inputFormatClassName;
        private final Map<String, String> configuration;

        private BasicInputFormatProvider(String inputFormatClassName, Map<String, String> configuration) {
            this.inputFormatClassName = inputFormatClassName;
            this.configuration = ImmutableMap.copyOf(configuration);
        }

        @Override
        public String getInputFormatClassName() {
            return inputFormatClassName;
        }

        @Override
        public Map<String, String> getInputFormatConfiguration() {
            return configuration;
        }
    }
}