org.apache.tez.mapreduce.common.TestMRInputAMSplitGenerator.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.tez.mapreduce.common.TestMRInputAMSplitGenerator.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.mapreduce.common;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.split.TezGroupedSplit;
import org.apache.tez.dag.api.DataSourceDescriptor;
import org.apache.tez.dag.api.UserPayload;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.hadoop.classification.InterfaceAudience.Private;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.mapreduce.TezTestUtils;
import org.apache.tez.mapreduce.input.MRInput;
import org.apache.tez.mapreduce.lib.MRInputUtils;
import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRSplitProto;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputConfigureVertexTasksEvent;
import org.apache.tez.runtime.api.events.InputDataInformationEvent;
import org.junit.Test;

import com.google.protobuf.ByteString;

public class TestMRInputAMSplitGenerator {

    private static String SPLITS_LENGTHS = "splits.length";

    @Test(timeout = 5000)
    public void testGroupSplitsDisabledSortSplitsEnabled() throws Exception {
        testGroupSplitsAndSortSplits(false, true);
    }

    @Test(timeout = 5000)
    public void testGroupSplitsDisabledSortSplitsDisabled() throws Exception {
        testGroupSplitsAndSortSplits(false, false);
    }

    @Test(timeout = 5000)
    public void testGroupSplitsEnabledSortSplitsEnabled() throws Exception {
        testGroupSplitsAndSortSplits(true, true);
    }

    @Test(timeout = 5000)
    public void testGroupSplitsEnabledSortSplitsDisabled() throws Exception {
        testGroupSplitsAndSortSplits(true, false);
    }

    private void testGroupSplitsAndSortSplits(boolean groupSplitsEnabled, boolean sortSplitsEnabled)
            throws Exception {
        Configuration conf = new Configuration();
        String[] splitLengths = new String[50];
        for (int i = 0; i < splitLengths.length; i++) {
            splitLengths[i] = Integer.toString(1000 * (i + 1));
        }
        conf.setStrings(SPLITS_LENGTHS, splitLengths);
        DataSourceDescriptor dataSource = MRInput.createConfigBuilder(conf, InputFormatForTest.class)
                .groupSplits(groupSplitsEnabled).sortSplits(sortSplitsEnabled).build();
        UserPayload userPayload = dataSource.getInputDescriptor().getUserPayload();

        InputInitializerContext context = new TezTestUtils.TezRootInputInitializerContextForTest(userPayload);
        MRInputAMSplitGenerator splitGenerator = new MRInputAMSplitGenerator(context);

        List<Event> events = splitGenerator.initialize();

        assertTrue(events.get(0) instanceof InputConfigureVertexTasksEvent);
        boolean shuffled = false;
        InputSplit previousIs = null;
        int numRawInputSplits = 0;
        for (int i = 1; i < events.size(); i++) {
            assertTrue(events.get(i) instanceof InputDataInformationEvent);
            InputDataInformationEvent diEvent = (InputDataInformationEvent) (events.get(i));
            assertNull(diEvent.getDeserializedUserPayload());
            assertNotNull(diEvent.getUserPayload());
            MRSplitProto eventProto = MRSplitProto.parseFrom(ByteString.copyFrom(diEvent.getUserPayload()));
            InputSplit is = MRInputUtils.getNewSplitDetailsFromEvent(eventProto, new Configuration());
            if (groupSplitsEnabled) {
                numRawInputSplits += ((TezGroupedSplit) is).getGroupedSplits().size();
                for (InputSplit inputSplit : ((TezGroupedSplit) is).getGroupedSplits()) {
                    assertTrue(inputSplit instanceof InputSplitForTest);
                }
                assertTrue(((TezGroupedSplit) is).getGroupedSplits().get(0) instanceof InputSplitForTest);
            } else {
                numRawInputSplits++;
                assertTrue(is instanceof InputSplitForTest);
            }
            // The splits in the list returned from InputFormat has ascending
            // size in order.
            // If sortSplitsEnabled is true, MRInputAMSplitGenerator will sort the
            // splits in descending order.
            // If sortSplitsEnabled is false, MRInputAMSplitGenerator will shuffle
            // the splits.
            if (previousIs != null) {
                if (sortSplitsEnabled) {
                    assertTrue(is.getLength() <= previousIs.getLength());
                } else {
                    shuffled |= (is.getLength() > previousIs.getLength());
                }
            }
            previousIs = is;
        }
        assertEquals(splitLengths.length, numRawInputSplits);
        if (!sortSplitsEnabled) {
            assertTrue(shuffled);
        }
    }

    private static class InputFormatForTest extends InputFormat<IntWritable, IntWritable> {

        @Override
        public RecordReader<IntWritable, IntWritable> createRecordReader(
                org.apache.hadoop.mapreduce.InputSplit split, TaskAttemptContext context)
                throws IOException, InterruptedException {
            return new RecordReader<IntWritable, IntWritable>() {

                private boolean done = false;

                @Override
                public void close() throws IOException {
                }

                @Override
                public IntWritable getCurrentKey() throws IOException, InterruptedException {
                    return new IntWritable(0);
                }

                @Override
                public IntWritable getCurrentValue() throws IOException, InterruptedException {
                    return new IntWritable(0);
                }

                @Override
                public float getProgress() throws IOException, InterruptedException {
                    return done ? 0 : 1;
                }

                @Override
                public void initialize(org.apache.hadoop.mapreduce.InputSplit split, TaskAttemptContext context)
                        throws IOException, InterruptedException {
                }

                @Override
                public boolean nextKeyValue() throws IOException, InterruptedException {
                    if (!done) {
                        done = true;
                        return true;
                    }
                    return false;
                }
            };
        }

        @Override
        public List<org.apache.hadoop.mapreduce.InputSplit> getSplits(JobContext context)
                throws IOException, InterruptedException {
            List<org.apache.hadoop.mapreduce.InputSplit> list = new ArrayList<org.apache.hadoop.mapreduce.InputSplit>();
            int[] lengths = context.getConfiguration().getInts(SPLITS_LENGTHS);
            for (int i = 0; i < lengths.length; i++) {
                list.add(new InputSplitForTest(i + 1, lengths[i]));
            }
            return list;
        }
    }

    @Private
    public static class InputSplitForTest extends InputSplit implements Writable {

        private int identifier;
        private int length;

        @SuppressWarnings("unused")
        public InputSplitForTest() {
            // For writable
        }

        public int getIdentifier() {
            return this.identifier;
        }

        public InputSplitForTest(int identifier, int length) {
            this.identifier = identifier;
            this.length = length;
        }

        @Override
        public void write(DataOutput out) throws IOException {
            out.writeInt(identifier);
            out.writeInt(length);
        }

        @Override
        public void readFields(DataInput in) throws IOException {
            identifier = in.readInt();
            length = in.readInt();
        }

        @Override
        public long getLength() throws IOException {
            return length;
        }

        @Override
        public String[] getLocations() throws IOException {
            return new String[] { "localhost" };
        }
    }
}