org.apache.beam.runners.apex.translation.ParDoTranslator.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.apex.translation.ParDoTranslator.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.runners.apex.translation;

import static com.google.common.base.Preconditions.checkArgument;

import com.datatorrent.api.Operator;
import com.datatorrent.api.Operator.OutputPort;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.beam.runners.apex.ApexRunner;
import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator;
import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessElements;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * {@link ParDo.MultiOutput} is translated to {@link ApexParDoOperator} that wraps the {@link DoFn}.
 */
class ParDoTranslator<InputT, OutputT> implements TransformTranslator<ParDo.MultiOutput<InputT, OutputT>> {
    private static final long serialVersionUID = 1L;
    private static final Logger LOG = LoggerFactory.getLogger(ParDoTranslator.class);

    @Override
    public void translate(ParDo.MultiOutput<InputT, OutputT> transform, TranslationContext context) {
        DoFn<InputT, OutputT> doFn = transform.getFn();
        DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());

        if (signature.processElement().isSplittable()) {
            throw new UnsupportedOperationException(String.format("%s does not support splittable DoFn: %s",
                    ApexRunner.class.getSimpleName(), doFn));
        }
        if (signature.stateDeclarations().size() > 0) {
            throw new UnsupportedOperationException(
                    String.format("Found %s annotations on %s, but %s cannot yet be used with state in the %s.",
                            DoFn.StateId.class.getSimpleName(), doFn.getClass().getName(),
                            DoFn.class.getSimpleName(), ApexRunner.class.getSimpleName()));
        }

        if (signature.timerDeclarations().size() > 0) {
            throw new UnsupportedOperationException(
                    String.format("Found %s annotations on %s, but %s cannot yet be used with timers in the %s.",
                            DoFn.TimerId.class.getSimpleName(), doFn.getClass().getName(),
                            DoFn.class.getSimpleName(), ApexRunner.class.getSimpleName()));
        }

        Map<TupleTag<?>, PValue> outputs = context.getOutputs();
        PCollection<InputT> input = context.getInput();
        List<PCollectionView<?>> sideInputs = transform.getSideInputs();
        Coder<InputT> inputCoder = input.getCoder();
        WindowedValueCoder<InputT> wvInputCoder = FullWindowedValueCoder.of(inputCoder,
                input.getWindowingStrategy().getWindowFn().windowCoder());

        ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>(context.getPipelineOptions(), doFn,
                transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(),
                input.getWindowingStrategy(), sideInputs, wvInputCoder, context.getStateBackend());

        Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size());
        for (Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
            checkArgument(output.getValue() instanceof PCollection, "%s %s outputs non-PCollection %s of type %s",
                    ParDo.MultiOutput.class.getSimpleName(), context.getFullName(), output.getValue(),
                    output.getValue().getClass().getSimpleName());
            PCollection<?> pc = (PCollection<?>) output.getValue();
            if (output.getKey().equals(transform.getMainOutputTag())) {
                ports.put(pc, operator.output);
            } else {
                int portIndex = 0;
                for (TupleTag<?> tag : transform.getAdditionalOutputTags().getAll()) {
                    if (tag.equals(output.getKey())) {
                        ports.put(pc, operator.additionalOutputPorts[portIndex]);
                        break;
                    }
                    portIndex++;
                }
            }
        }
        context.addOperator(operator, ports);
        context.addStream(context.getInput(), operator.input);
        if (!sideInputs.isEmpty()) {
            addSideInputs(operator.sideInput1, sideInputs, context);
        }
    }

    static class SplittableProcessElementsTranslator<InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT>>
            implements TransformTranslator<ProcessElements<InputT, OutputT, RestrictionT, TrackerT>> {

        @Override
        public void translate(ProcessElements<InputT, OutputT, RestrictionT, TrackerT> transform,
                TranslationContext context) {

            Map<TupleTag<?>, PValue> outputs = context.getOutputs();
            PCollection<InputT> input = context.getInput();
            List<PCollectionView<?>> sideInputs = transform.getSideInputs();
            Coder<InputT> inputCoder = input.getCoder();
            WindowedValueCoder<InputT> wvInputCoder = FullWindowedValueCoder.of(inputCoder,
                    input.getWindowingStrategy().getWindowFn().windowCoder());

            @SuppressWarnings({ "rawtypes", "unchecked" })
            DoFn<InputT, OutputT> doFn = (DoFn) transform.newProcessFn(transform.getFn());
            ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>(context.getPipelineOptions(),
                    doFn, transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(),
                    input.getWindowingStrategy(), sideInputs, wvInputCoder, context.getStateBackend());

            Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size());
            for (Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
                checkArgument(output.getValue() instanceof PCollection,
                        "%s %s outputs non-PCollection %s of type %s", ParDo.MultiOutput.class.getSimpleName(),
                        context.getFullName(), output.getValue(), output.getValue().getClass().getSimpleName());
                PCollection<?> pc = (PCollection<?>) output.getValue();
                if (output.getKey().equals(transform.getMainOutputTag())) {
                    ports.put(pc, operator.output);
                } else {
                    int portIndex = 0;
                    for (TupleTag<?> tag : transform.getAdditionalOutputTags().getAll()) {
                        if (tag.equals(output.getKey())) {
                            ports.put(pc, operator.additionalOutputPorts[portIndex]);
                            break;
                        }
                        portIndex++;
                    }
                }
            }

            context.addOperator(operator, ports);
            context.addStream(context.getInput(), operator.input);
            if (!sideInputs.isEmpty()) {
                addSideInputs(operator.sideInput1, sideInputs, context);
            }

        }
    }

    static void addSideInputs(Operator.InputPort<?> sideInputPort, List<PCollectionView<?>> sideInputs,
            TranslationContext context) {
        Operator.InputPort<?>[] sideInputPorts = { sideInputPort };
        if (sideInputs.size() > sideInputPorts.length) {
            PCollection<?> unionCollection = unionSideInputs(sideInputs, context);
            context.addStream(unionCollection, sideInputPorts[0]);
        } else {
            // the number of ports for side inputs is fixed and each port can only take one input.
            for (int i = 0; i < sideInputs.size(); i++) {
                context.addStream(context.getViewInput(sideInputs.get(i)), sideInputPorts[i]);
            }
        }
    }

    private static PCollection<?> unionSideInputs(List<PCollectionView<?>> sideInputs, TranslationContext context) {
        checkArgument(sideInputs.size() > 1, "requires multiple side inputs");
        // flatten and assign union tag
        List<PCollection<Object>> sourceCollections = new ArrayList<>();
        Map<PCollection<?>, Integer> unionTags = new HashMap<>();
        PCollection<Object> firstSideInput = context.getViewInput(sideInputs.get(0));
        for (int i = 0; i < sideInputs.size(); i++) {
            PCollectionView<?> sideInput = sideInputs.get(i);
            PCollection<?> sideInputCollection = context.getViewInput(sideInput);
            if (!sideInputCollection.getWindowingStrategy().equals(firstSideInput.getWindowingStrategy())) {
                // TODO: check how to handle this in stream codec
                //String msg = "Multiple side inputs with different window strategies.";
                //throw new UnsupportedOperationException(msg);
                LOG.warn("Side inputs union with different windowing strategies {} {}",
                        firstSideInput.getWindowingStrategy(), sideInputCollection.getWindowingStrategy());
            }
            if (!sideInputCollection.getCoder().equals(firstSideInput.getCoder())) {
                String msg = "Multiple side inputs with different coders.";
                throw new UnsupportedOperationException(msg);
            }
            sourceCollections.add(context.<PCollection<Object>>getViewInput(sideInput));
            unionTags.put(sideInputCollection, i);
        }

        PCollection<Object> resultCollection = FlattenPCollectionTranslator.intermediateCollection(firstSideInput,
                firstSideInput.getCoder());
        FlattenPCollectionTranslator.flattenCollections(sourceCollections, unionTags, resultCollection, context);
        return resultCollection;
    }
}