org.apache.beam.fn.harness.control.ProcessBundleHandler.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.fn.harness.control.ProcessBundleHandler.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.fn.harness.control;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.getOnlyElement;

import com.google.common.collect.Collections2;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.protobuf.ByteString;
import com.google.protobuf.BytesValue;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fake.FakeStepContext;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
import org.apache.beam.runners.core.BeamFnDataReadRunner;
import org.apache.beam.runners.core.BeamFnDataWriteRunner;
import org.apache.beam.runners.core.BoundedSourceRunner;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.DoFnRunners.OutputManager;
import org.apache.beam.runners.core.NullSideInputReader;
import org.apache.beam.runners.dataflow.util.DoFnInfo;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Processes {@link org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest}s by materializing
 * the set of required runners for each {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec},
 * wiring them together based upon the {@code input} and {@code output} map definitions.
 *
 * <p>Finally executes the DAG based graph by starting all runners in reverse topological order,
 * and finishing all runners in forward topological order.
 */
public class ProcessBundleHandler {
    // TODO: What should the initial set of URNs be?
    private static final String DATA_INPUT_URN = "urn:org.apache.beam:source:runner:0.1";
    private static final String DATA_OUTPUT_URN = "urn:org.apache.beam:sink:runner:0.1";
    private static final String JAVA_DO_FN_URN = "urn:org.apache.beam:dofn:java:0.1";
    private static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1";

    private static final Logger LOG = LoggerFactory.getLogger(ProcessBundleHandler.class);

    private final PipelineOptions options;
    private final Function<String, Message> fnApiRegistry;
    private final BeamFnDataClient beamFnDataClient;

    public ProcessBundleHandler(PipelineOptions options, Function<String, Message> fnApiRegistry,
            BeamFnDataClient beamFnDataClient) {
        this.options = options;
        this.fnApiRegistry = fnApiRegistry;
        this.beamFnDataClient = beamFnDataClient;
    }

    protected <InputT, OutputT> void createConsumersForPrimitiveTransform(
            BeamFnApi.PrimitiveTransform primitiveTransform, Supplier<String> processBundleInstructionId,
            Function<BeamFnApi.Target, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> consumers,
            BiConsumer<BeamFnApi.Target, ThrowingConsumer<WindowedValue<InputT>>> addConsumer,
            Consumer<ThrowingRunnable> addStartFunction, Consumer<ThrowingRunnable> addFinishFunction)
            throws IOException {

        BeamFnApi.FunctionSpec functionSpec = primitiveTransform.getFunctionSpec();

        // For every output PCollection, create a map from output name to Consumer
        ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMapBuilder = ImmutableMap
                .builder();
        for (Map.Entry<String, BeamFnApi.PCollection> entry : primitiveTransform.getOutputsMap().entrySet()) {
            outputMapBuilder.put(entry.getKey(), consumers.apply(BeamFnApi.Target.newBuilder()
                    .setPrimitiveTransformReference(primitiveTransform.getId()).setName(entry.getKey()).build()));
        }
        ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap = outputMapBuilder
                .build();

        // Based upon the function spec, populate the start/finish/consumer information.
        ThrowingConsumer<WindowedValue<InputT>> consumer;
        switch (functionSpec.getUrn()) {
        default:
            BeamFnApi.Target target;
            BeamFnApi.Coder coderSpec;
            throw new IllegalArgumentException(String.format("Unknown FunctionSpec %s", functionSpec));

        case DATA_OUTPUT_URN:
            target = BeamFnApi.Target.newBuilder().setPrimitiveTransformReference(primitiveTransform.getId())
                    .setName(getOnlyElement(primitiveTransform.getOutputsMap().keySet())).build();
            coderSpec = (BeamFnApi.Coder) fnApiRegistry
                    .apply(getOnlyElement(primitiveTransform.getOutputsMap().values()).getCoderReference());
            BeamFnDataWriteRunner<InputT> remoteGrpcWriteRunner = new BeamFnDataWriteRunner<>(functionSpec,
                    processBundleInstructionId, target, coderSpec, beamFnDataClient);
            addStartFunction.accept(remoteGrpcWriteRunner::registerForOutput);
            consumer = remoteGrpcWriteRunner::consume;
            addFinishFunction.accept(remoteGrpcWriteRunner::close);
            break;

        case DATA_INPUT_URN:
            target = BeamFnApi.Target.newBuilder().setPrimitiveTransformReference(primitiveTransform.getId())
                    .setName(getOnlyElement(primitiveTransform.getInputsMap().keySet())).build();
            coderSpec = (BeamFnApi.Coder) fnApiRegistry
                    .apply(getOnlyElement(primitiveTransform.getOutputsMap().values()).getCoderReference());
            BeamFnDataReadRunner<OutputT> remoteGrpcReadRunner = new BeamFnDataReadRunner<>(functionSpec,
                    processBundleInstructionId, target, coderSpec, beamFnDataClient, outputMap);
            addStartFunction.accept(remoteGrpcReadRunner::registerInputLocation);
            consumer = null;
            addFinishFunction.accept(remoteGrpcReadRunner::blockTillReadFinishes);
            break;

        case JAVA_DO_FN_URN:
            DoFnRunner<InputT, OutputT> doFnRunner = createDoFnRunner(functionSpec, outputMap);
            addStartFunction.accept(doFnRunner::startBundle);
            addFinishFunction.accept(doFnRunner::finishBundle);
            consumer = doFnRunner::processElement;
            break;

        case JAVA_SOURCE_URN:
            @SuppressWarnings({ "unchecked", "rawtypes" })
            BoundedSourceRunner<BoundedSource<OutputT>, OutputT> sourceRunner = createBoundedSourceRunner(
                    functionSpec, outputMap);
            @SuppressWarnings({ "unchecked", "rawtypes" })
            ThrowingConsumer<WindowedValue<?>> sourceConsumer = (ThrowingConsumer) (ThrowingConsumer<WindowedValue<BoundedSource<OutputT>>>) sourceRunner::runReadLoop;
            // TODO: Remove and replace with source being sent across gRPC port
            addStartFunction.accept(sourceRunner::start);
            consumer = (ThrowingConsumer) sourceConsumer;
            break;
        }

        if (consumer != null) {
            for (Map.Entry<String, BeamFnApi.Target.List> entry : primitiveTransform.getInputsMap().entrySet()) {
                for (BeamFnApi.Target target : entry.getValue().getTargetList()) {
                    addConsumer.accept(target, consumer);
                }
            }
        }
    }

    public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request)
            throws Exception {
        BeamFnApi.InstructionResponse.Builder response = BeamFnApi.InstructionResponse.newBuilder()
                .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());

        String bundleId = request.getProcessBundle().getProcessBundleDescriptorReference();
        BeamFnApi.ProcessBundleDescriptor bundleDescriptor = (BeamFnApi.ProcessBundleDescriptor) fnApiRegistry
                .apply(bundleId);

        Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<Object>>> outputTargetToConsumer = HashMultimap
                .create();
        List<ThrowingRunnable> startFunctions = new ArrayList<>();
        List<ThrowingRunnable> finishFunctions = new ArrayList<>();
        // We process the primitive transform list in reverse order
        // because we assume that the runner provides it in topologically order.
        // This means that all the start/finish functions will be in reverse topological order.
        for (BeamFnApi.PrimitiveTransform primitiveTransform : Lists
                .reverse(bundleDescriptor.getPrimitiveTransformList())) {
            createConsumersForPrimitiveTransform(primitiveTransform, request::getInstructionId,
                    outputTargetToConsumer::get, outputTargetToConsumer::put, startFunctions::add,
                    finishFunctions::add);
        }

        // Already in reverse order so we don't need to do anything.
        for (ThrowingRunnable startFunction : startFunctions) {
            LOG.debug("Starting function {}", startFunction);
            startFunction.run();
        }

        // Need to reverse this since we want to call finish in topological order.
        for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) {
            LOG.debug("Finishing function {}", finishFunction);
            finishFunction.run();
        }

        return response;
    }

    /**
     * Converts a {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec} into a {@link DoFnRunner}.
     */
    private <InputT, OutputT> DoFnRunner<InputT, OutputT> createDoFnRunner(BeamFnApi.FunctionSpec functionSpec,
            Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) {
        ByteString serializedFn;
        try {
            serializedFn = functionSpec.getData().unpack(BytesValue.class).getValue();
        } catch (InvalidProtocolBufferException e) {
            throw new IllegalArgumentException(String.format("Unable to unwrap DoFn %s", functionSpec), e);
        }
        DoFnInfo<?, ?> doFnInfo = (DoFnInfo<?, ?>) SerializableUtils
                .deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo");

        checkArgument(
                Objects.equals(new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)),
                        doFnInfo.getOutputMap().keySet()),
                "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.",
                outputMap.keySet(), doFnInfo.getOutputMap());

        ImmutableMultimap.Builder<TupleTag<?>, ThrowingConsumer<WindowedValue<OutputT>>> tagToOutput = ImmutableMultimap
                .builder();
        for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) {
            tagToOutput.putAll(entry.getValue(), outputMap.get(Long.toString(entry.getKey())));
        }
        @SuppressWarnings({ "unchecked", "rawtypes" })
        final Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tagBasedOutputMap = (Map) tagToOutput
                .build().asMap();

        OutputManager outputManager = new OutputManager() {
            Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tupleTagToOutput = tagBasedOutputMap;

            @Override
            public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
                try {
                    Collection<ThrowingConsumer<WindowedValue<?>>> consumers = tupleTagToOutput.get(tag);
                    if (consumers == null) {
                        /* This is a normal case, e.g., if a DoFn has output but that output is not
                         * consumed. Drop the output. */
                        return;
                    }
                    for (ThrowingConsumer<WindowedValue<?>> consumer : consumers) {
                        consumer.accept(output);
                    }
                } catch (Throwable t) {
                    throw new RuntimeException(t);
                }
            }
        };

        @SuppressWarnings({ "unchecked", "rawtypes", "deprecation" })
        DoFnRunner<InputT, OutputT> runner = DoFnRunners.simpleRunner(PipelineOptionsFactory.create(), /* TODO */
                (DoFn) doFnInfo.getDoFn(), NullSideInputReader.empty(), /* TODO */
                outputManager, (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()),
                new ArrayList<>(doFnInfo.getOutputMap().values()), new FakeStepContext(),
                (WindowingStrategy) doFnInfo.getWindowingStrategy());
        return runner;
    }

    private <InputT extends BoundedSource<OutputT>, OutputT> BoundedSourceRunner<InputT, OutputT> createBoundedSourceRunner(
            BeamFnApi.FunctionSpec functionSpec,
            Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) {

        @SuppressWarnings({ "rawtypes", "unchecked" })
        BoundedSourceRunner<InputT, OutputT> runner = new BoundedSourceRunner(options, functionSpec, outputMap);
        return runner;
    }
}