org.apache.beam.runners.spark.translation.SparkKeyedCombineFn.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.spark.translation.SparkKeyedCombineFn.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.runners.spark.translation;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.joda.time.Instant;

/**
 * A {@link org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn}
 * with a {@link org.apache.beam.sdk.transforms.CombineWithContext.Context} for the SparkRunner.
 */
public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstractCombineFn {
    private final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn;

    public SparkKeyedCombineFn(CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn,
            SparkRuntimeContext runtimeContext,
            Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
            WindowingStrategy<?, ?> windowingStrategy) {
        super(runtimeContext, sideInputs, windowingStrategy);
        this.combineFn = combineFn;
    }

    /** Applying the combine function directly on a key's grouped values - post grouping. */
    public OutputT apply(WindowedValue<KV<K, Iterable<InputT>>> windowedKv) {
        // apply combine function on grouped values.
        return combineFn.apply(windowedKv.getValue().getValue(), ctxtForInput(windowedKv));
    }

    /**
     * Implements Spark's createCombiner function in:
     * <p>
     * {@link org.apache.spark.rdd.PairRDDFunctions#combineByKey}.
     * </p>
     */
    Iterable<WindowedValue<KV<K, AccumT>>> createCombiner(WindowedValue<KV<K, InputT>> wkvi) {
        // sort exploded inputs.
        Iterable<WindowedValue<KV<K, InputT>>> sortedInputs = sortByWindows(wkvi.explodeWindows());

        TimestampCombiner timestampCombiner = windowingStrategy.getTimestampCombiner();
        WindowFn<?, BoundedWindow> windowFn = windowingStrategy.getWindowFn();

        //--- inputs iterator, by window order.
        final Iterator<WindowedValue<KV<K, InputT>>> iterator = sortedInputs.iterator();
        WindowedValue<KV<K, InputT>> currentInput = iterator.next();
        BoundedWindow currentWindow = Iterables.getFirst(currentInput.getWindows(), null);

        // first create the accumulator and accumulate first input.
        K key = currentInput.getValue().getKey();
        AccumT accumulator = combineFn.createAccumulator(ctxtForInput(currentInput));
        accumulator = combineFn.addInput(accumulator, currentInput.getValue().getValue(),
                ctxtForInput(currentInput));

        // keep track of the timestamps assigned by the TimestampCombiner.
        Instant windowTimestamp = timestampCombiner.assign(currentWindow,
                windowingStrategy.getWindowFn().getOutputTime(currentInput.getTimestamp(), currentWindow));

        // accumulate the next windows, or output.
        List<WindowedValue<KV<K, AccumT>>> output = Lists.newArrayList();

        // if merging, merge overlapping windows, e.g. Sessions.
        final boolean merging = !windowingStrategy.getWindowFn().isNonMerging();

        while (iterator.hasNext()) {
            WindowedValue<KV<K, InputT>> nextValue = iterator.next();
            BoundedWindow nextWindow = Iterables.getOnlyElement(nextValue.getWindows());

            boolean mergingAndIntersecting = merging
                    && isIntersecting((IntervalWindow) currentWindow, (IntervalWindow) nextWindow);

            if (mergingAndIntersecting || nextWindow.equals(currentWindow)) {
                if (mergingAndIntersecting) {
                    // merge intersecting windows.
                    currentWindow = merge((IntervalWindow) currentWindow, (IntervalWindow) nextWindow);
                }
                // keep accumulating and carry on ;-)
                accumulator = combineFn.addInput(accumulator, nextValue.getValue().getValue(),
                        ctxtForInput(nextValue));
                windowTimestamp = timestampCombiner.combine(windowTimestamp, timestampCombiner.assign(currentWindow,
                        windowFn.getOutputTime(nextValue.getTimestamp(), currentWindow)));
            } else {
                // moving to the next window, first add the current accumulation to output
                // and initialize the accumulator.
                output.add(WindowedValue.of(KV.of(key, accumulator), windowTimestamp, currentWindow,
                        PaneInfo.NO_FIRING));
                // re-init accumulator, window and timestamp.
                accumulator = combineFn.createAccumulator(ctxtForInput(nextValue));
                accumulator = combineFn.addInput(accumulator, nextValue.getValue().getValue(),
                        ctxtForInput(nextValue));
                currentWindow = nextWindow;
                windowTimestamp = timestampCombiner.assign(currentWindow,
                        windowFn.getOutputTime(nextValue.getTimestamp(), currentWindow));
            }
        }

        // add last accumulator to the output.
        output.add(WindowedValue.of(KV.of(key, accumulator), windowTimestamp, currentWindow, PaneInfo.NO_FIRING));

        return output;
    }

    /**
     * Implements Spark's mergeValue function in:
     * <p>
     * {@link org.apache.spark.rdd.PairRDDFunctions#combineByKey}.
     * </p>
     */
    Iterable<WindowedValue<KV<K, AccumT>>> mergeValue(WindowedValue<KV<K, InputT>> wkvi,
            Iterable<WindowedValue<KV<K, AccumT>>> wkvas) {
        // by calling createCombiner on the inputs and afterwards merging the accumulators,we avoid
        // an explode&accumulate for the input that will result in poor O(n^2) performance:
        // first sort the exploded input - O(nlogn).
        // follow with an accumulators sort = O(mlogm).
        // now for each (exploded) input, find a matching accumulator (if exists) to merge into, or
        // create a new one - O(n*m).
        // this results in - O(nlogn) + O(mlogm) + O(n*m) ~> O(n^2)
        // instead, calling createCombiner will create accumulators from the input - O(nlogn) + O(n).
        // now, calling mergeCombiners will finally result in - O((n+m)log(n+m)) + O(n+m) ~> O(nlogn).
        return mergeCombiners(createCombiner(wkvi), wkvas);
    }

    /**
     * Implements Spark's mergeCombiners function in:
     * <p>
     * {@link org.apache.spark.rdd.PairRDDFunctions#combineByKey}.
     * </p>
     */
    Iterable<WindowedValue<KV<K, AccumT>>> mergeCombiners(Iterable<WindowedValue<KV<K, AccumT>>> a1,
            Iterable<WindowedValue<KV<K, AccumT>>> a2) {
        // concatenate accumulators.
        Iterable<WindowedValue<KV<K, AccumT>>> accumulators = Iterables.concat(a1, a2);

        // sort accumulators, no need to explode since inputs were exploded.
        Iterable<WindowedValue<KV<K, AccumT>>> sortedAccumulators = sortByWindows(accumulators);

        @SuppressWarnings("unchecked")
        TimestampCombiner timestampCombiner = windowingStrategy.getTimestampCombiner();

        //--- accumulators iterator, by window order.
        final Iterator<WindowedValue<KV<K, AccumT>>> iterator = sortedAccumulators.iterator();

        // get the first accumulator and assign it to the current window's accumulators.
        WindowedValue<KV<K, AccumT>> currentValue = iterator.next();
        K key = currentValue.getValue().getKey();
        BoundedWindow currentWindow = Iterables.getFirst(currentValue.getWindows(), null);
        List<AccumT> currentWindowAccumulators = Lists.newArrayList();
        currentWindowAccumulators.add(currentValue.getValue().getValue());

        // keep track of the timestamps assigned by the TimestampCombiner,
        // in createCombiner we already merge the timestamps assigned
        // to individual elements, here we will just merge them.
        List<Instant> windowTimestamps = Lists.newArrayList();
        windowTimestamps.add(currentValue.getTimestamp());

        // accumulate the next windows, or output.
        List<WindowedValue<KV<K, AccumT>>> output = Lists.newArrayList();

        // if merging, merge overlapping windows, e.g. Sessions.
        final boolean merging = !windowingStrategy.getWindowFn().isNonMerging();

        while (iterator.hasNext()) {
            WindowedValue<KV<K, AccumT>> nextValue = iterator.next();
            BoundedWindow nextWindow = Iterables.getOnlyElement(nextValue.getWindows());

            boolean mergingAndIntersecting = merging
                    && isIntersecting((IntervalWindow) currentWindow, (IntervalWindow) nextWindow);

            if (mergingAndIntersecting || nextWindow.equals(currentWindow)) {
                if (mergingAndIntersecting) {
                    // merge intersecting windows.
                    currentWindow = merge((IntervalWindow) currentWindow, (IntervalWindow) nextWindow);
                }
                // add to window accumulators.
                currentWindowAccumulators.add(nextValue.getValue().getValue());
                windowTimestamps.add(nextValue.getTimestamp());
            } else {
                // before moving to the next window,
                // add the current accumulation to the output and initialize the accumulation.

                // merge the timestamps of all accumulators to merge.
                Instant mergedTimestamp = timestampCombiner.merge(currentWindow, windowTimestamps);

                // merge accumulators.
                // transforming a KV<K, Iterable<AccumT>> into a KV<K, Iterable<AccumT>>.
                // for the (possibly merged) window.
                Iterable<AccumT> accumsToMerge = Iterables.unmodifiableIterable(currentWindowAccumulators);
                WindowedValue<KV<K, Iterable<AccumT>>> preMergeWindowedValue = WindowedValue
                        .of(KV.of(key, accumsToMerge), mergedTimestamp, currentWindow, PaneInfo.NO_FIRING);
                // applying the actual combiner onto the accumulators.
                AccumT accumulated = combineFn.mergeAccumulators(accumsToMerge,
                        ctxtForInput(preMergeWindowedValue));
                WindowedValue<KV<K, AccumT>> postMergeWindowedValue = preMergeWindowedValue
                        .withValue(KV.of(key, accumulated));
                // emit the accumulated output.
                output.add(postMergeWindowedValue);

                // re-init accumulator, window and timestamps.
                currentWindowAccumulators.clear();
                currentWindowAccumulators.add(nextValue.getValue().getValue());
                currentWindow = nextWindow;
                windowTimestamps.clear();
                windowTimestamps.add(nextValue.getTimestamp());
            }
        }

        // merge the last chunk of accumulators.
        Instant mergedTimestamp = timestampCombiner.merge(currentWindow, windowTimestamps);
        Iterable<AccumT> accumsToMerge = Iterables.unmodifiableIterable(currentWindowAccumulators);
        WindowedValue<KV<K, Iterable<AccumT>>> preMergeWindowedValue = WindowedValue.of(KV.of(key, accumsToMerge),
                mergedTimestamp, currentWindow, PaneInfo.NO_FIRING);
        AccumT accumulated = combineFn.mergeAccumulators(accumsToMerge, ctxtForInput(preMergeWindowedValue));
        WindowedValue<KV<K, AccumT>> postMergeWindowedValue = preMergeWindowedValue
                .withValue(KV.of(key, accumulated));
        output.add(postMergeWindowedValue);

        return output;
    }

    Iterable<WindowedValue<OutputT>> extractOutput(Iterable<WindowedValue<KV<K, AccumT>>> wkvas) {
        return Iterables.transform(wkvas, new Function<WindowedValue<KV<K, AccumT>>, WindowedValue<OutputT>>() {
            @Nullable
            @Override
            public WindowedValue<OutputT> apply(@Nullable WindowedValue<KV<K, AccumT>> wkva) {
                if (wkva == null) {
                    return null;
                }
                AccumT accumulator = wkva.getValue().getValue();
                return wkva.withValue(combineFn.extractOutput(accumulator, ctxtForInput(wkva)));
            }
        });
    }
}