org.apache.beam.runners.spark.translation.SparkGlobalCombineFn.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.spark.translation.SparkGlobalCombineFn.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.runners.spark.translation;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.joda.time.Instant;

/**
 * A {@link org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn}
 * with a {@link CombineWithContext.Context} for the SparkRunner.
 */
public class SparkGlobalCombineFn<InputT, AccumT, OutputT> extends SparkAbstractCombineFn {
    private final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn;

    public SparkGlobalCombineFn(CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn,
            SparkRuntimeContext runtimeContext,
            Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
            WindowingStrategy<?, ?> windowingStrategy) {
        super(runtimeContext, sideInputs, windowingStrategy);
        this.combineFn = combineFn;
    }

    /**
     * Implements Spark's zeroValue function in:
     * <p>
     * {@link org.apache.spark.api.java.JavaRDD#aggregate}.
     * </p>
     */
    Iterable<WindowedValue<AccumT>> zeroValue() {
        return Lists.newArrayList();
    }

    private Iterable<WindowedValue<AccumT>> createAccumulator(WindowedValue<InputT> input) {

        // sort exploded inputs.
        Iterable<WindowedValue<InputT>> sortedInputs = sortByWindows(input.explodeWindows());

        TimestampCombiner timestampCombiner = windowingStrategy.getTimestampCombiner();
        WindowFn<?, BoundedWindow> windowFn = windowingStrategy.getWindowFn();

        //--- inputs iterator, by window order.
        final Iterator<WindowedValue<InputT>> iterator = sortedInputs.iterator();
        WindowedValue<InputT> currentInput = iterator.next();
        BoundedWindow currentWindow = Iterables.getFirst(currentInput.getWindows(), null);

        // first create the accumulator and accumulate first input.
        AccumT accumulator = combineFn.createAccumulator(ctxtForInput(currentInput));
        accumulator = combineFn.addInput(accumulator, currentInput.getValue(), ctxtForInput(currentInput));

        // keep track of the timestamps assigned by the TimestampCombiner.
        Instant windowTimestamp = timestampCombiner.assign(currentWindow,
                windowingStrategy.getWindowFn().getOutputTime(currentInput.getTimestamp(), currentWindow));

        // accumulate the next windows, or output.
        List<WindowedValue<AccumT>> output = Lists.newArrayList();

        // if merging, merge overlapping windows, e.g. Sessions.
        final boolean merging = !windowingStrategy.getWindowFn().isNonMerging();

        while (iterator.hasNext()) {
            WindowedValue<InputT> nextValue = iterator.next();
            BoundedWindow nextWindow = Iterables.getOnlyElement(nextValue.getWindows());

            boolean mergingAndIntersecting = merging
                    && isIntersecting((IntervalWindow) currentWindow, (IntervalWindow) nextWindow);

            if (mergingAndIntersecting || nextWindow.equals(currentWindow)) {
                if (mergingAndIntersecting) {
                    // merge intersecting windows.
                    currentWindow = merge((IntervalWindow) currentWindow, (IntervalWindow) nextWindow);
                }
                // keep accumulating and carry on ;-)
                accumulator = combineFn.addInput(accumulator, nextValue.getValue(), ctxtForInput(nextValue));
                windowTimestamp = timestampCombiner.merge(currentWindow, windowTimestamp,
                        windowingStrategy.getWindowFn().getOutputTime(nextValue.getTimestamp(), currentWindow));
            } else {
                // moving to the next window, first add the current accumulation to output
                // and initialize the accumulator.
                output.add(WindowedValue.of(accumulator, windowTimestamp, currentWindow, PaneInfo.NO_FIRING));
                // re-init accumulator, window and timestamp.
                accumulator = combineFn.createAccumulator(ctxtForInput(nextValue));
                accumulator = combineFn.addInput(accumulator, nextValue.getValue(), ctxtForInput(nextValue));
                currentWindow = nextWindow;
                windowTimestamp = timestampCombiner.assign(currentWindow,
                        windowFn.getOutputTime(nextValue.getTimestamp(), currentWindow));
            }
        }

        // add last accumulator to the output.
        output.add(WindowedValue.of(accumulator, windowTimestamp, currentWindow, PaneInfo.NO_FIRING));

        return output;
    }

    /**
     * Implement Spark's seqOp function in:
     * <p>
     * {@link org.apache.spark.api.java.JavaRDD#aggregate}.
     * </p>
     */
    Iterable<WindowedValue<AccumT>> seqOp(Iterable<WindowedValue<AccumT>> accum, WindowedValue<InputT> input) {
        return combOp(accum, createAccumulator(input));
    }

    /**
     * Implement Spark's combOp function in:
     * <p>
     * {@link org.apache.spark.api.java.JavaRDD#aggregate}.
     * </p>
     */
    Iterable<WindowedValue<AccumT>> combOp(Iterable<WindowedValue<AccumT>> a1, Iterable<WindowedValue<AccumT>> a2) {

        // concatenate accumulators.
        Iterable<WindowedValue<AccumT>> accumulators = Iterables.concat(a1, a2);
        // if empty, return an empty accumulators iterable.
        if (!accumulators.iterator().hasNext()) {
            return Lists.newArrayList();
        }

        // sort accumulators, no need to explode since inputs were exploded.
        Iterable<WindowedValue<AccumT>> sortedAccumulators = sortByWindows(accumulators);

        @SuppressWarnings("unchecked")
        TimestampCombiner timestampCombiner = windowingStrategy.getTimestampCombiner();

        //--- accumulators iterator, by window order.
        final Iterator<WindowedValue<AccumT>> iterator = sortedAccumulators.iterator();

        // get the first accumulator and assign it to the current window's accumulators.
        WindowedValue<AccumT> currentValue = iterator.next();
        BoundedWindow currentWindow = Iterables.getFirst(currentValue.getWindows(), null);
        List<AccumT> currentWindowAccumulators = Lists.newArrayList();
        currentWindowAccumulators.add(currentValue.getValue());

        // keep track of the timestamps assigned by the TimestampCombiner,
        // in createCombiner we already merge the timestamps assigned
        // to individual elements, here we will just merge them.
        List<Instant> windowTimestamps = Lists.newArrayList();
        windowTimestamps.add(currentValue.getTimestamp());

        // accumulate the next windows, or output.
        List<WindowedValue<AccumT>> output = Lists.newArrayList();

        // if merging, merge overlapping windows, e.g. Sessions.
        final boolean merging = !windowingStrategy.getWindowFn().isNonMerging();

        while (iterator.hasNext()) {
            WindowedValue<AccumT> nextValue = iterator.next();
            BoundedWindow nextWindow = Iterables.getOnlyElement(nextValue.getWindows());

            boolean mergingAndIntersecting = merging
                    && isIntersecting((IntervalWindow) currentWindow, (IntervalWindow) nextWindow);

            if (mergingAndIntersecting || nextWindow.equals(currentWindow)) {
                if (mergingAndIntersecting) {
                    // merge intersecting windows.
                    currentWindow = merge((IntervalWindow) currentWindow, (IntervalWindow) nextWindow);
                }
                // add to window accumulators.
                currentWindowAccumulators.add(nextValue.getValue());
                windowTimestamps.add(nextValue.getTimestamp());
            } else {
                // before moving to the next window,
                // add the current accumulation to the output and initialize the accumulation.

                // merge the timestamps of all accumulators to merge.
                Instant mergedTimestamp = timestampCombiner.merge(currentWindow, windowTimestamps);

                // merge accumulators.
                // transforming a KV<K, Iterable<AccumT>> into a KV<K, Iterable<AccumT>>.
                // for the (possibly merged) window.
                Iterable<AccumT> accumsToMerge = Iterables.unmodifiableIterable(currentWindowAccumulators);
                WindowedValue<Iterable<AccumT>> preMergeWindowedValue = WindowedValue.of(accumsToMerge,
                        mergedTimestamp, currentWindow, PaneInfo.NO_FIRING);
                // applying the actual combiner onto the accumulators.
                AccumT accumulated = combineFn.mergeAccumulators(accumsToMerge,
                        ctxtForInput(preMergeWindowedValue));
                WindowedValue<AccumT> postMergeWindowedValue = preMergeWindowedValue.withValue(accumulated);
                // emit the accumulated output.
                output.add(postMergeWindowedValue);

                // re-init accumulator, window and timestamps.
                currentWindowAccumulators.clear();
                currentWindowAccumulators.add(nextValue.getValue());
                currentWindow = nextWindow;
                windowTimestamps.clear();
                windowTimestamps.add(nextValue.getTimestamp());
            }
        }

        // merge the last chunk of accumulators.
        Instant mergedTimestamp = timestampCombiner.merge(currentWindow, windowTimestamps);
        Iterable<AccumT> accumsToMerge = Iterables.unmodifiableIterable(currentWindowAccumulators);
        WindowedValue<Iterable<AccumT>> preMergeWindowedValue = WindowedValue.of(accumsToMerge, mergedTimestamp,
                currentWindow, PaneInfo.NO_FIRING);
        AccumT accumulated = combineFn.mergeAccumulators(accumsToMerge, ctxtForInput(preMergeWindowedValue));
        WindowedValue<AccumT> postMergeWindowedValue = preMergeWindowedValue.withValue(accumulated);
        output.add(postMergeWindowedValue);

        return output;
    }

    Iterable<WindowedValue<OutputT>> extractOutput(Iterable<WindowedValue<AccumT>> wvas) {
        return Iterables.transform(wvas, new Function<WindowedValue<AccumT>, WindowedValue<OutputT>>() {
            @Nullable
            @Override
            public WindowedValue<OutputT> apply(@Nullable WindowedValue<AccumT> wva) {
                if (wva == null) {
                    return null;
                }
                return wva.withValue(combineFn.extractOutput(wva.getValue(), ctxtForInput(wva)));
            }
        });
    }
}