org.apache.beam.runners.spark.translation.TranslationUtils.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.spark.translation.TranslationUtils.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.runners.spark.translation;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterators;
import com.google.common.collect.Maps;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.InMemoryStateInternals;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateInternalsFactory;
import org.apache.beam.runners.spark.SparkRunner;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import scala.Tuple2;

/** A set of utilities to help translating Beam transformations into Spark transformations. */
public final class TranslationUtils {

    private TranslationUtils() {
    }

    /**
     * In-memory state internals factory.
     *
     * @param <K> State key type.
     */
    static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, Serializable {
        @Override
        public StateInternals stateInternalsForKey(K key) {
            return InMemoryStateInternals.forKey(key);
        }
    }

    /**
     * A SparkKeyedCombineFn function applied to grouped KVs.
     *
     * @param <K> Grouped key type.
     * @param <InputT> Grouped values type.
     * @param <OutputT> Output type.
     */
    public static class CombineGroupedValues<K, InputT, OutputT>
            implements Function<WindowedValue<KV<K, Iterable<InputT>>>, WindowedValue<KV<K, OutputT>>> {
        private final SparkKeyedCombineFn<K, InputT, ?, OutputT> fn;

        public CombineGroupedValues(SparkKeyedCombineFn<K, InputT, ?, OutputT> fn) {
            this.fn = fn;
        }

        @Override
        public WindowedValue<KV<K, OutputT>> call(WindowedValue<KV<K, Iterable<InputT>>> windowedKv)
                throws Exception {
            return WindowedValue.of(KV.of(windowedKv.getValue().getKey(), fn.apply(windowedKv)),
                    windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane());
        }
    }

    /**
     * Checks if the window transformation should be applied or skipped.
     *
     * <p>Avoid running assign windows if both source and destination are global window or if the user
     * has not specified the WindowFn (meaning they are just messing with triggering or allowed
     * lateness).
     *
     * @param transform The {@link Window.Assign} transformation.
     * @param context The {@link EvaluationContext}.
     * @param <T> PCollection type.
     * @param <W> {@link BoundedWindow} type.
     * @return if to apply the transformation.
     */
    public static <T, W extends BoundedWindow> boolean skipAssignWindows(Window.Assign<T> transform,
            EvaluationContext context) {
        @SuppressWarnings("unchecked")
        WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn();
        return windowFn == null
                || (context.getInput(transform).getWindowingStrategy().getWindowFn() instanceof GlobalWindows
                        && windowFn instanceof GlobalWindows);
    }

    /** Transform a pair stream into a value stream. */
    public static <T1, T2> JavaDStream<T2> dStreamValues(JavaPairDStream<T1, T2> pairDStream) {
        return pairDStream.map(new Function<Tuple2<T1, T2>, T2>() {
            @Override
            public T2 call(Tuple2<T1, T2> v1) throws Exception {
                return v1._2();
            }
        });
    }

    /** {@link KV} to pair function. */
    public static <K, V> PairFunction<KV<K, V>, K, V> toPairFunction() {
        return new PairFunction<KV<K, V>, K, V>() {
            @Override
            public Tuple2<K, V> call(KV<K, V> kv) {
                return new Tuple2<>(kv.getKey(), kv.getValue());
            }
        };
    }

    /** {@link KV} to pair flatmap function. */
    public static <K, V> PairFlatMapFunction<Iterator<KV<K, V>>, K, V> toPairFlatMapFunction() {
        return new PairFlatMapFunction<Iterator<KV<K, V>>, K, V>() {
            @Override
            public Iterable<Tuple2<K, V>> call(final Iterator<KV<K, V>> itr) {
                final Iterator<Tuple2<K, V>> outputItr = Iterators.transform(itr,
                        new com.google.common.base.Function<KV<K, V>, Tuple2<K, V>>() {

                            @Override
                            public Tuple2<K, V> apply(KV<K, V> kv) {
                                return new Tuple2<>(kv.getKey(), kv.getValue());
                            }
                        });
                return new Iterable<Tuple2<K, V>>() {

                    @Override
                    public Iterator<Tuple2<K, V>> iterator() {
                        return outputItr;
                    }
                };
            }
        };
    }

    /** A pair to {@link KV} function . */
    static <K, V> Function<Tuple2<K, V>, KV<K, V>> fromPairFunction() {
        return new Function<Tuple2<K, V>, KV<K, V>>() {
            @Override
            public KV<K, V> call(Tuple2<K, V> t2) {
                return KV.of(t2._1(), t2._2());
            }
        };
    }

    /** A pair to {@link KV} flatmap function . */
    static <K, V> FlatMapFunction<Iterator<Tuple2<K, V>>, KV<K, V>> fromPairFlatMapFunction() {
        return new FlatMapFunction<Iterator<Tuple2<K, V>>, KV<K, V>>() {
            @Override
            public Iterable<KV<K, V>> call(Iterator<Tuple2<K, V>> itr) {
                final Iterator<KV<K, V>> outputItr = Iterators.transform(itr,
                        new com.google.common.base.Function<Tuple2<K, V>, KV<K, V>>() {
                            @Override
                            public KV<K, V> apply(Tuple2<K, V> t2) {
                                return KV.of(t2._1(), t2._2());
                            }
                        });
                return new Iterable<KV<K, V>>() {
                    @Override
                    public Iterator<KV<K, V>> iterator() {
                        return outputItr;
                    }
                };
            }
        };
    }

    /** Extract key from a {@link WindowedValue} {@link KV} into a pair. */
    public static <K, V> PairFunction<WindowedValue<KV<K, V>>, K, WindowedValue<KV<K, V>>> toPairByKeyInWindowedValue() {
        return new PairFunction<WindowedValue<KV<K, V>>, K, WindowedValue<KV<K, V>>>() {
            @Override
            public Tuple2<K, WindowedValue<KV<K, V>>> call(WindowedValue<KV<K, V>> windowedKv) throws Exception {
                return new Tuple2<>(windowedKv.getValue().getKey(), windowedKv);
            }
        };
    }

    /** Extract window from a {@link KV} with {@link WindowedValue} value. */
    static <K, V> Function<KV<K, WindowedValue<V>>, WindowedValue<KV<K, V>>> toKVByWindowInValue() {
        return new Function<KV<K, WindowedValue<V>>, WindowedValue<KV<K, V>>>() {
            @Override
            public WindowedValue<KV<K, V>> call(KV<K, WindowedValue<V>> kv) throws Exception {
                WindowedValue<V> wv = kv.getValue();
                return wv.withValue(KV.of(kv.getKey(), wv.getValue()));
            }
        };
    }

    /**
     * A utility class to filter {@link TupleTag}s.
     *
     * @param <V> TupleTag type.
     */
    public static final class TupleTagFilter<V>
            implements Function<Tuple2<TupleTag<V>, WindowedValue<?>>, Boolean> {

        private final TupleTag<V> tag;

        public TupleTagFilter(TupleTag<V> tag) {
            this.tag = tag;
        }

        @Override
        public Boolean call(Tuple2<TupleTag<V>, WindowedValue<?>> input) {
            return tag.equals(input._1());
        }
    }

    /**
     * Create SideInputs as Broadcast variables.
     *
     * @param views The {@link PCollectionView}s.
     * @param context The {@link EvaluationContext}.
     * @return a map of tagged {@link SideInputBroadcast}s and their {@link WindowingStrategy}.
     */
    static Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> getSideInputs(
            List<PCollectionView<?>> views, EvaluationContext context) {
        return getSideInputs(views, context.getSparkContext(), context.getPViews());
    }

    /**
     * Create SideInputs as Broadcast variables.
     *
     * @param views The {@link PCollectionView}s.
     * @param context The {@link JavaSparkContext}.
     * @param pviews The {@link SparkPCollectionView}.
     * @return a map of tagged {@link SideInputBroadcast}s and their {@link WindowingStrategy}.
     */
    public static Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> getSideInputs(
            List<PCollectionView<?>> views, JavaSparkContext context, SparkPCollectionView pviews) {
        if (views == null) {
            return ImmutableMap.of();
        } else {
            Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = Maps.newHashMap();
            for (PCollectionView<?> view : views) {
                SideInputBroadcast helper = pviews.getPCollectionView(view, context);
                WindowingStrategy<?, ?> windowingStrategy = view.getWindowingStrategyInternal();
                sideInputs.put(view.getTagInternal(),
                        KV.<WindowingStrategy<?, ?>, SideInputBroadcast<?>>of(windowingStrategy, helper));
            }
            return sideInputs;
        }
    }

    public static void rejectSplittable(DoFn<?, ?> doFn) {
        DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());

        if (signature.processElement().isSplittable()) {
            throw new UnsupportedOperationException(String.format("%s does not support splittable DoFn: %s",
                    SparkRunner.class.getSimpleName(), doFn));
        }
    }

    /**
     * Reject state and timers {@link DoFn}.
     *
     * @param doFn the {@link DoFn} to possibly reject.
     */
    public static void rejectStateAndTimers(DoFn<?, ?> doFn) {
        DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());

        if (signature.stateDeclarations().size() > 0) {
            throw new UnsupportedOperationException(
                    String.format("Found %s annotations on %s, but %s cannot yet be used with state in the %s.",
                            DoFn.StateId.class.getSimpleName(), doFn.getClass().getName(),
                            DoFn.class.getSimpleName(), SparkRunner.class.getSimpleName()));
        }

        if (signature.timerDeclarations().size() > 0) {
            throw new UnsupportedOperationException(
                    String.format("Found %s annotations on %s, but %s cannot yet be used with timers in the %s.",
                            DoFn.TimerId.class.getSimpleName(), doFn.getClass().getName(),
                            DoFn.class.getSimpleName(), SparkRunner.class.getSimpleName()));
        }
    }

    public static <T> VoidFunction<T> emptyVoidFunction() {
        return new VoidFunction<T>() {
            @Override
            public void call(T t) throws Exception {
                // Empty implementation.
            }
        };
    }

    /**
     * A utility method that adapts {@link PairFunction} to a {@link PairFlatMapFunction} with an
     * {@link Iterator} input. This is particularly useful because it allows to use functions written
     * for mapToPair functions in flatmapToPair functions.
     *
     * @param pairFunction the {@link PairFunction} to adapt.
     * @param <T> the input type.
     * @param <K> the output key type.
     * @param <V> the output value type.
     * @return a {@link PairFlatMapFunction} that accepts an {@link Iterator} as an input and applies
     *     the {@link PairFunction} on every element.
     */
    public static <T, K, V> PairFlatMapFunction<Iterator<T>, K, V> pairFunctionToPairFlatMapFunction(
            final PairFunction<T, K, V> pairFunction) {
        return new PairFlatMapFunction<Iterator<T>, K, V>() {

            @Override
            public Iterable<Tuple2<K, V>> call(Iterator<T> itr) throws Exception {
                final Iterator<Tuple2<K, V>> outputItr = Iterators.transform(itr,
                        new com.google.common.base.Function<T, Tuple2<K, V>>() {

                            @Override
                            public Tuple2<K, V> apply(T t) {
                                try {
                                    return pairFunction.call(t);
                                } catch (Exception e) {
                                    throw new RuntimeException(e);
                                }
                            }
                        });
                return new Iterable<Tuple2<K, V>>() {

                    @Override
                    public Iterator<Tuple2<K, V>> iterator() {
                        return outputItr;
                    }
                };
            }
        };
    }

    /**
     * A utility method that adapts {@link Function} to a {@link FlatMapFunction} with an {@link
     * Iterator} input. This is particularly useful because it allows to use functions written for map
     * functions in flatmap functions.
     *
     * @param func the {@link Function} to adapt.
     * @param <InputT> the input type.
     * @param <OutputT> the output type.
     * @return a {@link FlatMapFunction} that accepts an {@link Iterator} as an input and applies the
     *     {@link Function} on every element.
     */
    public static <InputT, OutputT> FlatMapFunction<Iterator<InputT>, OutputT> functionToFlatMapFunction(
            final Function<InputT, OutputT> func) {
        return new FlatMapFunction<Iterator<InputT>, OutputT>() {

            @Override
            public Iterable<OutputT> call(Iterator<InputT> itr) throws Exception {
                final Iterator<OutputT> outputItr = Iterators.transform(itr,
                        new com.google.common.base.Function<InputT, OutputT>() {

                            @Override
                            public OutputT apply(InputT t) {
                                try {
                                    return func.call(t);
                                } catch (Exception e) {
                                    throw new RuntimeException(e);
                                }
                            }
                        });
                return new Iterable<OutputT>() {

                    @Override
                    public Iterator<OutputT> iterator() {
                        return outputItr;
                    }
                };
            }
        };
    }
}