com.facebook.presto.execution.SqlStageExecution.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.execution.SqlStageExecution.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.execution;

import com.facebook.presto.OutputBuffers;
import com.facebook.presto.Session;
import com.facebook.presto.execution.StateMachine.StateChangeListener;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.split.RemoteSplit;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;

import javax.annotation.concurrent.ThreadSafe;

import java.net.URI;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static com.facebook.presto.OutputBuffers.INITIAL_EMPTY_OUTPUT_BUFFERS;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Sets.newConcurrentHashSet;
import static io.airlift.concurrent.MoreFutures.firstCompletedFuture;
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;

@ThreadSafe
public final class SqlStageExecution {
    private final StageStateMachine stateMachine;
    private final RemoteTaskFactory remoteTaskFactory;
    private final NodeTaskMap nodeTaskMap;

    private final Map<PlanFragmentId, RemoteSourceNode> exchangeSources;

    private final Map<Node, Set<RemoteTask>> tasks = new ConcurrentHashMap<>();
    private final AtomicInteger nextTaskId = new AtomicInteger();
    private final Set<TaskId> allTasks = newConcurrentHashSet();
    private final Set<TaskId> finishedTasks = newConcurrentHashSet();

    private final Multimap<PlanNodeId, URI> exchangeLocations = HashMultimap.create();
    private final Set<PlanNodeId> completeSources = newConcurrentHashSet();
    private final Set<PlanFragmentId> completeSourceFragments = newConcurrentHashSet();

    private final AtomicReference<OutputBuffers> outputBuffers = new AtomicReference<>(
            INITIAL_EMPTY_OUTPUT_BUFFERS);

    public SqlStageExecution(StageId stageId, URI location, PlanFragment fragment,
            RemoteTaskFactory remoteTaskFactory, Session session, NodeTaskMap nodeTaskMap,
            ExecutorService executor) {
        this(new StageStateMachine(requireNonNull(stageId, "stageId is null"),
                requireNonNull(location, "location is null"), requireNonNull(session, "session is null"),
                requireNonNull(fragment, "fragment is null"), requireNonNull(executor, "executor is null")),
                remoteTaskFactory, nodeTaskMap);
    }

    public SqlStageExecution(StageStateMachine stateMachine, RemoteTaskFactory remoteTaskFactory,
            NodeTaskMap nodeTaskMap) {
        this.stateMachine = stateMachine;
        this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
        this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null");

        ImmutableMap.Builder<PlanFragmentId, RemoteSourceNode> fragmentToExchangeSource = ImmutableMap.builder();
        for (RemoteSourceNode remoteSourceNode : stateMachine.getFragment().getRemoteSourceNodes()) {
            for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) {
                fragmentToExchangeSource.put(planFragmentId, remoteSourceNode);
            }
        }
        this.exchangeSources = fragmentToExchangeSource.build();
    }

    public StageId getStageId() {
        return stateMachine.getStageId();
    }

    public StageState getState() {
        return stateMachine.getState();
    }

    public void addStateChangeListener(StateChangeListener<StageState> stateChangeListener) {
        stateMachine.addStateChangeListener(stateChangeListener::stateChanged);
    }

    public PlanFragment getFragment() {
        return stateMachine.getFragment();
    }

    public void beginScheduling() {
        stateMachine.transitionToScheduling();
    }

    public synchronized void transitionToSchedulingSplits() {
        stateMachine.transitionToSchedulingSplits();
    }

    public synchronized void schedulingComplete() {
        if (!stateMachine.transitionToScheduled()) {
            return;
        }

        if (getAllTasks().stream().anyMatch(task -> getState() == StageState.RUNNING)) {
            stateMachine.transitionToRunning();
        }
        if (finishedTasks.containsAll(allTasks)) {
            stateMachine.transitionToFinished();
        }

        PlanNodeId partitionedSource = stateMachine.getFragment().getPartitionedSource();
        if (partitionedSource != null) {
            for (RemoteTask task : getAllTasks()) {
                task.noMoreSplits(partitionedSource);
            }
            completeSources.add(partitionedSource);
        }
    }

    public synchronized void cancel() {
        stateMachine.transitionToCanceled();
        getAllTasks().forEach(RemoteTask::cancel);
    }

    public synchronized void abort() {
        stateMachine.transitionToAborted();
        getAllTasks().forEach(RemoteTask::abort);
    }

    public synchronized long getMemoryReservation() {
        return getAllTasks().stream()
                .mapToLong(task -> task.getTaskInfo().getStats().getMemoryReservation().toBytes()).sum();
    }

    public StageInfo getStageInfo() {
        return stateMachine.getStageInfo(
                () -> getAllTasks().stream().map(RemoteTask::getTaskInfo).collect(toImmutableList()),
                ImmutableList::of);
    }

    public synchronized void addExchangeLocation(ExchangeLocation exchangeLocation) {
        requireNonNull(exchangeLocation, "exchangeLocation is null");
        RemoteSourceNode remoteSource = exchangeSources.get(exchangeLocation.getPlanFragmentId());
        checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s",
                exchangeLocation.getPlanFragmentId(), exchangeSources.keySet());

        exchangeLocations.put(remoteSource.getId(), exchangeLocation.getUri());
        for (RemoteTask task : getAllTasks()) {
            task.addSplits(remoteSource.getId(), ImmutableList
                    .of(createRemoteSplitFor(task.getTaskInfo().getTaskId(), exchangeLocation.getUri())));
        }
    }

    public synchronized void noMoreExchangeLocationsFor(PlanFragmentId fragmentId) {
        requireNonNull(fragmentId, "fragmentId is null");
        RemoteSourceNode remoteSource = exchangeSources.get(fragmentId);
        checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId,
                exchangeSources.keySet());

        completeSourceFragments.add(fragmentId);

        // is the source now complete?
        if (completeSourceFragments.containsAll(remoteSource.getSourceFragmentIds())) {
            completeSources.add(remoteSource.getId());
            for (RemoteTask task : getAllTasks()) {
                task.noMoreSplits(remoteSource.getId());
            }
        }
    }

    public synchronized void setOutputBuffers(OutputBuffers outputBuffers) {
        requireNonNull(outputBuffers, "outputBuffers is null");

        while (true) {
            OutputBuffers currentOutputBuffers = this.outputBuffers.get();
            if (outputBuffers.getVersion() <= currentOutputBuffers.getVersion()) {
                return;
            }

            if (this.outputBuffers.compareAndSet(currentOutputBuffers, outputBuffers)) {
                for (RemoteTask task : getAllTasks()) {
                    task.setOutputBuffers(outputBuffers);
                }
                return;
            }
        }
    }

    // do not synchronize
    // this is used for query info building which should be independent of scheduling work
    public boolean hasTasks() {
        return !tasks.isEmpty();
    }

    public synchronized List<RemoteTask> getAllTasks() {
        return tasks.values().stream().flatMap(Set::stream).collect(toImmutableList());
    }

    public synchronized CompletableFuture<?> getTaskStateChange() {
        List<RemoteTask> allTasks = getAllTasks();
        if (allTasks.isEmpty()) {
            return completedFuture(null);
        }

        List<CompletableFuture<TaskInfo>> stateChangeFutures = allTasks.stream()
                .map(task -> task.getStateChange(task.getTaskInfo())).collect(toImmutableList());

        return firstCompletedFuture(stateChangeFutures, true);
    }

    public synchronized RemoteTask scheduleTask(Node node) {
        requireNonNull(node, "node is null");

        return scheduleTask(node, null, ImmutableList.<Split>of());
    }

    public synchronized Set<RemoteTask> scheduleSplits(Node node, Iterable<Split> splits) {
        requireNonNull(node, "node is null");
        requireNonNull(splits, "splits is null");

        PlanNodeId partitionedSource = stateMachine.getFragment().getPartitionedSource();
        checkState(partitionedSource != null, "Partitioned source is null");

        ImmutableSet.Builder<RemoteTask> newTasks = ImmutableSet.builder();
        Collection<RemoteTask> tasks = this.tasks.get(node);
        if (tasks == null) {
            newTasks.add(scheduleTask(node, partitionedSource, splits));
        } else {
            RemoteTask task = tasks.iterator().next();
            task.addSplits(partitionedSource, splits);
        }
        return newTasks.build();
    }

    private synchronized RemoteTask scheduleTask(Node node, PlanNodeId sourceId, Iterable<Split> sourceSplits) {
        TaskId taskId = new TaskId(stateMachine.getStageId(), String.valueOf(nextTaskId.getAndIncrement()));

        ImmutableMultimap.Builder<PlanNodeId, Split> initialSplits = ImmutableMultimap.builder();
        for (Split sourceSplit : sourceSplits) {
            initialSplits.put(sourceId, sourceSplit);
        }
        for (Entry<PlanNodeId, URI> entry : exchangeLocations.entries()) {
            initialSplits.put(entry.getKey(), createRemoteSplitFor(taskId, entry.getValue()));
        }

        RemoteTask task = remoteTaskFactory.createRemoteTask(stateMachine.getSession(), taskId, node,
                stateMachine.getFragment(), initialSplits.build(), outputBuffers.get(),
                nodeTaskMap.getSplitCountChangeListener(node));

        completeSources.forEach(task::noMoreSplits);

        allTasks.add(taskId);
        tasks.computeIfAbsent(node, key -> newConcurrentHashSet()).add(task);
        nodeTaskMap.addTask(node, task);

        task.addStateChangeListener(taskInfo -> {
            StageState stageState = getState();
            if (stageState.isDone()) {
                return;
            }

            TaskState taskState = taskInfo.getState();
            if (taskState == TaskState.FAILED) {
                RuntimeException failure = taskInfo.getFailures().stream().findFirst()
                        .map(ExecutionFailureInfo::toException).orElse(new PrestoException(
                                StandardErrorCode.INTERNAL_ERROR, "A task failed for an unknown reason"));
                stateMachine.transitionToFailed(failure);
            } else if (taskState == TaskState.ABORTED) {
                // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED)
                stateMachine.transitionToFailed(new PrestoException(StandardErrorCode.INTERNAL_ERROR,
                        "A task is in the ABORTED state but stage is " + stageState));
            } else if (taskState == TaskState.FINISHED) {
                finishedTasks.add(task.getTaskId());
            }

            if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING) {
                if (taskState == TaskState.RUNNING) {
                    stateMachine.transitionToRunning();
                }
                if (finishedTasks.containsAll(allTasks)) {
                    stateMachine.transitionToFinished();
                }
            }
        });

        if (!stateMachine.getState().isDone()) {
            task.start();
        } else {
            // stage finished while we were scheduling this task
            task.abort();
        }

        return task;
    }

    public void recordGetSplitTime(long start) {
        stateMachine.recordGetSplitTime(start);
    }

    private static Split createRemoteSplitFor(TaskId taskId, URI taskLocation) {
        URI splitLocation = uriBuilderFrom(taskLocation).appendPath("results").appendPath(taskId.toString())
                .build();
        return new Split("remote", new RemoteSplit(splitLocation));
    }

    @Override
    public String toString() {
        return stateMachine.toString();
    }

    public static class ExchangeLocation {
        private final PlanFragmentId planFragmentId;
        private final URI uri;

        public ExchangeLocation(PlanFragmentId planFragmentId, URI uri) {
            this.planFragmentId = requireNonNull(planFragmentId, "planFragmentId is null");
            this.uri = requireNonNull(uri, "uri is null");
        }

        public PlanFragmentId getPlanFragmentId() {
            return planFragmentId;
        }

        public URI getUri() {
            return uri;
        }

        @Override
        public String toString() {
            return toStringHelper(this).add("planFragmentId", planFragmentId).add("uri", uri).toString();
        }
    }
}