io.prestosql.execution.BenchmarkNodeScheduler.java Source code

Java tutorial

Introduction

Here is the source code for io.prestosql.execution.BenchmarkNodeScheduler.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.execution;

import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Iterators;
import com.google.common.collect.Multimap;
import io.prestosql.client.NodeVersion;
import io.prestosql.connector.ConnectorId;
import io.prestosql.execution.scheduler.FlatNetworkTopology;
import io.prestosql.execution.scheduler.LegacyNetworkTopology;
import io.prestosql.execution.scheduler.NetworkLocation;
import io.prestosql.execution.scheduler.NetworkTopology;
import io.prestosql.execution.scheduler.NodeScheduler;
import io.prestosql.execution.scheduler.NodeSchedulerConfig;
import io.prestosql.execution.scheduler.NodeSelector;
import io.prestosql.metadata.InMemoryNodeManager;
import io.prestosql.metadata.PrestoNode;
import io.prestosql.metadata.Split;
import io.prestosql.spi.HostAddress;
import io.prestosql.spi.Node;
import io.prestosql.spi.connector.ConnectorSplit;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.testing.TestingTransactionHandle;
import io.prestosql.util.FinalizerService;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;

import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.prestosql.execution.scheduler.NodeSchedulerConfig.NetworkTopologyType.BENCHMARK;
import static io.prestosql.execution.scheduler.NodeSchedulerConfig.NetworkTopologyType.FLAT;
import static io.prestosql.execution.scheduler.NodeSchedulerConfig.NetworkTopologyType.LEGACY;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;

@SuppressWarnings("MethodMayBeStatic")
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@Fork(1)
@Warmup(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
public class BenchmarkNodeScheduler {
    private static final int MAX_SPLITS_PER_NODE = 100;
    private static final int MAX_PENDING_SPLITS_PER_TASK_PER_NODE = 50;
    private static final int NODES = 200;
    private static final int DATA_NODES = 10_000;
    private static final int RACKS = DATA_NODES / 25;
    private static final int SPLITS = NODES * (MAX_SPLITS_PER_NODE + MAX_PENDING_SPLITS_PER_TASK_PER_NODE / 3);
    private static final int SPLIT_BATCH_SIZE = 100;
    private static final ConnectorId CONNECTOR_ID = new ConnectorId("test_connector_id");

    @Benchmark
    @OperationsPerInvocation(SPLITS)
    public Object benchmark(BenchmarkData data) {
        List<RemoteTask> remoteTasks = ImmutableList.copyOf(data.getTaskMap().values());
        Iterator<MockRemoteTaskFactory.MockRemoteTask> finishingTask = Iterators.cycle(data.getTaskMap().values());
        Iterator<Split> splits = data.getSplits().iterator();
        Set<Split> batch = new HashSet<>();
        while (splits.hasNext() || !batch.isEmpty()) {
            Multimap<Node, Split> assignments = data.getNodeSelector().computeAssignments(batch, remoteTasks)
                    .getAssignments();
            for (Node node : assignments.keySet()) {
                MockRemoteTaskFactory.MockRemoteTask remoteTask = data.getTaskMap().get(node);
                remoteTask.addSplits(ImmutableMultimap.<PlanNodeId, Split>builder()
                        .putAll(new PlanNodeId("sourceId"), assignments.get(node)).build());
                remoteTask.startSplits(MAX_SPLITS_PER_NODE);
            }
            if (assignments.size() == batch.size()) {
                batch.clear();
            } else {
                batch.removeAll(assignments.values());
            }
            while (batch.size() < SPLIT_BATCH_SIZE && splits.hasNext()) {
                batch.add(splits.next());
            }
            finishingTask.next().finishSplits((int) Math.ceil(MAX_SPLITS_PER_NODE / 50.0));
        }

        return remoteTasks;
    }

    @SuppressWarnings("FieldMayBeFinal")
    @State(Scope.Thread)
    public static class BenchmarkData {
        @Param({ LEGACY, BENCHMARK, FLAT })
        private String topologyName = LEGACY;

        private FinalizerService finalizerService = new FinalizerService();
        private NodeSelector nodeSelector;
        private Map<Node, MockRemoteTaskFactory.MockRemoteTask> taskMap = new HashMap<>();
        private List<Split> splits = new ArrayList<>();

        @Setup
        public void setup() {
            TestingTransactionHandle transactionHandle = TestingTransactionHandle.create();

            finalizerService.start();
            NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService);

            ImmutableList.Builder<Node> nodeBuilder = ImmutableList.builder();
            for (int i = 0; i < NODES; i++) {
                nodeBuilder.add(new PrestoNode("node" + i, URI.create("http://" + addressForHost(i).getHostText()),
                        NodeVersion.UNKNOWN, false));
            }
            List<Node> nodes = nodeBuilder.build();
            MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(
                    newCachedThreadPool(daemonThreadsNamed("remoteTaskExecutor-%s")),
                    newScheduledThreadPool(2, daemonThreadsNamed("remoteTaskScheduledExecutor-%s")));
            for (int i = 0; i < nodes.size(); i++) {
                Node node = nodes.get(i);
                ImmutableList.Builder<Split> initialSplits = ImmutableList.builder();
                for (int j = 0; j < MAX_SPLITS_PER_NODE + MAX_PENDING_SPLITS_PER_TASK_PER_NODE; j++) {
                    initialSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote(i)));
                }
                TaskId taskId = new TaskId("test", 1, i);
                MockRemoteTaskFactory.MockRemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId,
                        node, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId));
                nodeTaskMap.addTask(node, remoteTask);
                taskMap.put(node, remoteTask);
            }

            for (int i = 0; i < SPLITS; i++) {
                splits.add(new Split(CONNECTOR_ID, transactionHandle,
                        new TestSplitRemote(ThreadLocalRandom.current().nextInt(DATA_NODES))));
            }

            InMemoryNodeManager nodeManager = new InMemoryNodeManager();
            nodeManager.addNode(CONNECTOR_ID, nodes);
            NodeScheduler nodeScheduler = new NodeScheduler(getNetworkTopology(), nodeManager,
                    getNodeSchedulerConfig(), nodeTaskMap);
            nodeSelector = nodeScheduler.createNodeSelector(CONNECTOR_ID);
        }

        @TearDown
        public void tearDown() {
            finalizerService.destroy();
        }

        private NodeSchedulerConfig getNodeSchedulerConfig() {
            return new NodeSchedulerConfig().setMaxSplitsPerNode(MAX_SPLITS_PER_NODE).setIncludeCoordinator(false)
                    .setNetworkTopology(topologyName)
                    .setMaxPendingSplitsPerTask(MAX_PENDING_SPLITS_PER_TASK_PER_NODE);
        }

        private NetworkTopology getNetworkTopology() {
            NetworkTopology topology;
            switch (topologyName) {
            case LEGACY:
                topology = new LegacyNetworkTopology();
                break;
            case FLAT:
                topology = new FlatNetworkTopology();
                break;
            case BENCHMARK:
                topology = new BenchmarkNetworkTopology();
                break;
            default:
                throw new IllegalStateException();
            }
            return topology;
        }

        public Map<Node, MockRemoteTaskFactory.MockRemoteTask> getTaskMap() {
            return taskMap;
        }

        public NodeSelector getNodeSelector() {
            return nodeSelector;
        }

        public List<Split> getSplits() {
            return splits;
        }
    }

    public static void main(String[] args) throws Throwable {
        Options options = new OptionsBuilder().verbosity(VerboseMode.NORMAL)
                .include(".*" + BenchmarkNodeScheduler.class.getSimpleName() + ".*").build();
        new Runner(options).run();
    }

    private static class BenchmarkNetworkTopology implements NetworkTopology {
        @Override
        public NetworkLocation locate(HostAddress address) {
            List<String> parts = new ArrayList<>(
                    ImmutableList.copyOf(Splitter.on(".").split(address.getHostText())));
            Collections.reverse(parts);
            return NetworkLocation.create(parts);
        }

        @Override
        public List<String> getLocationSegmentNames() {
            return ImmutableList.of("rack", "machine");
        }
    }

    private static class TestSplitRemote implements ConnectorSplit {
        private final List<HostAddress> hosts;

        public TestSplitRemote(int dataHost) {
            hosts = ImmutableList.of(addressForHost(dataHost));
        }

        @Override
        public boolean isRemotelyAccessible() {
            return true;
        }

        @Override
        public List<HostAddress> getAddresses() {
            return hosts;
        }

        @Override
        public Object getInfo() {
            return this;
        }
    }

    private static HostAddress addressForHost(int host) {
        int rack = Integer.hashCode(host) % RACKS;
        return HostAddress.fromParts("host" + host + ".rack" + rack, 1);
    }
}