org.apache.beam.runners.direct.ExecutorServiceParallelExecutor.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.direct.ExecutorServiceParallelExecutor.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.direct;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.runners.local.ExecutionDriver;
import org.apache.beam.runners.local.ExecutionDriver.DriverState;
import org.apache.beam.runners.local.PipelineMessageReceiver;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult.State;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheLoader;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalListener;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * An {@link PipelineExecutor} that uses an underlying {@link ExecutorService} and {@link
 * EvaluationContext} to execute a {@link Pipeline}.
 */
final class ExecutorServiceParallelExecutor implements PipelineExecutor,
        BundleProcessor<PCollection<?>, CommittedBundle<?>, AppliedPTransform<?, ?, ?>> {
    private static final Logger LOG = LoggerFactory.getLogger(ExecutorServiceParallelExecutor.class);

    private final int targetParallelism;
    private final ExecutorService executorService;

    private final TransformEvaluatorRegistry registry;

    private final EvaluationContext evaluationContext;

    private final TransformExecutorFactory executorFactory;
    private final TransformExecutorService parallelExecutorService;
    private final LoadingCache<StepAndKey, TransformExecutorService> serialExecutorServices;

    private final QueueMessageReceiver visibleUpdates;

    private final ExecutorService metricsExecutor;

    private AtomicReference<State> pipelineState = new AtomicReference<>(State.RUNNING);

    public static ExecutorServiceParallelExecutor create(int targetParallelism, TransformEvaluatorRegistry registry,
            Map<String, Collection<ModelEnforcementFactory>> transformEnforcements, EvaluationContext context,
            ExecutorService metricsExecutor) {
        return new ExecutorServiceParallelExecutor(targetParallelism, registry, transformEnforcements, context,
                metricsExecutor);
    }

    private ExecutorServiceParallelExecutor(int targetParallelism, TransformEvaluatorRegistry registry,
            Map<String, Collection<ModelEnforcementFactory>> transformEnforcements, EvaluationContext context,
            ExecutorService metricsExecutor) {
        this.targetParallelism = targetParallelism;
        this.metricsExecutor = metricsExecutor;
        // Don't use Daemon threads for workers. The Pipeline should continue to execute even if there
        // are no other active threads (for example, because waitUntilFinish was not called)
        this.executorService = Executors.newFixedThreadPool(targetParallelism,
                new ThreadFactoryBuilder().setThreadFactory(MoreExecutors.platformThreadFactory())
                        .setNameFormat("direct-runner-worker").build());
        this.registry = registry;
        this.evaluationContext = context;

        // Weak Values allows TransformExecutorServices that are no longer in use to be reclaimed.
        // Executing TransformExecutorServices have a strong reference to their TransformExecutorService
        // which stops the TransformExecutorServices from being prematurely garbage collected
        serialExecutorServices = CacheBuilder.newBuilder().weakValues()
                .removalListener(shutdownExecutorServiceListener())
                .build(serialTransformExecutorServiceCacheLoader());

        this.visibleUpdates = new QueueMessageReceiver();

        parallelExecutorService = TransformExecutorServices.parallel(executorService);
        executorFactory = new DirectTransformExecutor.Factory(context, registry, transformEnforcements);
    }

    private CacheLoader<StepAndKey, TransformExecutorService> serialTransformExecutorServiceCacheLoader() {
        return new CacheLoader<StepAndKey, TransformExecutorService>() {
            @Override
            public TransformExecutorService load(StepAndKey stepAndKey) throws Exception {
                return TransformExecutorServices.serial(executorService);
            }
        };
    }

    private RemovalListener<StepAndKey, TransformExecutorService> shutdownExecutorServiceListener() {
        return notification -> {
            TransformExecutorService service = notification.getValue();
            if (service != null) {
                service.shutdown();
            }
        };
    }

    @Override
    // TODO: [BEAM-4563] Pass Future back to consumer to check for async errors
    @SuppressWarnings("FutureReturnValueIgnored")
    public void start(DirectGraph graph, RootProviderRegistry rootProviderRegistry) {
        int numTargetSplits = Math.max(3, targetParallelism);
        ImmutableMap.Builder<AppliedPTransform<?, ?, ?>, ConcurrentLinkedQueue<CommittedBundle<?>>> pendingRootBundles = ImmutableMap
                .builder();
        for (AppliedPTransform<?, ?, ?> root : graph.getRootTransforms()) {
            ConcurrentLinkedQueue<CommittedBundle<?>> pending = new ConcurrentLinkedQueue<>();
            try {
                Collection<CommittedBundle<?>> initialInputs = rootProviderRegistry.getInitialInputs(root,
                        numTargetSplits);
                pending.addAll(initialInputs);
            } catch (Exception e) {
                throw UserCodeException.wrap(e);
            }
            pendingRootBundles.put(root, pending);
        }
        evaluationContext.initialize(pendingRootBundles.build());
        final ExecutionDriver executionDriver = QuiescenceDriver.create(evaluationContext, graph, this,
                visibleUpdates, pendingRootBundles.build());
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                DriverState drive = executionDriver.drive();
                if (drive.isTermainal()) {
                    State newPipelineState = State.UNKNOWN;
                    switch (drive) {
                    case FAILED:
                        newPipelineState = State.FAILED;
                        break;
                    case SHUTDOWN:
                        newPipelineState = State.DONE;
                        break;
                    case CONTINUE:
                        throw new IllegalStateException(
                                String.format("%s should not be a terminal state", DriverState.CONTINUE));
                    default:
                        throw new IllegalArgumentException(
                                String.format("Unknown %s %s", DriverState.class.getSimpleName(), drive));
                    }
                    shutdownIfNecessary(newPipelineState);
                } else {
                    executorService.submit(this);
                }
            }
        });
    }

    @SuppressWarnings("unchecked")
    @Override
    public void process(CommittedBundle<?> bundle, AppliedPTransform<?, ?, ?> consumer,
            CompletionCallback onComplete) {
        evaluateBundle(consumer, bundle, onComplete);
    }

    private <T> void evaluateBundle(final AppliedPTransform<?, ?, ?> transform, final CommittedBundle<T> bundle,
            final CompletionCallback onComplete) {
        TransformExecutorService transformExecutor;

        if (isKeyed(bundle.getPCollection())) {
            final StepAndKey stepAndKey = StepAndKey.of(transform, bundle.getKey());
            // This executor will remain reachable until it has executed all scheduled transforms.
            // The TransformExecutors keep a strong reference to the Executor, the ExecutorService keeps
            // a reference to the scheduled DirectTransformExecutor callable. Follow-up TransformExecutors
            // (scheduled due to the completion of another DirectTransformExecutor) are provided to the
            // ExecutorService before the Earlier DirectTransformExecutor callable completes.
            transformExecutor = serialExecutorServices.getUnchecked(stepAndKey);
        } else {
            transformExecutor = parallelExecutorService;
        }

        TransformExecutor callable = executorFactory.create(bundle, transform, onComplete, transformExecutor);
        if (!pipelineState.get().isTerminal()) {
            transformExecutor.schedule(callable);
        }
    }

    private boolean isKeyed(PValue pvalue) {
        return evaluationContext.isKeyed(pvalue);
    }

    @Override
    public State waitUntilFinish(Duration duration) throws Exception {
        Instant completionTime;
        if (duration.equals(Duration.ZERO)) {
            completionTime = new Instant(Long.MAX_VALUE);
        } else {
            completionTime = Instant.now().plus(duration);
        }

        VisibleExecutorUpdate update = null;
        while (Instant.now().isBefore(completionTime) && (update == null || isTerminalStateUpdate(update))) {
            // Get an update; don't block forever if another thread has handled it. The call to poll will
            // wait the entire timeout; this call primarily exists to relinquish any core.
            update = visibleUpdates.tryNext(Duration.millis(25L));
            if (update == null && pipelineState.get().isTerminal()) {
                // there are no updates to process and no updates will ever be published because the
                // executor is shutdown
                return pipelineState.get();
            } else if (update != null && update.thrown.isPresent()) {
                Throwable thrown = update.thrown.get();
                if (thrown instanceof Exception) {
                    throw (Exception) thrown;
                } else if (thrown instanceof Error) {
                    throw (Error) thrown;
                } else {
                    throw new Exception("Unknown Type of Throwable", thrown);
                }
            }
        }
        return pipelineState.get();
    }

    @Override
    public State getPipelineState() {
        return pipelineState.get();
    }

    private boolean isTerminalStateUpdate(VisibleExecutorUpdate update) {
        return !(update.getNewState() == null && update.getNewState().isTerminal());
    }

    @Override
    public void stop() {
        shutdownIfNecessary(State.CANCELLED);
        visibleUpdates.cancelled();
    }

    private void shutdownIfNecessary(State newState) {
        if (!newState.isTerminal()) {
            return;
        }
        LOG.debug("Pipeline has terminated. Shutting down.");

        final Collection<Exception> errors = new ArrayList<>();
        // Stop accepting new work before shutting down the executor. This ensures that thread don't try
        // to add work to the shutdown executor.
        try {
            serialExecutorServices.invalidateAll();
        } catch (final RuntimeException re) {
            errors.add(re);
        }
        try {
            serialExecutorServices.cleanUp();
        } catch (final RuntimeException re) {
            errors.add(re);
        }
        try {
            parallelExecutorService.shutdown();
        } catch (final RuntimeException re) {
            errors.add(re);
        }
        try {
            executorService.shutdown();
        } catch (final RuntimeException re) {
            errors.add(re);
        }
        try {
            metricsExecutor.shutdown();
        } catch (final RuntimeException re) {
            errors.add(re);
        }
        try {
            registry.cleanup();
        } catch (final Exception e) {
            errors.add(e);
        }
        pipelineState.compareAndSet(State.RUNNING, newState); // ensure we hit a terminal node
        if (!errors.isEmpty()) {
            final IllegalStateException exception = new IllegalStateException("Error"
                    + (errors.size() == 1 ? "" : "s") + " during executor shutdown:\n"
                    + errors.stream().map(Exception::getMessage).collect(Collectors.joining("\n- ", "- ", "")));
            visibleUpdates.failed(exception);
            throw exception;
        }
    }

    /**
     * An update of interest to the user. Used in {@link #waitUntilFinish} to decide whether to return
     * normally or throw an exception.
     */
    private static class VisibleExecutorUpdate {
        private final Optional<? extends Throwable> thrown;
        @Nullable
        private final State newState;

        public static VisibleExecutorUpdate fromException(Exception e) {
            return new VisibleExecutorUpdate(null, e);
        }

        public static VisibleExecutorUpdate fromError(Error err) {
            return new VisibleExecutorUpdate(State.FAILED, err);
        }

        public static VisibleExecutorUpdate finished() {
            return new VisibleExecutorUpdate(State.DONE, null);
        }

        public static VisibleExecutorUpdate cancelled() {
            return new VisibleExecutorUpdate(State.CANCELLED, null);
        }

        private VisibleExecutorUpdate(State newState, @Nullable Throwable exception) {
            this.thrown = Optional.fromNullable(exception);
            this.newState = newState;
        }

        State getNewState() {
            return newState;
        }
    }

    private static class QueueMessageReceiver implements PipelineMessageReceiver {
        // If the type of BlockingQueue changes, ensure the findbugs filter is updated appropriately
        private final BlockingQueue<VisibleExecutorUpdate> updates = new LinkedBlockingQueue<>();

        @Override
        public void failed(Exception e) {
            updates.offer(VisibleExecutorUpdate.fromException(e));
        }

        @Override
        public void failed(Error e) {
            updates.offer(VisibleExecutorUpdate.fromError(e));
        }

        @Override
        public void cancelled() {
            updates.offer(VisibleExecutorUpdate.cancelled());
        }

        @Override
        public void completed() {
            updates.offer(VisibleExecutorUpdate.finished());
        }

        /** Try to get the next unconsumed message in this {@link QueueMessageReceiver}. */
        @Nullable
        private VisibleExecutorUpdate tryNext(Duration timeout) throws InterruptedException {
            return updates.poll(timeout.getMillis(), TimeUnit.MILLISECONDS);
        }
    }
}