org.apache.beam.runners.core.construction.ParDos.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.core.construction.ParDos.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.runners.core.construction;

import static com.google.common.base.Preconditions.checkArgument;

import com.google.auto.value.AutoValue;
import com.google.common.base.Optional;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.BytesValue;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.construction.PTransforms.TransformPayloadTranslator;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.FunctionSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.ParDoPayload;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Parameter.Type;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput.Builder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.StateSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.TimerSpec;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;

/**
 * Utilities for interacting with {@link ParDo} instances and {@link ParDoPayload} protos.
 */
public class ParDos {
    /**
     * The URN for a {@link ParDoPayload}.
     */
    public static final String PAR_DO_PAYLOAD_URN = "urn:beam:pardo:v1";
    /**
     * The URN for an unknown Java {@link DoFn}.
     */
    public static final String CUSTOM_JAVA_DO_FN_URN = "urn:beam:dofn:javasdk:0.1";
    /**
     * The URN for an unknown Java {@link ViewFn}.
     */
    public static final String CUSTOM_JAVA_VIEW_FN_URN = "urn:beam:viewfn:javasdk:0.1";
    /**
     * The URN for an unknown Java {@link WindowMappingFn}.
     */
    public static final String CUSTOM_JAVA_WINDOW_MAPPING_FN_URN = "urn:beam:windowmappingfn:javasdk:0.1";

    /**
     * A {@link TransformPayloadTranslator} for {@link ParDo}.
     */
    public static class ParDoPayloadTranslator
            implements PTransforms.TransformPayloadTranslator<ParDo.MultiOutput<?, ?>> {
        public static TransformPayloadTranslator create() {
            return new ParDoPayloadTranslator();
        }

        private ParDoPayloadTranslator() {
        }

        @Override
        public FunctionSpec translate(AppliedPTransform<?, ?, MultiOutput<?, ?>> transform,
                SdkComponents components) {
            ParDoPayload payload = toProto(transform.getTransform(), components);
            return RunnerApi.FunctionSpec.newBuilder().setUrn(PAR_DO_PAYLOAD_URN).setParameter(Any.pack(payload))
                    .build();
        }
    }

    public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components) {
        DoFnSignature signature = DoFnSignatures.getSignature(parDo.getFn().getClass());
        Map<String, StateDeclaration> states = signature.stateDeclarations();
        Map<String, TimerDeclaration> timers = signature.timerDeclarations();
        List<Parameter> parameters = signature.processElement().extraParameters();

        ParDoPayload.Builder builder = ParDoPayload.newBuilder();
        builder.setDoFn(toProto(parDo.getFn(), parDo.getMainOutputTag()));
        for (PCollectionView<?> sideInput : parDo.getSideInputs()) {
            builder.putSideInputs(sideInput.getTagInternal().getId(), toProto(sideInput));
        }
        for (Parameter parameter : parameters) {
            Optional<RunnerApi.Parameter> protoParameter = toProto(parameter);
            if (protoParameter.isPresent()) {
                builder.addParameters(protoParameter.get());
            }
        }
        for (Map.Entry<String, StateDeclaration> state : states.entrySet()) {
            StateSpec spec = toProto(state.getValue());
            builder.putStateSpecs(state.getKey(), spec);
        }
        for (Map.Entry<String, TimerDeclaration> timer : timers.entrySet()) {
            TimerSpec spec = toProto(timer.getValue());
            builder.putTimerSpecs(timer.getKey(), spec);
        }
        return builder.build();
    }

    public static DoFn<?, ?> getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException {
        return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn();
    }

    public static TupleTag<?> getMainOutputTag(ParDoPayload payload) throws InvalidProtocolBufferException {
        return doFnAndMainOutputTagFromProto(payload.getDoFn()).getMainOutputTag();
    }

    // TODO: Implement
    private static StateSpec toProto(StateDeclaration state) {
        throw new UnsupportedOperationException("Not yet supported");
    }

    // TODO: Implement
    private static TimerSpec toProto(TimerDeclaration timer) {
        throw new UnsupportedOperationException("Not yet supported");
    }

    @AutoValue
    abstract static class DoFnAndMainOutput implements Serializable {
        public static DoFnAndMainOutput of(DoFn<?, ?> fn, TupleTag<?> tag) {
            return new AutoValue_ParDos_DoFnAndMainOutput(fn, tag);
        }

        abstract DoFn<?, ?> getDoFn();

        abstract TupleTag<?> getMainOutputTag();
    }

    private static SdkFunctionSpec toProto(DoFn<?, ?> fn, TupleTag<?> tag) {
        return SdkFunctionSpec.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn(CUSTOM_JAVA_DO_FN_URN)
                .setParameter(Any.pack(BytesValue.newBuilder()
                        .setValue(ByteString
                                .copyFrom(SerializableUtils.serializeToByteArray(DoFnAndMainOutput.of(fn, tag))))
                        .build())))
                .build();
    }

    private static DoFnAndMainOutput doFnAndMainOutputTagFromProto(SdkFunctionSpec fnSpec)
            throws InvalidProtocolBufferException {
        checkArgument(fnSpec.getSpec().getUrn().equals(CUSTOM_JAVA_DO_FN_URN));
        byte[] serializedFn = fnSpec.getSpec().getParameter().unpack(BytesValue.class).getValue().toByteArray();
        return (DoFnAndMainOutput) SerializableUtils.deserializeFromByteArray(serializedFn,
                "Custom DoFn And Main Output tag");
    }

    private static Optional<RunnerApi.Parameter> toProto(Parameter parameter) {
        return parameter.match(new Cases.WithDefault<Optional<RunnerApi.Parameter>>() {
            @Override
            public Optional<RunnerApi.Parameter> dispatch(WindowParameter p) {
                return Optional.of(RunnerApi.Parameter.newBuilder().setType(Type.WINDOW).build());
            }

            @Override
            public Optional<RunnerApi.Parameter> dispatch(RestrictionTrackerParameter p) {
                return Optional.of(RunnerApi.Parameter.newBuilder().setType(Type.RESTRICTION_TRACKER).build());
            }

            @Override
            protected Optional<RunnerApi.Parameter> dispatchDefault(Parameter p) {
                return Optional.absent();
            }
        });
    }

    private static SideInput toProto(PCollectionView<?> view) {
        Builder builder = SideInput.newBuilder();
        builder.setAccessPattern(
                FunctionSpec.newBuilder().setUrn(view.getViewFn().getMaterialization().getUrn()).build());
        builder.setViewFn(toProto(view.getViewFn()));
        builder.setWindowMappingFn(toProto(view.getWindowMappingFn()));
        return builder.build();
    }

    public static PCollectionView<?> fromProto(SideInput sideInput, String id, RunnerApi.PTransform parDoTransform,
            Components components) throws IOException {
        TupleTag<?> tag = new TupleTag<>(id);
        WindowMappingFn<?> windowMappingFn = windowMappingFnFromProto(sideInput.getWindowMappingFn());
        ViewFn<?, ?> viewFn = viewFnFromProto(sideInput.getViewFn());

        RunnerApi.PCollection inputCollection = components
                .getPcollectionsOrThrow(parDoTransform.getInputsOrThrow(id));
        WindowingStrategy<?, ?> windowingStrategy = WindowingStrategies.fromProto(
                components.getWindowingStrategiesOrThrow(inputCollection.getWindowingStrategyId()), components);
        Coder<?> elemCoder = Coders.fromProto(components.getCodersOrThrow(inputCollection.getCoderId()),
                components);
        Coder<Iterable<WindowedValue<?>>> coder = (Coder) IterableCoder
                .of(FullWindowedValueCoder.of(elemCoder, windowingStrategy.getWindowFn().windowCoder()));
        checkArgument(sideInput.getAccessPattern().getUrn().equals(Materializations.ITERABLE_MATERIALIZATION_URN),
                "Unknown View Materialization URN %s", sideInput.getAccessPattern().getUrn());

        PCollectionView<?> view = new RunnerPCollectionView<>((TupleTag<Iterable<WindowedValue<?>>>) tag,
                (ViewFn<Iterable<WindowedValue<?>>, ?>) viewFn, windowMappingFn, windowingStrategy, coder);
        return view;
    }

    private static SdkFunctionSpec toProto(ViewFn<?, ?> viewFn) {
        return SdkFunctionSpec.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn(CUSTOM_JAVA_VIEW_FN_URN)
                .setParameter(Any.pack(BytesValue.newBuilder()
                        .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(viewFn))).build())))
                .build();
    }

    private static ViewFn<?, ?> viewFnFromProto(SdkFunctionSpec viewFn) throws InvalidProtocolBufferException {
        FunctionSpec spec = viewFn.getSpec();
        checkArgument(spec.getUrn().equals(CUSTOM_JAVA_VIEW_FN_URN), "Can't deserialize unknown %s type %s",
                ViewFn.class.getSimpleName(), spec.getUrn());
        return (ViewFn<?, ?>) SerializableUtils.deserializeFromByteArray(
                spec.getParameter().unpack(BytesValue.class).getValue().toByteArray(), "Custom ViewFn");
    }

    private static SdkFunctionSpec toProto(WindowMappingFn<?> windowMappingFn) {
        return SdkFunctionSpec.newBuilder().setSpec(FunctionSpec.newBuilder()
                .setUrn(CUSTOM_JAVA_WINDOW_MAPPING_FN_URN)
                .setParameter(Any.pack(BytesValue.newBuilder()
                        .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(windowMappingFn)))
                        .build())))
                .build();
    }

    private static WindowMappingFn<?> windowMappingFnFromProto(SdkFunctionSpec windowMappingFn)
            throws InvalidProtocolBufferException {
        FunctionSpec spec = windowMappingFn.getSpec();
        checkArgument(spec.getUrn().equals(CUSTOM_JAVA_WINDOW_MAPPING_FN_URN),
                "Can't deserialize unknown %s type %s", WindowMappingFn.class.getSimpleName(), spec.getUrn());
        return (WindowMappingFn<?>) SerializableUtils.deserializeFromByteArray(
                spec.getParameter().unpack(BytesValue.class).getValue().toByteArray(), "Custom WinodwMappingFn");
    }
}