org.apache.tez.dag.library.vertexmanager.TestShuffleVertexManager.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.tez.dag.library.vertexmanager.TestShuffleVertexManager.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.dag.library.vertexmanager;

import com.google.common.collect.Maps;

import org.apache.hadoop.conf.Configuration;
import org.apache.tez.common.ReflectionUtils;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.EdgeManagerPlugin;
import org.apache.tez.dag.api.EdgeManagerPluginContext;
import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.EdgeProperty.SchedulingType;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.VertexManagerPluginContext.TaskWithLocationHint;
import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.anyMap;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class TestShuffleVertexManager {

    @SuppressWarnings({ "unchecked", "rawtypes" })
    @Test(timeout = 5000)
    public void testShuffleVertexManagerAutoParallelism() throws Exception {
        Configuration conf = new Configuration();
        conf.setBoolean(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, true);
        conf.setLong(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, 1000L);
        ShuffleVertexManager manager = null;

        HashMap<String, EdgeProperty> mockInputVertices = new HashMap<String, EdgeProperty>();
        String mockSrcVertexId1 = "Vertex1";
        EdgeProperty eProp1 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));
        String mockSrcVertexId2 = "Vertex2";
        EdgeProperty eProp2 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));
        String mockSrcVertexId3 = "Vertex3";
        EdgeProperty eProp3 = EdgeProperty.create(EdgeProperty.DataMovementType.BROADCAST,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));

        final String mockManagedVertexId = "Vertex4";

        mockInputVertices.put(mockSrcVertexId1, eProp1);
        mockInputVertices.put(mockSrcVertexId2, eProp2);
        mockInputVertices.put(mockSrcVertexId3, eProp3);

        final VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
        when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
        when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(4);

        //Check via setters
        ShuffleVertexManager.ShuffleVertexManagerConfigBuilder configurer = ShuffleVertexManager
                .createConfigBuilder(null);
        VertexManagerPluginDescriptor pluginDesc = configurer.setAutoReduceParallelism(true)
                .setDesiredTaskInputSize(1000l).setMinTaskParallelism(10).setSlowStartMaxSrcCompletionFraction(0.5f)
                .build();
        when(mockContext.getUserPayload()).thenReturn(pluginDesc.getUserPayload());

        manager = ReflectionUtils.createClazzInstance(pluginDesc.getClassName(),
                new Class[] { VertexManagerPluginContext.class }, new Object[] { mockContext });
        manager.initialize();
        verify(mockContext, times(1)).vertexReconfigurationPlanned(); // Tez notified of reconfig

        Assert.assertTrue(manager.enableAutoParallelism == true);
        Assert.assertTrue(manager.desiredTaskInputDataSize == 1000l);
        Assert.assertTrue(manager.minTaskParallelism == 10);
        Assert.assertTrue(manager.slowStartMinSrcCompletionFraction == 0.25f);
        Assert.assertTrue(manager.slowStartMaxSrcCompletionFraction == 0.5f);

        configurer = ShuffleVertexManager.createConfigBuilder(null);
        pluginDesc = configurer.setAutoReduceParallelism(false).build();
        when(mockContext.getUserPayload()).thenReturn(pluginDesc.getUserPayload());

        manager = ReflectionUtils.createClazzInstance(pluginDesc.getClassName(),
                new Class[] { VertexManagerPluginContext.class }, new Object[] { mockContext });
        manager.initialize();
        verify(mockContext, times(1)).vertexReconfigurationPlanned(); // Tez not notified of reconfig

        Assert.assertTrue(manager.enableAutoParallelism == false);
        Assert.assertTrue(
                manager.desiredTaskInputDataSize == ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT);
        Assert.assertTrue(manager.minTaskParallelism == 1);
        Assert.assertTrue(
                manager.slowStartMinSrcCompletionFraction == ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT);
        Assert.assertTrue(
                manager.slowStartMaxSrcCompletionFraction == ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT);

        final HashSet<Integer> scheduledTasks = new HashSet<Integer>();
        doAnswer(new Answer() {
            public Object answer(InvocationOnMock invocation) {
                Object[] args = invocation.getArguments();
                scheduledTasks.clear();
                List<TaskWithLocationHint> tasks = (List<TaskWithLocationHint>) args[0];
                for (TaskWithLocationHint task : tasks) {
                    scheduledTasks.add(task.getTaskIndex());
                }
                return null;
            }
        }).when(mockContext).scheduleVertexTasks(anyList());

        final Map<String, EdgeManagerPlugin> newEdgeManagers = new HashMap<String, EdgeManagerPlugin>();

        doAnswer(new Answer() {
            public Object answer(InvocationOnMock invocation) throws Exception {
                when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(2);
                newEdgeManagers.clear();
                for (Entry<String, EdgeManagerPluginDescriptor> entry : ((Map<String, EdgeManagerPluginDescriptor>) invocation
                        .getArguments()[2]).entrySet()) {

                    final UserPayload userPayload = entry.getValue().getUserPayload();
                    EdgeManagerPluginContext emContext = new EdgeManagerPluginContext() {
                        @Override
                        public UserPayload getUserPayload() {
                            return userPayload == null ? null : userPayload;
                        }

                        @Override
                        public String getSourceVertexName() {
                            return null;
                        }

                        @Override
                        public String getDestinationVertexName() {
                            return null;
                        }

                        @Override
                        public int getSourceVertexNumTasks() {
                            return 2;
                        }

                        @Override
                        public int getDestinationVertexNumTasks() {
                            return 2;
                        }
                    };
                    EdgeManagerPlugin edgeManager = ReflectionUtils.createClazzInstance(
                            entry.getValue().getClassName(), new Class[] { EdgeManagerPluginContext.class },
                            new Object[] { emContext });
                    edgeManager.initialize();
                    newEdgeManagers.put(entry.getKey(), edgeManager);
                }
                return null;
            }
        }).when(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(), anyMap());

        // check initialization
        manager = createManager(conf, mockContext, 0.1f, 0.1f); // Tez notified of reconfig
        verify(mockContext, times(2)).vertexReconfigurationPlanned();
        Assert.assertTrue(manager.bipartiteSources == 2);

        // source vertices have 0 tasks.
        when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(0);
        when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(0);
        when(mockContext.getVertexNumTasks(mockSrcVertexId3)).thenReturn(1);

        // check waiting for notification before scheduling
        manager.onVertexStarted(null);
        Assert.assertFalse(manager.pendingTasks.isEmpty());
        // source vertices have 0 tasks. so only 1 notification needed. triggers scheduling
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.isEmpty());
        verify(mockContext, times(1)).doneReconfiguringVertex(); // reconfig done
        Assert.assertTrue(scheduledTasks.size() == 4); // all tasks scheduled
        scheduledTasks.clear();
        // TODO TEZ-1714 locking verify(mockContext, times(1)).vertexManagerDone(); // notified after scheduling all tasks

        // check scheduling only after onVertexStarted
        manager = createManager(conf, mockContext, 0.1f, 0.1f); // Tez notified of reconfig
        verify(mockContext, times(3)).vertexReconfigurationPlanned();
        Assert.assertTrue(manager.bipartiteSources == 2);
        // source vertices have 0 tasks. so only 1 notification needed. does not trigger scheduling
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        verify(mockContext, times(1)).doneReconfiguringVertex(); // reconfig done
        Assert.assertTrue(scheduledTasks.size() == 0); // no tasks scheduled
        manager.onVertexStarted(null);
        verify(mockContext, times(2)).doneReconfiguringVertex(); // reconfig done
        Assert.assertTrue(manager.pendingTasks.isEmpty());
        Assert.assertTrue(scheduledTasks.size() == 4); // all tasks scheduled

        when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(2);
        when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(2);

        ByteBuffer payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(5000L).build().toByteString()
                .asReadOnlyByteBuffer();
        VertexManagerEvent vmEvent = VertexManagerEvent.create("Vertex", payload);
        // parallelism not change due to large data size
        manager = createManager(conf, mockContext, 0.1f, 0.1f);
        verify(mockContext, times(4)).vertexReconfigurationPlanned(); // Tez notified of reconfig
        manager.onVertexStarted(null);
        Assert.assertTrue(manager.pendingTasks.size() == 4); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        manager.onVertexManagerEventReceived(vmEvent);

        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        verify(mockContext, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class), anyMap(),
                anyMap());
        verify(mockContext, times(2)).doneReconfiguringVertex();
        // trigger scheduling
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        verify(mockContext, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class), anyMap(),
                anyMap());
        verify(mockContext, times(3)).doneReconfiguringVertex(); // reconfig done
        Assert.assertEquals(0, manager.pendingTasks.size()); // all tasks scheduled
        Assert.assertEquals(4, scheduledTasks.size());
        // TODO TEZ-1714 locking verify(mockContext, times(2)).vertexManagerDone(); // notified after scheduling all tasks
        Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
        Assert.assertEquals(5000L, manager.completedSourceTasksOutputSize);

        /**
         * Test for TEZ-978
         * Delay determining parallelism until enough data has been received.
         */
        scheduledTasks.clear();
        payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(1L).build().toByteString()
                .asReadOnlyByteBuffer();
        vmEvent = VertexManagerEvent.create("Vertex", payload);

        //min/max fraction of 0.01/0.75 would ensure that we hit determineParallelism code path on receiving first event itself.
        manager = createManager(conf, mockContext, 0.01f, 0.75f);
        manager.onVertexStarted(null);
        Assert.assertEquals(4, manager.pendingTasks.size()); // no tasks scheduled
        Assert.assertEquals(4, manager.totalNumBipartiteSourceTasks);
        Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);

        //First task in src1 completed with small payload
        manager.onVertexManagerEventReceived(vmEvent); //small payload
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        Assert.assertTrue(manager.determineParallelismAndApply() == false);
        Assert.assertEquals(4, manager.pendingTasks.size());
        Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
        Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
        Assert.assertEquals(1, manager.numVertexManagerEventsReceived);
        Assert.assertEquals(1L, manager.completedSourceTasksOutputSize);

        //Second task in src1 completed with small payload
        manager.onVertexManagerEventReceived(vmEvent); //small payload
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        //Still overall data gathered has not reached threshold; So, ensure parallelism can be determined later
        Assert.assertTrue(manager.determineParallelismAndApply() == false);
        Assert.assertEquals(4, manager.pendingTasks.size());
        Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
        Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
        Assert.assertEquals(2, manager.numVertexManagerEventsReceived);
        Assert.assertEquals(2L, manager.completedSourceTasksOutputSize);

        //First task in src2 completed (with larger payload) to trigger determining parallelism
        payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(1200L).build().toByteString()
                .asReadOnlyByteBuffer();
        vmEvent = VertexManagerEvent.create("Vertex", payload);
        manager.onVertexManagerEventReceived(vmEvent);
        Assert.assertTrue(manager.determineParallelismAndApply()); //ensure parallelism is determined
        verify(mockContext, times(1)).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(),
                anyMap());
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
        Assert.assertEquals(1, manager.pendingTasks.size());
        Assert.assertEquals(1, scheduledTasks.size());
        Assert.assertEquals(2, manager.numBipartiteSourceTasksCompleted);
        Assert.assertEquals(3, manager.numVertexManagerEventsReceived);
        Assert.assertEquals(1202L, manager.completedSourceTasksOutputSize);

        //Test for max fraction. Min fraction is just instruction to framework, but honor max fraction
        when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(20);
        when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(20);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(40);
        scheduledTasks.clear();
        payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(100L).build().toByteString()
                .asReadOnlyByteBuffer();
        vmEvent = VertexManagerEvent.create("Vertex", payload);

        //min/max fraction of 0.0/0.2
        manager = createManager(conf, mockContext, 0.0f, 0.2f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertEquals(40, manager.pendingTasks.size()); // no tasks scheduled
        Assert.assertEquals(40, manager.totalNumBipartiteSourceTasks);
        Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
        //send 7 events with payload size as 100
        for (int i = 0; i < 7; i++) {
            manager.onVertexManagerEventReceived(vmEvent); //small payload
            manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(i));
            //should not change parallelism
            verify(mockContext, times(0)).setVertexParallelism(eq(4), any(VertexLocationHint.class), anyMap(),
                    anyMap());
        }
        //send 8th event with payload size as 100
        manager.onVertexManagerEventReceived(vmEvent);
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(8));
        //Since max threshold (40 * 0.2 = 8) is met, vertex manager should determine parallelism
        verify(mockContext, times(1)).setVertexParallelism(eq(4), any(VertexLocationHint.class), anyMap(),
                anyMap());

        //reset context for next test
        when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(2);
        when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(2);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(4);

        // parallelism changed due to small data size
        scheduledTasks.clear();
        payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(500L).build().toByteString()
                .asReadOnlyByteBuffer();
        vmEvent = VertexManagerEvent.create("Vertex", payload);

        manager = createManager(conf, mockContext, 0.5f, 0.5f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertEquals(4, manager.pendingTasks.size()); // no tasks scheduled
        Assert.assertEquals(4, manager.totalNumBipartiteSourceTasks);
        // task completion from non-bipartite stage does nothing
        manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
        Assert.assertEquals(4, manager.pendingTasks.size()); // no tasks scheduled
        Assert.assertEquals(4, manager.totalNumBipartiteSourceTasks);
        Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
        manager.onVertexManagerEventReceived(vmEvent);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        Assert.assertEquals(4, manager.pendingTasks.size());
        Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
        Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
        Assert.assertEquals(1, manager.numVertexManagerEventsReceived);
        Assert.assertEquals(500L, manager.completedSourceTasksOutputSize);
        // ignore duplicate completion
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        Assert.assertEquals(4, manager.pendingTasks.size());
        Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
        Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
        Assert.assertEquals(500L, manager.completedSourceTasksOutputSize);

        manager.onVertexManagerEventReceived(vmEvent);
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
        // managedVertex tasks reduced
        verify(mockContext, times(2)).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(),
                anyMap());
        Assert.assertEquals(2, newEdgeManagers.size());
        // TODO improve tests for parallelism
        Assert.assertEquals(0, manager.pendingTasks.size()); // all tasks scheduled
        Assert.assertEquals(2, scheduledTasks.size());
        Assert.assertTrue(scheduledTasks.contains(new Integer(0)));
        Assert.assertTrue(scheduledTasks.contains(new Integer(1)));
        Assert.assertEquals(2, manager.numBipartiteSourceTasksCompleted);
        Assert.assertEquals(2, manager.numVertexManagerEventsReceived);
        Assert.assertEquals(1000L, manager.completedSourceTasksOutputSize);

        // more completions dont cause recalculation of parallelism
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
        verify(mockContext, times(2)).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(),
                anyMap());
        Assert.assertEquals(2, newEdgeManagers.size());

        EdgeManagerPlugin edgeManager = newEdgeManagers.values().iterator().next();
        Map<Integer, List<Integer>> targets = Maps.newHashMap();
        DataMovementEvent dmEvent = DataMovementEvent.create(1, ByteBuffer.wrap(new byte[0]));
        // 4 source task outputs - same as original number of partitions
        Assert.assertEquals(4, edgeManager.getNumSourceTaskPhysicalOutputs(0));
        // 4 destination task inputs - 2 source tasks + 2 merged partitions
        Assert.assertEquals(4, edgeManager.getNumDestinationTaskPhysicalInputs(0));
        edgeManager.routeDataMovementEventToDestination(dmEvent, 1, dmEvent.getSourceIndex(), targets);
        Assert.assertEquals(1, targets.size());
        Map.Entry<Integer, List<Integer>> e = targets.entrySet().iterator().next();
        Assert.assertEquals(0, e.getKey().intValue());
        Assert.assertEquals(1, e.getValue().size());
        Assert.assertEquals(3, e.getValue().get(0).intValue());
        targets.clear();
        dmEvent = DataMovementEvent.create(2, ByteBuffer.wrap(new byte[0]));
        edgeManager.routeDataMovementEventToDestination(dmEvent, 0, dmEvent.getSourceIndex(), targets);
        Assert.assertEquals(1, targets.size());
        e = targets.entrySet().iterator().next();
        Assert.assertEquals(1, e.getKey().intValue());
        Assert.assertEquals(1, e.getValue().size());
        Assert.assertEquals(0, e.getValue().get(0).intValue());
        targets.clear();
        edgeManager.routeInputSourceTaskFailedEventToDestination(2, targets);
        Assert.assertEquals(2, targets.size());
        for (Map.Entry<Integer, List<Integer>> entry : targets.entrySet()) {
            Assert.assertTrue(entry.getKey().intValue() == 0 || entry.getKey().intValue() == 1);
            Assert.assertEquals(2, entry.getValue().size());
            Assert.assertEquals(4, entry.getValue().get(0).intValue());
            Assert.assertEquals(5, entry.getValue().get(1).intValue());
        }
    }

    @SuppressWarnings({ "unchecked", "rawtypes" })
    @Test(timeout = 5000)
    public void testShuffleVertexManagerSlowStart() {
        Configuration conf = new Configuration();
        ShuffleVertexManager manager = null;
        HashMap<String, EdgeProperty> mockInputVertices = new HashMap<String, EdgeProperty>();
        String mockSrcVertexId1 = "Vertex1";
        EdgeProperty eProp1 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));
        String mockSrcVertexId2 = "Vertex2";
        EdgeProperty eProp2 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));
        String mockSrcVertexId3 = "Vertex3";
        EdgeProperty eProp3 = EdgeProperty.create(EdgeProperty.DataMovementType.BROADCAST,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));

        String mockManagedVertexId = "Vertex4";

        VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
        when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
        when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);

        // fail if there is no bipartite src vertex
        mockInputVertices.put(mockSrcVertexId3, eProp3);
        try {
            manager = createManager(conf, mockContext, 0.1f, 0.1f);
            Assert.assertFalse(true);
        } catch (TezUncheckedException e) {
            Assert.assertTrue(e.getMessage().contains("Atleast 1 bipartite source should exist"));
        }

        mockInputVertices.put(mockSrcVertexId1, eProp1);
        mockInputVertices.put(mockSrcVertexId2, eProp2);

        // check initialization
        manager = createManager(conf, mockContext, 0.1f, 0.1f);
        Assert.assertTrue(manager.bipartiteSources == 2);

        final HashSet<Integer> scheduledTasks = new HashSet<Integer>();
        doAnswer(new Answer() {
            public Object answer(InvocationOnMock invocation) {
                Object[] args = invocation.getArguments();
                scheduledTasks.clear();
                List<TaskWithLocationHint> tasks = (List<TaskWithLocationHint>) args[0];
                for (TaskWithLocationHint task : tasks) {
                    scheduledTasks.add(task.getTaskIndex());
                }
                return null;
            }
        }).when(mockContext).scheduleVertexTasks(anyList());

        // source vertices have 0 tasks. immediate start of all managed tasks
        when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(0);
        when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(0);
        manager.onVertexStarted(null);
        Assert.assertTrue(manager.pendingTasks.isEmpty());
        Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled

        when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(2);
        when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(2);

        try {
            // source vertex have some tasks. min < 0.
            manager = createManager(conf, mockContext, -0.1f, 0);
            Assert.assertTrue(false); // should not come here
        } catch (IllegalArgumentException e) {
            Assert.assertTrue(e.getMessage().contains("Invalid values for slowStartMinSrcCompletionFraction"));
        }

        try {
            // source vertex have some tasks. min > max
            manager = createManager(conf, mockContext, 0.5f, 0.3f);
            Assert.assertTrue(false); // should not come here
        } catch (IllegalArgumentException e) {
            Assert.assertTrue(e.getMessage().contains("Invalid values for slowStartMinSrcCompletionFraction"));
        }

        // source vertex have some tasks. min, max == 0
        manager = createManager(conf, mockContext, 0, 0);
        manager.onVertexStarted(null);
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        Assert.assertTrue(manager.totalTasksToSchedule == 3);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 0);
        // all source vertices need to be configured
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.isEmpty());
        Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled

        // min, max > 0 and min == max
        manager = createManager(conf, mockContext, 0.25f, 0.25f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        // task completion from non-bipartite stage does nothing
        manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 0);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.isEmpty());
        Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 1);

        // min, max > 0 and min == max == absolute max 1.0
        manager = createManager(conf, mockContext, 1.0f, 1.0f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        // task completion from non-bipartite stage does nothing
        manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 0);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 1);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.size() == 3);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 3);
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.isEmpty());
        Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 4);

        // min, max > 0 and min == max
        manager = createManager(conf, mockContext, 1.0f, 1.0f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        // task completion from non-bipartite stage does nothing
        manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 0);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 1);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.size() == 3);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3);
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 3);
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.isEmpty());
        Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 4);

        // min, max > and min < max
        manager = createManager(conf, mockContext, 0.25f, 0.75f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.size() == 2);
        Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
        // completion of same task again should not get counted
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.size() == 2);
        Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 0);
        Assert.assertTrue(scheduledTasks.size() == 2); // 2 tasks scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 3);
        scheduledTasks.clear();
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1)); // we are done. no action
        Assert.assertTrue(manager.pendingTasks.size() == 0);
        Assert.assertTrue(scheduledTasks.size() == 0); // no task scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 4);

        // min, max > and min < max
        manager = createManager(conf, mockContext, 0.25f, 1.0f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.size() == 2);
        Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
        manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 1);
        Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 3);
        manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.size() == 0);
        Assert.assertTrue(scheduledTasks.size() == 1); // no task scheduled
        Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 4);

    }

    /**
     * Tasks should be scheduled only when all source vertices are configured completely
     */
    @Test(timeout = 5000)
    public void test_Tez1649_with_scatter_gather_edges() {
        Configuration conf = new Configuration();
        conf.setBoolean(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, true);
        conf.setLong(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, 1000L);
        ShuffleVertexManager manager = null;

        HashMap<String, EdgeProperty> mockInputVertices_R2 = new HashMap<String, EdgeProperty>();
        String r1 = "R1";
        EdgeProperty eProp1 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));
        String m2 = "M2";
        EdgeProperty eProp2 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));
        String m3 = "M3";
        EdgeProperty eProp3 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));

        final String mockManagedVertexId_R2 = "R2";
        mockInputVertices_R2.put(r1, eProp1);
        mockInputVertices_R2.put(m2, eProp2);
        mockInputVertices_R2.put(m3, eProp3);

        final VertexManagerPluginContext mockContext_R2 = mock(VertexManagerPluginContext.class);
        when(mockContext_R2.getInputVertexEdgeProperties()).thenReturn(mockInputVertices_R2);
        when(mockContext_R2.getVertexName()).thenReturn(mockManagedVertexId_R2);
        when(mockContext_R2.getVertexNumTasks(mockManagedVertexId_R2)).thenReturn(3);
        when(mockContext_R2.getVertexNumTasks(r1)).thenReturn(3);
        when(mockContext_R2.getVertexNumTasks(m2)).thenReturn(3);
        when(mockContext_R2.getVertexNumTasks(m3)).thenReturn(3);

        final Map<String, EdgeManagerPlugin> edgeManagerR2 = new HashMap<String, EdgeManagerPlugin>();
        doAnswer(new Answer() {
            public Object answer(InvocationOnMock invocation) throws Exception {
                when(mockContext_R2.getVertexNumTasks(mockManagedVertexId_R2)).thenReturn(2);
                edgeManagerR2.clear();
                for (Entry<String, EdgeManagerPluginDescriptor> entry : ((Map<String, EdgeManagerPluginDescriptor>) invocation
                        .getArguments()[2]).entrySet()) {

                    final UserPayload userPayload = entry.getValue().getUserPayload();
                    EdgeManagerPluginContext emContext = new EdgeManagerPluginContext() {
                        @Override
                        public UserPayload getUserPayload() {
                            return userPayload == null ? null : userPayload;
                        }

                        @Override
                        public String getSourceVertexName() {
                            return null;
                        }

                        @Override
                        public String getDestinationVertexName() {
                            return null;
                        }

                        @Override
                        public int getSourceVertexNumTasks() {
                            return 2;
                        }

                        @Override
                        public int getDestinationVertexNumTasks() {
                            return 2;
                        }
                    };
                    EdgeManagerPlugin edgeManager = ReflectionUtils.createClazzInstance(
                            entry.getValue().getClassName(), new Class[] { EdgeManagerPluginContext.class },
                            new Object[] { emContext });
                    edgeManager.initialize();
                    edgeManagerR2.put(entry.getKey(), edgeManager);
                }
                return null;
            }
        }).when(mockContext_R2).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(), anyMap());

        ByteBuffer payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(50L).build().toByteString()
                .asReadOnlyByteBuffer();
        VertexManagerEvent vmEvent = VertexManagerEvent.create("Vertex", payload);

        // check initialization
        manager = createManager(conf, mockContext_R2, 0.001f, 0.001f);
        Assert.assertTrue(manager.bipartiteSources == 3);

        final HashSet<Integer> scheduledTasks = new HashSet<Integer>();
        doAnswer(new Answer() {
            public Object answer(InvocationOnMock invocation) {
                Object[] args = invocation.getArguments();
                scheduledTasks.clear();
                List<TaskWithLocationHint> tasks = (List<TaskWithLocationHint>) args[0];
                for (TaskWithLocationHint task : tasks) {
                    scheduledTasks.add(task.getTaskIndex());
                }
                return null;
            }
        }).when(mockContext_R2).scheduleVertexTasks(anyList());

        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(m2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(m3, VertexState.CONFIGURED));

        manager.onVertexManagerEventReceived(vmEvent);
        Assert.assertEquals(3, manager.pendingTasks.size()); // no tasks scheduled
        Assert.assertEquals(9, manager.totalNumBipartiteSourceTasks);
        Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);

        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 9);

        //Send events for all tasks of m3.
        manager.onSourceTaskCompleted(m3, new Integer(0));
        manager.onSourceTaskCompleted(m3, new Integer(1));
        manager.onSourceTaskCompleted(m3, new Integer(2));

        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 9);

        //Send an event for m2. But still we need to wait for at least 1 event from r1.
        manager.onSourceTaskCompleted(m2, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 9);

        //Ensure that setVertexParallelism is not called for R2.
        verify(mockContext_R2, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class), anyMap(),
                anyMap());

        // complete configuration of r1 triggers the scheduling
        manager.onVertexStateUpdated(new VertexStateUpdate(r1, VertexState.CONFIGURED));
        verify(mockContext_R2, times(1)).setVertexParallelism(eq(1), any(VertexLocationHint.class), anyMap(),
                anyMap());
        Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
        Assert.assertTrue(scheduledTasks.size() == 3);

        //try with zero task vertices
        scheduledTasks.clear();
        when(mockContext_R2.getInputVertexEdgeProperties()).thenReturn(mockInputVertices_R2);
        when(mockContext_R2.getVertexName()).thenReturn(mockManagedVertexId_R2);
        when(mockContext_R2.getVertexNumTasks(mockManagedVertexId_R2)).thenReturn(3);
        when(mockContext_R2.getVertexNumTasks(r1)).thenReturn(0);
        when(mockContext_R2.getVertexNumTasks(m2)).thenReturn(0);
        when(mockContext_R2.getVertexNumTasks(m3)).thenReturn(3);

        manager = createManager(conf, mockContext_R2, 0.001f, 0.001f);
        manager.onVertexStarted(null);
        Assert.assertEquals(3, manager.pendingTasks.size()); // no tasks scheduled
        Assert.assertEquals(3, manager.totalNumBipartiteSourceTasks);
        Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);

        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 3);

        // Only need completed configuration notification from m3
        manager.onVertexStateUpdated(new VertexStateUpdate(m3, VertexState.CONFIGURED));
        manager.onSourceTaskCompleted(m3, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
        Assert.assertTrue(scheduledTasks.size() == 3);
    }

    @Test(timeout = 5000)
    public void test_Tez1649_with_mixed_edges() {
        Configuration conf = new Configuration();
        conf.setBoolean(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, true);
        conf.setLong(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, 1000L);
        ShuffleVertexManager manager = null;

        HashMap<String, EdgeProperty> mockInputVertices = new HashMap<String, EdgeProperty>();
        String r1 = "R1";
        EdgeProperty eProp1 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));
        String m2 = "M2";
        EdgeProperty eProp2 = EdgeProperty.create(EdgeProperty.DataMovementType.BROADCAST,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));
        String m3 = "M3";
        EdgeProperty eProp3 = EdgeProperty.create(EdgeProperty.DataMovementType.BROADCAST,
                EdgeProperty.DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("out"),
                InputDescriptor.create("in"));

        final String mockManagedVertexId = "R2";

        mockInputVertices.put(r1, eProp1);
        mockInputVertices.put(m2, eProp2);
        mockInputVertices.put(m3, eProp3);

        VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
        when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
        when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
        when(mockContext.getVertexNumTasks(r1)).thenReturn(3);
        when(mockContext.getVertexNumTasks(m2)).thenReturn(3);
        when(mockContext.getVertexNumTasks(m3)).thenReturn(3);

        // check initialization
        manager = createManager(conf, mockContext, 0.001f, 0.001f);
        Assert.assertTrue(manager.bipartiteSources == 1);

        final HashSet<Integer> scheduledTasks = new HashSet<Integer>();
        doAnswer(new Answer() {
            public Object answer(InvocationOnMock invocation) {
                Object[] args = invocation.getArguments();
                scheduledTasks.clear();
                List<TaskWithLocationHint> tasks = (List<TaskWithLocationHint>) args[0];
                for (TaskWithLocationHint task : tasks) {
                    scheduledTasks.add(task.getTaskIndex());
                }
                return null;
            }
        }).when(mockContext).scheduleVertexTasks(anyList());

        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(r1, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(m2, VertexState.CONFIGURED));

        Assert.assertEquals(3, manager.pendingTasks.size()); // no tasks scheduled
        Assert.assertEquals(3, manager.totalNumBipartiteSourceTasks);
        Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);

        //Send events for 2 tasks of r1.
        manager.onSourceTaskCompleted(r1, new Integer(0));
        manager.onSourceTaskCompleted(r1, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 3);

        //Send an event for m2.
        manager.onSourceTaskCompleted(m2, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 3);

        //Send an event for m2.
        manager.onVertexStateUpdated(new VertexStateUpdate(m3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
        Assert.assertTrue(scheduledTasks.size() == 3);

        //Scenario when numBipartiteSourceTasksCompleted == totalNumBipartiteSourceTasks.
        //Still, wait for a configuration to be completed from other edges
        scheduledTasks.clear();
        manager = createManager(conf, mockContext, 0.001f, 0.001f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(r1, VertexState.CONFIGURED));

        when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
        when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
        when(mockContext.getVertexNumTasks(r1)).thenReturn(3);
        when(mockContext.getVertexNumTasks(m2)).thenReturn(3);
        when(mockContext.getVertexNumTasks(m3)).thenReturn(3);
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 3);

        manager.onSourceTaskCompleted(r1, new Integer(0));
        manager.onSourceTaskCompleted(r1, new Integer(1));
        manager.onSourceTaskCompleted(r1, new Integer(2));
        //Tasks from non-scatter edges of m2 and m3 are not complete.
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        manager.onVertexStateUpdated(new VertexStateUpdate(m2, VertexState.CONFIGURED));
        manager.onVertexStateUpdated(new VertexStateUpdate(m3, VertexState.CONFIGURED));
        //Got an event from other edges. Schedule all
        Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
        Assert.assertTrue(scheduledTasks.size() == 3);

        //try with a zero task vertex (with non-scatter-gather edges)
        scheduledTasks.clear();
        manager = createManager(conf, mockContext, 0.001f, 0.001f);
        manager.onVertexStarted(null);
        when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
        when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
        when(mockContext.getVertexNumTasks(r1)).thenReturn(3); //scatter gather
        when(mockContext.getVertexNumTasks(m2)).thenReturn(0); //broadcast
        when(mockContext.getVertexNumTasks(m3)).thenReturn(3); //broadcast

        manager = createManager(conf, mockContext, 0.001f, 0.001f);
        manager.onVertexStarted(null);
        manager.onVertexStateUpdated(new VertexStateUpdate(r1, VertexState.CONFIGURED));

        Assert.assertEquals(3, manager.pendingTasks.size()); // no tasks scheduled
        Assert.assertEquals(3, manager.totalNumBipartiteSourceTasks);
        Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);

        //Send 2 events for tasks of r1.
        manager.onSourceTaskCompleted(r1, new Integer(0));
        manager.onSourceTaskCompleted(r1, new Integer(1));
        Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
        Assert.assertTrue(scheduledTasks.size() == 0);

        // event from m3 triggers scheduling. no need for m2 since it has 0 tasks
        manager.onVertexStateUpdated(new VertexStateUpdate(m3, VertexState.CONFIGURED));
        Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
        Assert.assertTrue(scheduledTasks.size() == 3);

        //try with all zero task vertices in non-SG edges
        scheduledTasks.clear();
        manager = createManager(conf, mockContext, 0.001f, 0.001f);
        manager.onVertexStarted(null);
        when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
        when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
        when(mockContext.getVertexNumTasks(r1)).thenReturn(3); //scatter gather
        when(mockContext.getVertexNumTasks(m2)).thenReturn(0); //broadcast
        when(mockContext.getVertexNumTasks(m3)).thenReturn(0); //broadcast

        //Send 1 events for tasks of r1.
        manager.onVertexStateUpdated(new VertexStateUpdate(r1, VertexState.CONFIGURED));
        manager.onSourceTaskCompleted(r1, new Integer(0));
        Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
        Assert.assertTrue(scheduledTasks.size() == 3);
    }

    private ShuffleVertexManager createManager(Configuration conf, VertexManagerPluginContext context, float min,
            float max) {
        conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, min);
        conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, max);
        UserPayload payload;
        try {
            payload = TezUtils.createUserPayloadFromConf(conf);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        when(context.getUserPayload()).thenReturn(payload);
        ShuffleVertexManager manager = new ShuffleVertexManager(context);
        manager.initialize();
        return manager;
    }
}