org.datavec.spark.transform.SparkTransformExecutor.java Source code

Java tutorial

Introduction

Here is the source code for org.datavec.spark.transform.SparkTransformExecutor.java

Source

/*
 *  * Copyright 2016 Skymind, Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 */

package org.datavec.spark.transform;

import org.datavec.spark.SequenceEmptyRecordFunction;
import org.datavec.spark.functions.EmptyRecordFunction;
import org.datavec.spark.transform.join.*;
import org.datavec.spark.transform.misc.ColumnAsKeyPairFunction;
import org.datavec.spark.transform.reduce.MapToPairForReducerFunction;
import org.datavec.spark.transform.sequence.SparkMapToPairByColumnFunction;
import org.datavec.spark.transform.transform.SequenceSplitFunction;
import org.datavec.spark.transform.sequence.SparkGroupToSequenceFunction;
import org.datavec.spark.transform.sequence.SparkSequenceFilterFunction;
import org.datavec.spark.transform.sequence.SparkSequenceTransformFunction;
import org.apache.commons.math3.util.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.datavec.api.writable.Writable;
import org.datavec.api.transform.DataAction;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.join.Join;
import org.datavec.api.transform.rank.CalculateSortedRank;
import org.datavec.api.transform.sequence.ConvertToSequence;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.reduce.IReducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.spark.transform.analysis.SequenceFlatMapFunction;
import org.datavec.spark.transform.rank.UnzipForCalculateSortedRankFunction;
import org.datavec.spark.transform.filter.SparkFilterFunction;
import org.datavec.spark.transform.reduce.ReducerFunction;
import org.datavec.spark.transform.transform.SparkTransformFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

import java.util.Comparator;
import java.util.List;

/**
 * Execute a datavec
 * transform process
 * on spark rdds.
 *
 * @author Alex Black
 */
public class SparkTransformExecutor {

    private static final Logger log = LoggerFactory.getLogger(SparkTransformExecutor.class);
    //a boolean jvm argument that when the system property is true
    //will cause some functions to invoke a try catch block and just log errors
    //returning empty records
    public final static String LOG_ERROR_PROPERTY = "org.datavec.spark.transform.logerrors";

    /**
     * @deprecated Use static methods instead of instance methods on SparkTransformExecutor
     */
    @Deprecated
    public SparkTransformExecutor() {

    }

    /**
     * Execute the specified TransformProcess with the given input data<br>
     * Note: this method can only be used if the TransformProcess returns non-sequence data. For TransformProcesses
     * that return a sequence, use {@link #executeToSequence(JavaRDD, TransformProcess)}
     *
     * @param inputWritables   Input data to process
     * @param transformProcess TransformProcess to execute
     * @return Processed data
     */
    public static JavaRDD<List<Writable>> execute(JavaRDD<List<Writable>> inputWritables,
            TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }

        return execute(inputWritables, null, transformProcess).getFirst();
    }

    /**
     * Execute the specified TransformProcess with the given input data<br>
     * Note: this method can only be used if the TransformProcess
     * starts with non-sequential data,
     * but returns <it>sequence</it>
     * data (after grouping or converting to a sequence as one of the steps)
     *
     * @param inputWritables   Input data to process
     * @param transformProcess TransformProcess to execute
     * @return Processed (sequence) data
     */
    public static JavaRDD<List<List<Writable>>> executeToSequence(JavaRDD<List<Writable>> inputWritables,
            TransformProcess transformProcess) {
        if (!(transformProcess.getFinalSchema() instanceof SequenceSchema)) {
            throw new IllegalStateException("Cannot return non-sequence data with this method");
        }

        return execute(inputWritables, null, transformProcess).getSecond();
    }

    /**
     * Execute the specified TransformProcess with the given <i>sequence</i> input data<br>
     * Note: this method can only be used if the TransformProcess starts with sequence data, but returns <i>non-sequential</i>
     * data (after reducing or converting sequential data to individual examples)
     *
     * @param inputSequence    Input sequence data to process
     * @param transformProcess TransformProcess to execute
     * @return Processed (non-sequential) data
     */
    public static JavaRDD<List<Writable>> executeSequenceToSeparate(JavaRDD<List<List<Writable>>> inputSequence,
            TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }

        return execute(null, inputSequence, transformProcess).getFirst();
    }

    /**
     * Execute the specified TransformProcess with the given <i>sequence</i> input data<br>
     * Note: this method can only be used if the TransformProcess starts with sequence data, and also returns sequence data
     *
     * @param inputSequence    Input sequence data to process
     * @param transformProcess TransformProcess to execute
     * @return Processed (non-sequential) data
     */
    public static JavaRDD<List<List<Writable>>> executeSequenceToSequence(
            JavaRDD<List<List<Writable>>> inputSequence, TransformProcess transformProcess) {
        if (!(transformProcess.getFinalSchema() instanceof SequenceSchema)) {
            throw new IllegalStateException("Cannot return non-sequence data with this method");
        }

        return execute(null, inputSequence, transformProcess).getSecond();
    }

    /**
     * Returns true if the executor
     * is in try catch mode.
     * @return
     */
    public static boolean isTryCatch() {
        return Boolean.getBoolean(LOG_ERROR_PROPERTY);
    }

    private static Pair<JavaRDD<List<Writable>>, JavaRDD<List<List<Writable>>>> execute(
            JavaRDD<List<Writable>> inputWritables, JavaRDD<List<List<Writable>>> inputSequence,
            TransformProcess sequence) {
        JavaRDD<List<Writable>> currentWritables = inputWritables;
        JavaRDD<List<List<Writable>>> currentSequence = inputSequence;

        List<DataAction> list = sequence.getActionList();
        if (inputWritables != null) {
            List<Writable> first = inputWritables.first();
            if (first.size() != sequence.getInitialSchema().numColumns()) {
                throw new IllegalStateException("Input data number of columns (" + first.size()
                        + ") does not match the number of columns for the transform process ("
                        + sequence.getInitialSchema().numColumns() + ")");
            }
        } else {
            List<List<Writable>> firstSeq = inputSequence.first();
            if (firstSeq.size() > 0 && firstSeq.get(0).size() != sequence.getInitialSchema().numColumns()) {
                throw new IllegalStateException("Input sequence data number of columns (" + firstSeq.get(0).size()
                        + ") does not match the number of columns for the transform process ("
                        + sequence.getInitialSchema().numColumns() + ")");
            }
        }

        int count = 1;
        for (DataAction d : list) {
            //log.info("Starting execution of stage {} of {}", count, list.size());     //

            if (d.getTransform() != null) {
                Transform t = d.getTransform();
                if (currentWritables != null) {
                    Function<List<Writable>, List<Writable>> function = new SparkTransformFunction(t);
                    if (isTryCatch())
                        currentWritables = currentWritables.map(function).filter(new EmptyRecordFunction());
                    else
                        currentWritables = currentWritables.map(function);
                } else {
                    Function<List<List<Writable>>, List<List<Writable>>> function = new SparkSequenceTransformFunction(
                            t);
                    if (isTryCatch())
                        currentSequence = currentSequence.map(function).filter(new SequenceEmptyRecordFunction());
                    else
                        currentSequence = currentSequence.map(function);

                }
            } else if (d.getFilter() != null) {
                //Filter
                Filter f = d.getFilter();
                if (currentWritables != null) {
                    currentWritables = currentWritables.filter(new SparkFilterFunction(f));
                } else {
                    currentSequence = currentSequence.filter(new SparkSequenceFilterFunction(f));
                }

            } else if (d.getConvertToSequence() != null) {
                //Convert to a sequence...
                ConvertToSequence cts = d.getConvertToSequence();

                //First: convert to PairRDD
                Schema schema = cts.getInputSchema();
                int colIdx = schema.getIndexOfColumn(cts.getKeyColumn());
                JavaPairRDD<Writable, List<Writable>> withKey = currentWritables
                        .mapToPair(new SparkMapToPairByColumnFunction(colIdx));
                JavaPairRDD<Writable, Iterable<List<Writable>>> grouped = withKey.groupByKey();

                //Now: convert to a sequence...
                currentSequence = grouped.map(new SparkGroupToSequenceFunction(cts.getComparator()));
                currentWritables = null;
            } else if (d.getConvertFromSequence() != null) {
                //Convert from sequence...

                if (currentSequence == null) {
                    throw new IllegalStateException(
                            "Cannot execute ConvertFromSequence operation: current sequence is null");
                }

                currentWritables = currentSequence.flatMap(new SequenceFlatMapFunction());
                currentSequence = null;
            } else if (d.getSequenceSplit() != null) {
                SequenceSplit sequenceSplit = d.getSequenceSplit();
                if (currentSequence == null)
                    throw new IllegalStateException(
                            "Error during execution of SequenceSplit: currentSequence is null");
                currentSequence = currentSequence.flatMap(new SequenceSplitFunction(sequenceSplit));
            } else if (d.getReducer() != null) {
                IReducer reducer = d.getReducer();

                if (currentWritables == null)
                    throw new IllegalStateException(
                            "Error during execution of reduction: current writables are null. "
                                    + "Trying to execute a reduce operation on a sequence?");
                JavaPairRDD<String, List<Writable>> pair = currentWritables
                        .mapToPair(new MapToPairForReducerFunction(reducer));

                currentWritables = pair.groupByKey().map(new ReducerFunction(reducer));
            } else if (d.getCalculateSortedRank() != null) {
                CalculateSortedRank csr = d.getCalculateSortedRank();

                if (currentWritables == null) {
                    throw new IllegalStateException(
                            "Error during execution of CalculateSortedRank: current writables are null. "
                                    + "Trying to execute a CalculateSortedRank operation on a sequenc? (not currently supported)");
                }

                Comparator<Writable> comparator = csr.getComparator();
                String sortColumn = csr.getSortOnColumn();
                int sortColumnIdx = csr.getInputSchema().getIndexOfColumn(sortColumn);
                boolean ascending = csr.isAscending();
                //NOTE: this likely isn't the most efficient implementation.
                JavaPairRDD<Writable, List<Writable>> pairRDD = currentWritables
                        .mapToPair(new ColumnAsKeyPairFunction(sortColumnIdx));
                pairRDD = pairRDD.sortByKey(comparator, ascending);

                JavaPairRDD<Tuple2<Writable, List<Writable>>, Long> zipped = pairRDD.zipWithIndex();
                currentWritables = zipped.map(new UnzipForCalculateSortedRankFunction());
            } else {
                throw new RuntimeException("Unknown/not implemented action: " + d);
            }

            count++;
        }

        //log.info("Completed {} of {} execution steps", count - 1, list.size());       //Lazy execution means this can be printed before anything has actually happened...

        return new Pair<>(currentWritables, currentSequence);
    }

    /**
     * Execute a join on the specified data
     *
     * @param join  Join to execute
     * @param left  Left data for join
     * @param right Right data for join
     * @return Joined data
     */
    public static JavaRDD<List<Writable>> executeJoin(Join join, JavaRDD<List<Writable>> left,
            JavaRDD<List<Writable>> right) {
        //Extract out the keys, then join
        //This gives us a JavaPairRDD<String,JoinValue>
        JavaPairRDD<List<Writable>, JoinValue> leftJV = left.mapToPair(new MapToJoinValuesFunction(true, join));
        JavaPairRDD<List<Writable>, JoinValue> rightJV = right.mapToPair(new MapToJoinValuesFunction(false, join));

        //Then merge, collect by key, execute the join
        JavaPairRDD<List<Writable>, JoinValue> both = leftJV.union(rightJV);
        JavaPairRDD<List<Writable>, Iterable<JoinValue>> grouped = both.groupByKey();
        return grouped.flatMap(new ExecuteJoinFlatMapFunction(join));
    }
}