org.apache.storm.grouping.LoadAwareShuffleGroupingTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.storm.grouping.LoadAwareShuffleGroupingTest.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.storm.grouping;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.MoreExecutors;
import java.util.Arrays;

import org.apache.commons.collections.map.HashedMap;
import org.apache.storm.Config;
import org.apache.storm.daemon.GrouperFactory;
import org.apache.storm.generated.GlobalStreamId;
import org.apache.storm.generated.Grouping;
import org.apache.storm.generated.NodeInfo;
import org.apache.storm.generated.NullStruct;
import org.apache.storm.task.WorkerTopologyContext;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class LoadAwareShuffleGroupingTest {
    public static final double ACCEPTABLE_MARGIN = 0.015;
    private static final Logger LOG = LoggerFactory.getLogger(LoadAwareShuffleGroupingTest.class);

    private WorkerTopologyContext mockContext(List<Integer> availableTaskIds) {
        Map<String, Object> conf = new HashMap<>();
        conf.put(Config.STORM_NETWORK_TOPOGRAPHY_PLUGIN,
                "org.apache.storm.networktopography.DefaultRackDNSToSwitchMapping");
        conf.put(Config.TOPOLOGY_LOCALITYAWARE_HIGHER_BOUND_PERCENT, 0.8);
        conf.put(Config.TOPOLOGY_LOCALITYAWARE_LOWER_BOUND_PERCENT, 0.2);

        WorkerTopologyContext context = mock(WorkerTopologyContext.class);
        when(context.getConf()).thenReturn(conf);
        Map<Integer, NodeInfo> taskNodeToPort = new HashMap<>();
        NodeInfo nodeInfo = new NodeInfo("node-id", Sets.newHashSet(6700L));
        availableTaskIds.forEach(e -> taskNodeToPort.put(e, nodeInfo));
        when(context.getTaskToNodePort()).thenReturn(new AtomicReference<>(taskNodeToPort));
        when(context.getThisWorkerHost()).thenReturn("node-id");
        when(context.getThisWorkerPort()).thenReturn(6700);
        return context;
    }

    @Test
    public void testUnevenLoadOverTime() throws Exception {
        LoadAwareShuffleGrouping grouping = new LoadAwareShuffleGrouping();
        WorkerTopologyContext context = mockContext(Arrays.asList(1, 2));
        grouping.prepare(context, new GlobalStreamId("a", "default"), Arrays.asList(1, 2));
        double expectedOneWeight = 100.0;
        double expectedTwoWeight = 100.0;

        Map<Integer, Double> localLoad = new HashMap<>();
        localLoad.put(1, 1.0);
        localLoad.put(2, 0.0);
        LoadMapping lm = new LoadMapping();
        lm.setLocal(localLoad);
        //First verify that if something has a high load it's distribution will drop over time
        for (int i = 9; i >= 0; i--) {
            grouping.refreshLoad(lm);
            expectedOneWeight -= 10.0;
            Map<Integer, Double> countByType = count(grouping.choices, grouping.rets);
            LOG.info("contByType = {}", countByType);
            double expectedOnePercentage = expectedOneWeight / (expectedOneWeight + expectedTwoWeight);
            double expectedTwoPercentage = expectedTwoWeight / (expectedOneWeight + expectedTwoWeight);
            assertEquals("i = " + i, expectedOnePercentage,
                    countByType.getOrDefault(1, 0.0) / LoadAwareShuffleGrouping.CAPACITY, 0.01);
            assertEquals("i = " + i, expectedTwoPercentage,
                    countByType.getOrDefault(2, 0.0) / LoadAwareShuffleGrouping.CAPACITY, 0.01);
        }

        //Now verify that when it is switched we can recover
        localLoad.put(1, 0.0);
        localLoad.put(2, 1.0);
        lm.setLocal(localLoad);

        while (expectedOneWeight < 100.0) {
            grouping.refreshLoad(lm);
            expectedOneWeight += 1.0;
            expectedTwoWeight = Math.max(0.0, expectedTwoWeight - 10.0);
            Map<Integer, Double> countByType = count(grouping.choices, grouping.rets);
            LOG.info("contByType = {}", countByType);
            double expectedOnePercentage = expectedOneWeight / (expectedOneWeight + expectedTwoWeight);
            double expectedTwoPercentage = expectedTwoWeight / (expectedOneWeight + expectedTwoWeight);
            assertEquals(expectedOnePercentage,
                    countByType.getOrDefault(1, 0.0) / LoadAwareShuffleGrouping.CAPACITY, 0.01);
            assertEquals(expectedTwoPercentage,
                    countByType.getOrDefault(2, 0.0) / LoadAwareShuffleGrouping.CAPACITY, 0.01);
        }
    }

    private Map<Integer, Double> count(int[] choices, List<Integer>[] rets) {
        Map<Integer, Double> ret = new HashMap<>();
        for (int i : choices) {
            int task = rets[i].get(0);
            ret.put(task, ret.getOrDefault(task, 0.0) + 1);
        }
        return ret;
    }

    @Test
    public void testLoadAwareShuffleGroupingWithEvenLoad() {
        // just pick arbitrary number
        final int numTasks = 7;
        final LoadAwareShuffleGrouping grouper = new LoadAwareShuffleGrouping();

        // Define our taskIds and loads
        final List<Integer> availableTaskIds = getAvailableTaskIds(numTasks);
        final LoadMapping loadMapping = buildLocalTasksEvenLoadMapping(availableTaskIds);

        final WorkerTopologyContext context = mockContext(availableTaskIds);
        grouper.prepare(context, null, availableTaskIds);

        // Keep track of how many times we see each taskId
        // distribution should be even for all nodes when loads are even
        int desiredTaskCountPerTask = 5000;
        int totalEmits = numTasks * desiredTaskCountPerTask;
        int minPrCount = (int) (totalEmits * ((1.0 / numTasks) - ACCEPTABLE_MARGIN));
        int maxPrCount = (int) (totalEmits * ((1.0 / numTasks) + ACCEPTABLE_MARGIN));

        int[] taskCounts = runChooseTasksWithVerification(grouper, totalEmits, numTasks, loadMapping);

        for (int i = 0; i < numTasks; i++) {
            assertTrue("Distribution should be even for all nodes with small delta",
                    taskCounts[i] >= minPrCount && taskCounts[i] <= maxPrCount);
        }
    }

    @Test
    public void testLoadAwareShuffleGroupingWithEvenLoadMultiThreaded()
            throws InterruptedException, ExecutionException {
        final int numTasks = 7;

        final LoadAwareShuffleGrouping grouper = new LoadAwareShuffleGrouping();

        // Task Id not used, so just pick a static value
        final int inputTaskId = 100;
        // Define our taskIds - the test expects these to be incrementing by one up from zero
        final List<Integer> availableTaskIds = getAvailableTaskIds(numTasks);
        final LoadMapping loadMapping = buildLocalTasksEvenLoadMapping(availableTaskIds);

        final WorkerTopologyContext context = mockContext(availableTaskIds);
        grouper.prepare(context, null, availableTaskIds);

        // force triggers building ring
        grouper.refreshLoad(loadMapping);

        // calling chooseTasks should be finished before refreshing ring
        // adjusting groupingExecutionsPerThread might be needed with really slow machine
        // we allow race condition between refreshing ring and choosing tasks
        // so it will not make exact even distribution, though diff is expected to be small
        // given that all threadTasks are finished before refreshing ring,
        // distribution should be exactly even
        final int groupingExecutionsPerThread = numTasks * 5000;
        final int numThreads = 10;
        int totalEmits = groupingExecutionsPerThread * numThreads;

        List<Callable<int[]>> threadTasks = Lists.newArrayList();
        for (int x = 0; x < numThreads; x++) {
            Callable<int[]> threadTask = new Callable<int[]>() {
                @Override
                public int[] call() throws Exception {
                    int[] taskCounts = new int[availableTaskIds.size()];
                    for (int i = 1; i <= groupingExecutionsPerThread; i++) {
                        List<Integer> taskIds = grouper.chooseTasks(inputTaskId, Lists.newArrayList());

                        // Validate a single task id return
                        assertNotNull("Not null taskId list returned", taskIds);
                        assertEquals("Single task Id returned", 1, taskIds.size());

                        int taskId = taskIds.get(0);

                        assertTrue("TaskId should exist", taskId >= 0 && taskId < availableTaskIds.size());
                        taskCounts[taskId]++;
                    }
                    return taskCounts;
                }
            };

            // Add to our collection.
            threadTasks.add(threadTask);
        }

        ExecutorService executor = Executors.newFixedThreadPool(threadTasks.size());
        List<Future<int[]>> taskResults = executor.invokeAll(threadTasks);

        // Wait for all tasks to complete
        int[] taskIdTotals = new int[numTasks];
        for (Future taskResult : taskResults) {
            while (!taskResult.isDone()) {
                Thread.sleep(1000);
            }
            int[] taskDistributions = (int[]) taskResult.get();
            for (int i = 0; i < taskDistributions.length; i++) {
                taskIdTotals[i] += taskDistributions[i];
            }
        }

        int minPrCount = (int) (totalEmits * ((1.0 / numTasks) - ACCEPTABLE_MARGIN));
        int maxPrCount = (int) (totalEmits * ((1.0 / numTasks) + ACCEPTABLE_MARGIN));

        for (int i = 0; i < numTasks; i++) {
            assertTrue("Distribution should be even for all nodes with small delta",
                    taskIdTotals[i] >= minPrCount && taskIdTotals[i] <= maxPrCount);
        }
    }

    @Test
    public void testShuffleLoadEven() {
        // port test-shuffle-load-even
        LoadAwareCustomStreamGrouping shuffler = GrouperFactory.mkGrouper(mockContext(Lists.newArrayList(1, 2)),
                "comp", "stream", null, Grouping.shuffle(new NullStruct()), Lists.newArrayList(1, 2),
                Collections.emptyMap());
        int numMessages = 100000;
        int minPrCount = (int) (numMessages * (0.5 - ACCEPTABLE_MARGIN));
        int maxPrCount = (int) (numMessages * (0.5 + ACCEPTABLE_MARGIN));
        LoadMapping load = new LoadMapping();
        Map<Integer, Double> loadInfoMap = new HashMap<>();
        loadInfoMap.put(1, 0.0);
        loadInfoMap.put(2, 0.0);
        load.setLocal(loadInfoMap);

        // force triggers building ring
        shuffler.refreshLoad(load);

        List<Object> data = Lists.newArrayList(1, 2);
        int[] frequencies = new int[3];
        for (int i = 0; i < numMessages; i++) {
            List<Integer> tasks = shuffler.chooseTasks(1, data);
            for (int task : tasks) {
                frequencies[task]++;
            }
        }

        int load1 = frequencies[1];
        int load2 = frequencies[2];

        LOG.info("Frequency info: load1 = {}, load2 = {}", load1, load2);

        assertTrue(load1 >= minPrCount);
        assertTrue(load1 <= maxPrCount);
        assertTrue(load2 >= minPrCount);
        assertTrue(load2 <= maxPrCount);
    }

    @Ignore
    @Test
    public void testBenchmarkLoadAwareShuffleGroupingEvenLoad() {
        final int numTasks = 10;
        List<Integer> availableTaskIds = getAvailableTaskIds(numTasks);
        runSimpleBenchmark(new LoadAwareShuffleGrouping(), availableTaskIds,
                buildLocalTasksEvenLoadMapping(availableTaskIds));
    }

    @Ignore
    @Test
    public void testBenchmarkLoadAwareShuffleGroupingUnevenLoad() {
        final int numTasks = 10;
        List<Integer> availableTaskIds = getAvailableTaskIds(numTasks);
        runSimpleBenchmark(new LoadAwareShuffleGrouping(), availableTaskIds,
                buildLocalTasksUnevenLoadMapping(availableTaskIds));
    }

    @Ignore
    @Test
    public void testBenchmarkLoadAwareShuffleGroupingEvenLoadAndMultiThreaded()
            throws ExecutionException, InterruptedException {
        final int numTasks = 10;
        final int numThreads = 2;
        List<Integer> availableTaskIds = getAvailableTaskIds(numTasks);
        runMultithreadedBenchmark(new LoadAwareShuffleGrouping(), availableTaskIds,
                buildLocalTasksEvenLoadMapping(availableTaskIds), numThreads);
    }

    @Ignore
    @Test
    public void testBenchmarkLoadAwareShuffleGroupingUnevenLoadAndMultiThreaded()
            throws ExecutionException, InterruptedException {
        final int numTasks = 10;
        final int numThreads = 2;
        List<Integer> availableTaskIds = getAvailableTaskIds(numTasks);
        runMultithreadedBenchmark(new LoadAwareShuffleGrouping(), availableTaskIds,
                buildLocalTasksUnevenLoadMapping(availableTaskIds), numThreads);
    }

    private List<Integer> getAvailableTaskIds(int numTasks) {
        // this method should return sequential numbers starting at 0
        final List<Integer> availableTaskIds = Lists.newArrayList();
        for (int i = 0; i < numTasks; i++) {
            availableTaskIds.add(i);
        }
        return availableTaskIds;
    }

    private LoadMapping buildLocalTasksEvenLoadMapping(List<Integer> availableTasks) {
        LoadMapping loadMapping = new LoadMapping();
        Map<Integer, Double> localLoadMap = new HashMap<>(availableTasks.size());
        for (int i = 0; i < availableTasks.size(); i++) {
            localLoadMap.put(availableTasks.get(i), 0.1);
        }
        loadMapping.setLocal(localLoadMap);
        return loadMapping;
    }

    private LoadMapping buildLocalTasksUnevenLoadMapping(List<Integer> availableTasks) {
        LoadMapping loadMapping = new LoadMapping();
        Map<Integer, Double> localLoadMap = new HashMap<>(availableTasks.size());
        for (int i = 0; i < availableTasks.size(); i++) {
            localLoadMap.put(availableTasks.get(i), 0.1 * (i + 1));
        }
        loadMapping.setLocal(localLoadMap);
        return loadMapping;
    }

    private LoadMapping buildLocalTasksRandomLoadMapping(List<Integer> availableTasks) {
        LoadMapping loadMapping = new LoadMapping();
        Map<Integer, Double> localLoadMap = new HashMap<>(availableTasks.size());
        for (int i = 0; i < availableTasks.size(); i++) {
            localLoadMap.put(availableTasks.get(i), Math.random());
        }
        loadMapping.setLocal(localLoadMap);
        return loadMapping;
    }

    private int[] runChooseTasksWithVerification(LoadAwareShuffleGrouping grouper, int totalEmits, int numTasks,
            LoadMapping loadMapping) {
        int[] taskCounts = new int[numTasks];

        // Task Id not used, so just pick a static value
        int inputTaskId = 100;

        // force triggers building ring
        grouper.refreshLoad(loadMapping);

        for (int i = 1; i <= totalEmits; i++) {
            List<Integer> taskIds = grouper.chooseTasks(inputTaskId, Lists.newArrayList());

            // Validate a single task id return
            assertNotNull("Not null taskId list returned", taskIds);
            assertEquals("Single task Id returned", 1, taskIds.size());

            int taskId = taskIds.get(0);

            assertTrue("TaskId should exist", taskId >= 0 && taskId < numTasks);
            taskCounts[taskId]++;
        }
        return taskCounts;
    }

    private void runSimpleBenchmark(LoadAwareCustomStreamGrouping grouper, List<Integer> availableTaskIds,
            LoadMapping loadMapping) {
        // Task Id not used, so just pick a static value
        final int inputTaskId = 100;

        WorkerTopologyContext context = mockContext(availableTaskIds);
        grouper.prepare(context, null, availableTaskIds);

        // periodically calls refreshLoad in 1 sec to simulate worker load update timer
        ScheduledExecutorService refreshService = MoreExecutors
                .getExitingScheduledExecutorService(new ScheduledThreadPoolExecutor(1));
        refreshService.scheduleAtFixedRate(() -> grouper.refreshLoad(loadMapping), 1, 1, TimeUnit.SECONDS);

        long current = System.currentTimeMillis();
        int idx = 0;
        while (true) {
            grouper.chooseTasks(inputTaskId, Lists.newArrayList());

            idx++;
            if (idx % 100000 == 0) {
                // warm up 60 seconds
                if (System.currentTimeMillis() - current >= 60_000) {
                    break;
                }
            }
        }

        current = System.currentTimeMillis();
        for (int i = 1; i <= 2_000_000_000; i++) {
            grouper.chooseTasks(inputTaskId, Lists.newArrayList());
        }

        LOG.info("Duration: {} ms", (System.currentTimeMillis() - current));

        refreshService.shutdownNow();
    }

    private void runMultithreadedBenchmark(LoadAwareCustomStreamGrouping grouper, List<Integer> availableTaskIds,
            LoadMapping loadMapping, int numThreads) throws InterruptedException, ExecutionException {
        // Task Id not used, so just pick a static value
        final int inputTaskId = 100;

        final WorkerTopologyContext context = mockContext(availableTaskIds);

        // Call prepare with our available taskIds
        grouper.prepare(context, null, availableTaskIds);

        // periodically calls refreshLoad in 1 sec to simulate worker load update timer
        ScheduledExecutorService refreshService = MoreExecutors
                .getExitingScheduledExecutorService(new ScheduledThreadPoolExecutor(1));
        refreshService.scheduleAtFixedRate(() -> grouper.refreshLoad(loadMapping), 1, 1, TimeUnit.SECONDS);

        long current = System.currentTimeMillis();
        int idx = 0;
        while (true) {
            grouper.chooseTasks(inputTaskId, Lists.newArrayList());

            idx++;
            if (idx % 100000 == 0) {
                // warm up 60 seconds
                if (System.currentTimeMillis() - current >= 60_000) {
                    break;
                }
            }
        }

        final int groupingExecutionsPerThread = 2_000_000_000;

        List<Callable<Long>> threadTasks = Lists.newArrayList();
        for (int x = 0; x < numThreads; x++) {
            Callable<Long> threadTask = new Callable<Long>() {
                @Override
                public Long call() throws Exception {
                    long current = System.currentTimeMillis();
                    for (int i = 1; i <= groupingExecutionsPerThread; i++) {
                        grouper.chooseTasks(inputTaskId, Lists.newArrayList());
                    }
                    return System.currentTimeMillis() - current;
                }
            };

            // Add to our collection.
            threadTasks.add(threadTask);
        }

        ExecutorService executor = Executors.newFixedThreadPool(threadTasks.size());
        List<Future<Long>> taskResults = executor.invokeAll(threadTasks);

        // Wait for all tasks to complete
        Long maxDurationMillis = 0L;
        for (Future taskResult : taskResults) {
            while (!taskResult.isDone()) {
                Thread.sleep(100);
            }
            Long durationMillis = (Long) taskResult.get();
            if (maxDurationMillis < durationMillis) {
                maxDurationMillis = durationMillis;
            }
        }

        LOG.info("Max duration among threads is : {} ms", maxDurationMillis);

        refreshService.shutdownNow();
    }
}