org.apache.tez.dag.library.vertexmanager.TestShuffleVertexManagerUtils.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.tez.dag.library.vertexmanager.TestShuffleVertexManagerUtils.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.dag.library.vertexmanager;

import com.google.protobuf.ByteString;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.tez.common.ReflectionUtils;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.EdgeManagerPlugin;
import org.apache.tez.dag.api.EdgeManagerPluginContext;
import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.library.vertexmanager.FairShuffleVertexManager.FairRoutingType;
import org.apache.tez.dag.library.vertexmanager.FairShuffleVertexManager.FairShuffleVertexManagerConfigBuilder;
import org.apache.tez.dag.records.TaskAttemptIdentifierImpl;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.api.TaskIdentifier;
import org.apache.tez.runtime.api.VertexIdentifier;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.roaringbitmap.RoaringBitmap;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyList;
import static org.mockito.Matchers.anyMap;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

@SuppressWarnings({ "unchecked", "rawtypes" })
public class TestShuffleVertexManagerUtils {
    static long MB = 1024l * 1024l;
    TezVertexID vertexId = TezVertexID.fromString("vertex_1436907267600_195589_1_00");
    int taskId = 0;

    VertexManagerPluginContext createVertexManagerContext(String mockSrcVertexId1, int numTasksSrcVertexId1,
            String mockSrcVertexId2, int numTasksSrcVertexId2, String mockSrcVertexId3, int numTasksSrcVertexId3,
            String mockManagedVertexId, int numTasksmockManagedVertexId, List<Integer> scheduledTasks,
            Map<String, EdgeManagerPlugin> newEdgeManagers) {
        HashMap<String, EdgeProperty> mockInputVertices = new HashMap<String, EdgeProperty>();
        EdgeProperty eProp1 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, EdgeProperty.SchedulingType.SEQUENTIAL,
                OutputDescriptor.create("out"), InputDescriptor.create("in"));
        EdgeProperty eProp2 = EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER,
                EdgeProperty.DataSourceType.PERSISTED, EdgeProperty.SchedulingType.SEQUENTIAL,
                OutputDescriptor.create("out"), InputDescriptor.create("in"));
        EdgeProperty eProp3 = EdgeProperty.create(EdgeProperty.DataMovementType.BROADCAST,
                EdgeProperty.DataSourceType.PERSISTED, EdgeProperty.SchedulingType.SEQUENTIAL,
                OutputDescriptor.create("out"), InputDescriptor.create("in"));

        mockInputVertices.put(mockSrcVertexId1, eProp1);
        mockInputVertices.put(mockSrcVertexId2, eProp2);
        mockInputVertices.put(mockSrcVertexId3, eProp3);

        final VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
        when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
        when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
        when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(numTasksSrcVertexId1);
        when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(numTasksSrcVertexId2);
        when(mockContext.getVertexNumTasks(mockSrcVertexId3)).thenReturn(numTasksSrcVertexId3);
        when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(numTasksmockManagedVertexId);
        doAnswer(new ScheduledTasksAnswer(scheduledTasks)).when(mockContext).scheduleTasks(anyList());
        doAnswer(new reconfigVertexAnswer(mockContext, mockManagedVertexId, newEdgeManagers)).when(mockContext)
                .reconfigureVertex(anyInt(), any(VertexLocationHint.class), anyMap());
        return mockContext;
    }

    VertexManagerEvent getVertexManagerEvent(long[] sizes, long totalSize, String vertexName) throws IOException {
        return getVertexManagerEvent(sizes, totalSize, vertexName, false);
    }

    VertexManagerEvent getVertexManagerEvent(long[] sizes, long totalSize, String vertexName,
            boolean reportDetailedStats) throws IOException {
        ByteBuffer payload;
        if (sizes != null) {
            RoaringBitmap partitionStats = ShuffleUtils.getPartitionStatsForPhysicalOutput(sizes);
            DataOutputBuffer dout = new DataOutputBuffer();
            partitionStats.serialize(dout);
            ByteString partitionStatsBytes = TezCommonUtils.compressByteArrayToByteString(dout.getData());
            if (reportDetailedStats) {
                payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(totalSize)
                        .setDetailedPartitionStats(ShuffleUtils.getDetailedPartitionStatsForPhysicalOutput(sizes))
                        .build().toByteString().asReadOnlyByteBuffer();
            } else {
                payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(totalSize)
                        .setPartitionStats(partitionStatsBytes).build().toByteString().asReadOnlyByteBuffer();
            }

        } else {
            payload = VertexManagerEventPayloadProto.newBuilder().setOutputSize(totalSize).build().toByteString()
                    .asReadOnlyByteBuffer();
        }
        TaskAttemptIdentifierImpl taId = new TaskAttemptIdentifierImpl("dag", vertexName,
                TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, taskId++), 0));
        VertexManagerEvent vmEvent = VertexManagerEvent.create(vertexName, payload);
        vmEvent.setProducerAttemptIdentifier(taId);
        return vmEvent;
    }

    public static TaskAttemptIdentifier createTaskAttemptIdentifier(String vName, int tId) {
        VertexIdentifier mockVertex = mock(VertexIdentifier.class);
        when(mockVertex.getName()).thenReturn(vName);
        TaskIdentifier mockTask = mock(TaskIdentifier.class);
        when(mockTask.getIdentifier()).thenReturn(tId);
        when(mockTask.getVertexIdentifier()).thenReturn(mockVertex);
        TaskAttemptIdentifier mockAttempt = mock(TaskAttemptIdentifier.class);
        when(mockAttempt.getIdentifier()).thenReturn(0);
        when(mockAttempt.getTaskIdentifier()).thenReturn(mockTask);
        return mockAttempt;
    }

    static public ShuffleVertexManagerBase createManager(
            Class<? extends ShuffleVertexManagerBase> shuffleVertexManagerClass, Configuration conf,
            VertexManagerPluginContext context, Boolean enableAutoParallelism, Long desiredTaskInputSize, Float min,
            Float max) {
        if (shuffleVertexManagerClass.equals(ShuffleVertexManager.class)) {
            return createShuffleVertexManager(conf, context, enableAutoParallelism, desiredTaskInputSize, min, max);
        } else if (shuffleVertexManagerClass.equals(FairShuffleVertexManager.class)) {
            FairRoutingType fairRoutingType = null;
            if (enableAutoParallelism != null) {
                fairRoutingType = enableAutoParallelism ? FairRoutingType.REDUCE_PARALLELISM : FairRoutingType.NONE;
            }
            return createFairShuffleVertexManager(conf, context, fairRoutingType, desiredTaskInputSize, min, max);
        } else {
            return null;
        }
    }

    static ShuffleVertexManager createShuffleVertexManager(Configuration conf, VertexManagerPluginContext context,
            Boolean enableAutoParallelism, Long desiredTaskInputSize, Float min, Float max) {
        if (min != null) {
            conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, min);
        } else {
            conf.unset(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION);
        }
        if (max != null) {
            conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, max);
        } else {
            conf.unset(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION);
        }
        if (enableAutoParallelism != null) {
            conf.setBoolean(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
                    enableAutoParallelism);
        }
        if (desiredTaskInputSize != null) {
            conf.setLong(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE,
                    desiredTaskInputSize);
        }
        UserPayload payload;
        try {
            payload = TezUtils.createUserPayloadFromConf(conf);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        when(context.getUserPayload()).thenReturn(payload);
        ShuffleVertexManager manager = new ShuffleVertexManager(context);
        manager.initialize();
        return manager;
    }

    static FairShuffleVertexManager createFairShuffleVertexManager(Configuration conf,
            VertexManagerPluginContext context, FairRoutingType fairRoutingType, Long desiredTaskInputSize,
            Float min, Float max) {
        FairShuffleVertexManagerConfigBuilder builder = FairShuffleVertexManager.createConfigBuilder(conf);
        if (min != null) {
            builder.setSlowStartMinSrcCompletionFraction(min);
        } else if (conf != null) {
            conf.unset(FairShuffleVertexManager.TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION);
        }
        if (max != null) {
            builder.setSlowStartMaxSrcCompletionFraction(max);
        } else if (conf != null) {
            conf.unset(FairShuffleVertexManager.TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION);
        }
        if (fairRoutingType != null) {
            builder.setAutoParallelism(fairRoutingType);
        }
        if (desiredTaskInputSize != null) {
            builder.setDesiredTaskInputSize(desiredTaskInputSize);
        }
        UserPayload payload = builder.build().getUserPayload();
        when(context.getUserPayload()).thenReturn(payload);
        FairShuffleVertexManager manager = new FairShuffleVertexManager(context);
        manager.initialize();
        return manager;
    }

    protected static class ScheduledTasksAnswer implements Answer<Object> {
        private List<Integer> scheduledTasks;

        public ScheduledTasksAnswer(List<Integer> scheduledTasks) {
            this.scheduledTasks = scheduledTasks;
        }

        @Override
        public Object answer(InvocationOnMock invocation) throws IOException {
            Object[] args = invocation.getArguments();
            scheduledTasks.clear();
            List<VertexManagerPluginContext.ScheduleTaskRequest> tasks = (List<VertexManagerPluginContext.ScheduleTaskRequest>) args[0];
            for (VertexManagerPluginContext.ScheduleTaskRequest task : tasks) {
                scheduledTasks.add(task.getTaskIndex());
            }
            return null;
        }
    }

    protected static class reconfigVertexAnswer implements Answer<Object> {

        private VertexManagerPluginContext mockContext;
        private String mockManagedVertexId;
        private Map<String, EdgeManagerPlugin> newEdgeManagers;

        public reconfigVertexAnswer(VertexManagerPluginContext mockContext, String mockManagedVertexId,
                Map<String, EdgeManagerPlugin> newEdgeManagers) {
            this.mockContext = mockContext;
            this.mockManagedVertexId = mockManagedVertexId;
            this.newEdgeManagers = newEdgeManagers;

        }

        public Object answer(InvocationOnMock invocation) throws Exception {
            final int numTasks = ((Integer) invocation.getArguments()[0]).intValue();
            when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(numTasks);
            if (newEdgeManagers != null) {
                newEdgeManagers.clear();
            }
            for (Map.Entry<String, EdgeProperty> entry : ((Map<String, EdgeProperty>) invocation.getArguments()[2])
                    .entrySet()) {

                EdgeManagerPluginDescriptor pluginDesc = entry.getValue().getEdgeManagerDescriptor();
                final UserPayload userPayload = pluginDesc.getUserPayload();
                EdgeManagerPluginContext emContext = new EdgeManagerPluginContext() {
                    @Override
                    public UserPayload getUserPayload() {
                        return userPayload == null ? null : userPayload;
                    }

                    @Override
                    public String getSourceVertexName() {
                        return null;
                    }

                    @Override
                    public String getDestinationVertexName() {
                        return null;
                    }

                    @Override
                    public int getSourceVertexNumTasks() {
                        return 2;
                    }

                    @Override
                    public int getDestinationVertexNumTasks() {
                        return numTasks;
                    }
                };
                if (newEdgeManagers != null) {
                    EdgeManagerPlugin edgeManager = ReflectionUtils.createClazzInstance(pluginDesc.getClassName(),
                            new Class[] { EdgeManagerPluginContext.class }, new Object[] { emContext });
                    edgeManager.initialize();
                    newEdgeManagers.put(entry.getKey(), edgeManager);
                }
            }
            return null;
        }
    }
}