Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tez.runtime.library.common.writers; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.BitSet; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.UUID; import com.google.protobuf.ByteString; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.io.compress.DefaultCodec; import org.apache.hadoop.util.DiskChecker.DiskErrorException; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.tez.common.TezCommonUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.common.TezUtilsInternal; import org.apache.tez.common.counters.TaskCounter; import org.apache.tez.common.counters.TezCounter; import org.apache.tez.common.counters.TezCounters; import org.apache.tez.runtime.api.Event; import org.apache.tez.runtime.api.OutputContext; import org.apache.tez.runtime.api.events.CompositeDataMovementEvent; import org.apache.tez.runtime.library.api.Partitioner; import org.apache.tez.runtime.library.api.TezRuntimeConfiguration; import org.apache.tez.runtime.library.common.sort.impl.IFile; import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord; import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord; import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutput; import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutputFiles; import org.apache.tez.runtime.library.partitioner.HashPartitioner; import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils; import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.DataMovementEventPayloadProto; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; import com.google.common.collect.LinkedListMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; @RunWith(value = Parameterized.class) public class TestUnorderedPartitionedKVWriter { private static final Log LOG = LogFactory.getLog(TestUnorderedPartitionedKVWriter.class); private static final String HOST_STRING = "localhost"; private static final int SHUFFLE_PORT = 4000; private static String testTmpDir = System.getProperty("test.build.data", "/tmp"); private static final Path TEST_ROOT_DIR = new Path(testTmpDir, TestUnorderedPartitionedKVWriter.class.getSimpleName()); private static FileSystem localFs; private boolean shouldCompress; public TestUnorderedPartitionedKVWriter(boolean shouldCompress) { this.shouldCompress = shouldCompress; } @Parameters public static Collection<Object[]> data() { Object[][] data = new Object[][] { { false }, { true } }; return Arrays.asList(data); } @Before public void setup() throws IOException { LOG.info("Setup. Using test dir: " + TEST_ROOT_DIR); localFs = FileSystem.getLocal(new Configuration()); localFs.delete(TEST_ROOT_DIR, true); localFs.mkdirs(TEST_ROOT_DIR); } @After public void cleanup() throws IOException { LOG.info("CleanUp"); localFs.delete(TEST_ROOT_DIR, true); } @Test(timeout = 10000) public void testBufferSizing() throws IOException { ApplicationId appId = ApplicationId.newInstance(10000, 1); TezCounters counters = new TezCounters(); String uniqueId = UUID.randomUUID().toString(); OutputContext outputContext = createMockOutputContext(counters, appId, uniqueId); int maxSingleBufferSizeBytes = 2047; Configuration conf = createConfiguration(outputContext, IntWritable.class, LongWritable.class, false, maxSingleBufferSizeBytes); int numOutputs = 10; UnorderedPartitionedKVWriter kvWriter = null; kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs, 2048); assertEquals(2, kvWriter.numBuffers); assertEquals(1024, kvWriter.sizePerBuffer); assertEquals(1, kvWriter.numInitializedBuffers); kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs, maxSingleBufferSizeBytes * 3); assertEquals(3, kvWriter.numBuffers); assertEquals(maxSingleBufferSizeBytes - maxSingleBufferSizeBytes % 4, kvWriter.sizePerBuffer); assertEquals(1, kvWriter.numInitializedBuffers); kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs, maxSingleBufferSizeBytes * 2 + 1); assertEquals(3, kvWriter.numBuffers); assertEquals(1364, kvWriter.sizePerBuffer); assertEquals(1, kvWriter.numInitializedBuffers); kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs, 10240); assertEquals(6, kvWriter.numBuffers); assertEquals(1704, kvWriter.sizePerBuffer); assertEquals(1, kvWriter.numInitializedBuffers); } @Test(timeout = 10000) public void testNoSpill() throws IOException, InterruptedException { baseTest(10, 10, null, shouldCompress); } @Test(timeout = 10000) public void testSingleSpill() throws IOException, InterruptedException { baseTest(50, 10, null, shouldCompress); } @Test(timeout = 10000) public void testMultipleSpills() throws IOException, InterruptedException { baseTest(200, 10, null, shouldCompress); } @Test(timeout = 10000) public void testNoRecords() throws IOException, InterruptedException { baseTest(0, 10, null, shouldCompress); } @Test(timeout = 10000) public void testSkippedPartitions() throws IOException, InterruptedException { baseTest(200, 10, Sets.newHashSet(2, 5), shouldCompress); } @Test(timeout = 10000) public void testRandomText() throws IOException, InterruptedException { textTest(100, 10, 2048, 0, 0, 0); } @Test(timeout = 10000) public void testLargeKeys() throws IOException, InterruptedException { textTest(0, 10, 2048, 10, 0, 0); } @Test(timeout = 10000) public void testLargevalues() throws IOException, InterruptedException { textTest(0, 10, 2048, 0, 10, 0); } @Test(timeout = 10000) public void testLargeKvPairs() throws IOException, InterruptedException { textTest(0, 10, 2048, 0, 0, 10); } @Test(timeout = 10000) public void testTextMixedRecords() throws IOException, InterruptedException { textTest(100, 10, 2048, 10, 10, 10); } public void textTest(int numRegularRecords, int numPartitions, long availableMemory, int numLargeKeys, int numLargevalues, int numLargeKvPairs) throws IOException, InterruptedException { Partitioner partitioner = new HashPartitioner(); ApplicationId appId = ApplicationId.newInstance(10000, 1); TezCounters counters = new TezCounters(); String uniqueId = UUID.randomUUID().toString(); OutputContext outputContext = createMockOutputContext(counters, appId, uniqueId); Random random = new Random(); Configuration conf = createConfiguration(outputContext, Text.class, Text.class, shouldCompress, -1, HashPartitioner.class); CompressionCodec codec = null; if (shouldCompress) { codec = new DefaultCodec(); ((Configurable) codec).setConf(conf); } int numRecordsWritten = 0; Map<Integer, Multimap<String, String>> expectedValues = new HashMap<Integer, Multimap<String, String>>(); for (int i = 0; i < numPartitions; i++) { expectedValues.put(i, LinkedListMultimap.<String, String>create()); } UnorderedPartitionedKVWriter kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numPartitions, availableMemory); int sizePerBuffer = kvWriter.sizePerBuffer; BitSet partitionsWithData = new BitSet(numPartitions); Text keyText = new Text(); Text valText = new Text(); for (int i = 0; i < numRegularRecords; i++) { String key = createRandomString(Math.abs(random.nextInt(10))); String val = createRandomString(Math.abs(random.nextInt(20))); keyText.set(key); valText.set(val); int partition = partitioner.getPartition(keyText, valText, numPartitions); partitionsWithData.set(partition); expectedValues.get(partition).put(key, val); kvWriter.write(keyText, valText); numRecordsWritten++; } // Write Large key records for (int i = 0; i < numLargeKeys; i++) { String key = createRandomString(sizePerBuffer + Math.abs(random.nextInt(100))); String val = createRandomString(Math.abs(random.nextInt(20))); keyText.set(key); valText.set(val); int partition = partitioner.getPartition(keyText, valText, numPartitions); partitionsWithData.set(partition); expectedValues.get(partition).put(key, val); kvWriter.write(keyText, valText); numRecordsWritten++; } // Write Large val records for (int i = 0; i < numLargevalues; i++) { String key = createRandomString(Math.abs(random.nextInt(10))); String val = createRandomString(sizePerBuffer + Math.abs(random.nextInt(100))); keyText.set(key); valText.set(val); int partition = partitioner.getPartition(keyText, valText, numPartitions); partitionsWithData.set(partition); expectedValues.get(partition).put(key, val); kvWriter.write(keyText, valText); numRecordsWritten++; } // Write records where key + val are large (but both can fit in the buffer individually) for (int i = 0; i < numLargeKvPairs; i++) { String key = createRandomString(sizePerBuffer / 2 + Math.abs(random.nextInt(100))); String val = createRandomString(sizePerBuffer / 2 + Math.abs(random.nextInt(100))); keyText.set(key); valText.set(val); int partition = partitioner.getPartition(keyText, valText, numPartitions); partitionsWithData.set(partition); expectedValues.get(partition).put(key, val); kvWriter.write(keyText, valText); numRecordsWritten++; } List<Event> events = kvWriter.close(); verify(outputContext, never()).fatalError(any(Throwable.class), any(String.class)); TezCounter outputLargeRecordsCounter = counters.findCounter(TaskCounter.OUTPUT_LARGE_RECORDS); assertEquals(numLargeKeys + numLargevalues + numLargeKvPairs, outputLargeRecordsCounter.getValue()); // Validate the event assertEquals(1, events.size()); assertTrue(events.get(0) instanceof CompositeDataMovementEvent); CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(0); assertEquals(0, cdme.getSourceIndexStart()); assertEquals(numPartitions, cdme.getCount()); DataMovementEventPayloadProto eventProto = DataMovementEventPayloadProto .parseFrom(ByteString.copyFrom(cdme.getUserPayload())); assertFalse(eventProto.hasData()); BitSet emptyPartitionBits = null; if (partitionsWithData.cardinality() != numPartitions) { assertTrue(eventProto.hasEmptyPartitions()); byte[] emptyPartitions = TezCommonUtils .decompressByteStringToByteArray(eventProto.getEmptyPartitions()); emptyPartitionBits = TezUtilsInternal.fromByteArray(emptyPartitions); assertEquals(numPartitions - partitionsWithData.cardinality(), emptyPartitionBits.cardinality()); } else { assertFalse(eventProto.hasEmptyPartitions()); emptyPartitionBits = new BitSet(numPartitions); } assertEquals(HOST_STRING, eventProto.getHost()); assertEquals(SHUFFLE_PORT, eventProto.getPort()); assertEquals(uniqueId, eventProto.getPathComponent()); // Verify the data // Verify the actual data TezTaskOutput taskOutput = new TezTaskOutputFiles(conf, uniqueId); Path outputFilePath = null; Path spillFilePath = null; try { outputFilePath = taskOutput.getOutputFile(); } catch (DiskErrorException e) { if (numRecordsWritten > 0) { fail(); } else { // Record checking not required. return; } } try { spillFilePath = taskOutput.getOutputIndexFile(); } catch (DiskErrorException e) { if (numRecordsWritten > 0) { fail(); } else { // Record checking not required. return; } } // Special case for 0 records. TezSpillRecord spillRecord = new TezSpillRecord(spillFilePath, conf); DataInputBuffer keyBuffer = new DataInputBuffer(); DataInputBuffer valBuffer = new DataInputBuffer(); Text keyDeser = new Text(); Text valDeser = new Text(); for (int i = 0; i < numPartitions; i++) { if (emptyPartitionBits.get(i)) { continue; } TezIndexRecord indexRecord = spillRecord.getIndex(i); FSDataInputStream inStream = FileSystem.getLocal(conf).open(outputFilePath); inStream.seek(indexRecord.getStartOffset()); IFile.Reader reader = new IFile.Reader(inStream, indexRecord.getPartLength(), codec, null, null, false, 0, -1); while (reader.nextRawKey(keyBuffer)) { reader.nextRawValue(valBuffer); keyDeser.readFields(keyBuffer); valDeser.readFields(valBuffer); int partition = partitioner.getPartition(keyDeser, valDeser, numPartitions); assertTrue(expectedValues.get(partition).remove(keyDeser.toString(), valDeser.toString())); } inStream.close(); } for (int i = 0; i < numPartitions; i++) { assertEquals(0, expectedValues.get(i).size()); expectedValues.remove(i); } assertEquals(0, expectedValues.size()); } private void baseTest(int numRecords, int numPartitions, Set<Integer> skippedPartitions, boolean shouldCompress) throws IOException, InterruptedException { PartitionerForTest partitioner = new PartitionerForTest(); ApplicationId appId = ApplicationId.newInstance(10000, 1); TezCounters counters = new TezCounters(); String uniqueId = UUID.randomUUID().toString(); OutputContext outputContext = createMockOutputContext(counters, appId, uniqueId); Configuration conf = createConfiguration(outputContext, IntWritable.class, LongWritable.class, shouldCompress, -1); CompressionCodec codec = null; if (shouldCompress) { codec = new DefaultCodec(); ((Configurable) codec).setConf(conf); } int numOutputs = numPartitions; long availableMemory = 2048; int numRecordsWritten = 0; Map<Integer, Multimap<Integer, Long>> expectedValues = new HashMap<Integer, Multimap<Integer, Long>>(); for (int i = 0; i < numOutputs; i++) { expectedValues.put(i, LinkedListMultimap.<Integer, Long>create()); } UnorderedPartitionedKVWriter kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs, availableMemory); int sizePerBuffer = kvWriter.sizePerBuffer; int sizePerRecord = 4 + 8; // IntW + LongW int sizePerRecordWithOverhead = sizePerRecord + 12; // Record + META_OVERHEAD IntWritable intWritable = new IntWritable(); LongWritable longWritable = new LongWritable(); for (int i = 0; i < numRecords; i++) { intWritable.set(i); longWritable.set(i); int partition = partitioner.getPartition(intWritable, longWritable, numOutputs); if (skippedPartitions != null && skippedPartitions.contains(partition)) { continue; } expectedValues.get(partition).put(intWritable.get(), longWritable.get()); kvWriter.write(intWritable, longWritable); numRecordsWritten++; } List<Event> events = kvWriter.close(); int recordsPerBuffer = sizePerBuffer / sizePerRecordWithOverhead; int numExpectedSpills = numRecordsWritten / recordsPerBuffer; verify(outputContext, never()).fatalError(any(Throwable.class), any(String.class)); // Verify the status of the buffers if (numExpectedSpills == 0) { assertEquals(1, kvWriter.numInitializedBuffers); } else { assertTrue(kvWriter.numInitializedBuffers > 1); } assertNull(kvWriter.currentBuffer); assertEquals(0, kvWriter.availableBuffers.size()); // Verify the counters TezCounter outputRecordBytesCounter = counters.findCounter(TaskCounter.OUTPUT_BYTES); TezCounter outputRecordsCounter = counters.findCounter(TaskCounter.OUTPUT_RECORDS); TezCounter outputBytesWithOverheadCounter = counters.findCounter(TaskCounter.OUTPUT_BYTES_WITH_OVERHEAD); TezCounter fileOutputBytesCounter = counters.findCounter(TaskCounter.OUTPUT_BYTES_PHYSICAL); TezCounter spilledRecordsCounter = counters.findCounter(TaskCounter.SPILLED_RECORDS); TezCounter additionalSpillBytesWritternCounter = counters .findCounter(TaskCounter.ADDITIONAL_SPILLS_BYTES_WRITTEN); TezCounter additionalSpillBytesReadCounter = counters.findCounter(TaskCounter.ADDITIONAL_SPILLS_BYTES_READ); TezCounter numAdditionalSpillsCounter = counters.findCounter(TaskCounter.ADDITIONAL_SPILL_COUNT); assertEquals(numRecordsWritten * sizePerRecord, outputRecordBytesCounter.getValue()); assertEquals(numRecordsWritten, outputRecordsCounter.getValue()); assertEquals(numRecordsWritten * sizePerRecordWithOverhead, outputBytesWithOverheadCounter.getValue()); long fileOutputBytes = fileOutputBytesCounter.getValue(); if (numRecordsWritten > 0) { assertTrue(fileOutputBytes > 0); if (!shouldCompress) { assertTrue(fileOutputBytes > outputRecordBytesCounter.getValue()); } } else { assertEquals(0, fileOutputBytes); } assertEquals(recordsPerBuffer * numExpectedSpills, spilledRecordsCounter.getValue()); long additionalSpillBytesWritten = additionalSpillBytesWritternCounter.getValue(); long additionalSpillBytesRead = additionalSpillBytesReadCounter.getValue(); if (numExpectedSpills == 0) { assertEquals(0, additionalSpillBytesWritten); assertEquals(0, additionalSpillBytesRead); } else { assertTrue(additionalSpillBytesWritten > 0); assertTrue(additionalSpillBytesRead > 0); if (!shouldCompress) { assertTrue(additionalSpillBytesWritten > (recordsPerBuffer * numExpectedSpills * sizePerRecord)); assertTrue(additionalSpillBytesRead > (recordsPerBuffer * numExpectedSpills * sizePerRecord)); } } assertTrue(additionalSpillBytesWritten == additionalSpillBytesRead); assertEquals(numExpectedSpills, numAdditionalSpillsCounter.getValue()); BitSet emptyPartitionBits = null; // Verify the event returned assertEquals(1, events.size()); assertTrue(events.get(0) instanceof CompositeDataMovementEvent); CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(0); assertEquals(0, cdme.getSourceIndexStart()); assertEquals(numOutputs, cdme.getCount()); DataMovementEventPayloadProto eventProto = DataMovementEventPayloadProto .parseFrom(ByteString.copyFrom(cdme.getUserPayload())); assertFalse(eventProto.hasData()); if (skippedPartitions == null && numRecordsWritten > 0) { assertFalse(eventProto.hasEmptyPartitions()); emptyPartitionBits = new BitSet(numPartitions); } else { assertTrue(eventProto.hasEmptyPartitions()); byte[] emptyPartitions = TezCommonUtils .decompressByteStringToByteArray(eventProto.getEmptyPartitions()); emptyPartitionBits = TezUtilsInternal.fromByteArray(emptyPartitions); if (numRecordsWritten == 0) { assertEquals(numPartitions, emptyPartitionBits.cardinality()); } else { for (Integer e : skippedPartitions) { assertTrue(emptyPartitionBits.get(e)); } assertEquals(skippedPartitions.size(), emptyPartitionBits.cardinality()); } } if (emptyPartitionBits.cardinality() != numPartitions) { assertEquals(HOST_STRING, eventProto.getHost()); assertEquals(SHUFFLE_PORT, eventProto.getPort()); assertEquals(uniqueId, eventProto.getPathComponent()); } else { assertFalse(eventProto.hasHost()); assertFalse(eventProto.hasPort()); assertFalse(eventProto.hasPathComponent()); } // Verify the actual data TezTaskOutput taskOutput = new TezTaskOutputFiles(conf, uniqueId); Path outputFilePath = null; Path spillFilePath = null; try { outputFilePath = taskOutput.getOutputFile(); } catch (DiskErrorException e) { if (numRecordsWritten > 0) { fail(); } else { // Record checking not required. return; } } try { spillFilePath = taskOutput.getOutputIndexFile(); } catch (DiskErrorException e) { if (numRecordsWritten > 0) { fail(); } else { // Record checking not required. return; } } // Special case for 0 records. TezSpillRecord spillRecord = new TezSpillRecord(spillFilePath, conf); DataInputBuffer keyBuffer = new DataInputBuffer(); DataInputBuffer valBuffer = new DataInputBuffer(); IntWritable keyDeser = new IntWritable(); LongWritable valDeser = new LongWritable(); for (int i = 0; i < numOutputs; i++) { if (skippedPartitions != null && skippedPartitions.contains(i)) { continue; } TezIndexRecord indexRecord = spillRecord.getIndex(i); FSDataInputStream inStream = FileSystem.getLocal(conf).open(outputFilePath); inStream.seek(indexRecord.getStartOffset()); IFile.Reader reader = new IFile.Reader(inStream, indexRecord.getPartLength(), codec, null, null, false, 0, -1); while (reader.nextRawKey(keyBuffer)) { reader.nextRawValue(valBuffer); keyDeser.readFields(keyBuffer); valDeser.readFields(valBuffer); int partition = partitioner.getPartition(keyDeser, valDeser, numOutputs); assertTrue(expectedValues.get(partition).remove(keyDeser.get(), valDeser.get())); } inStream.close(); } for (int i = 0; i < numOutputs; i++) { assertEquals(0, expectedValues.get(i).size()); expectedValues.remove(i); } assertEquals(0, expectedValues.size()); } private static String createRandomString(int size) { StringBuilder sb = new StringBuilder(); Random random = new Random(); for (int i = 0; i < size; i++) { int r = Math.abs(random.nextInt()) % 26; sb.append((char) (65 + r)); } return sb.toString(); } private OutputContext createMockOutputContext(TezCounters counters, ApplicationId appId, String uniqueId) { OutputContext outputContext = mock(OutputContext.class); doReturn(counters).when(outputContext).getCounters(); doReturn(appId).when(outputContext).getApplicationId(); doReturn(1).when(outputContext).getDAGAttemptNumber(); doReturn("dagName").when(outputContext).getDAGName(); doReturn("destinationVertexName").when(outputContext).getDestinationVertexName(); doReturn(1).when(outputContext).getOutputIndex(); doReturn(1).when(outputContext).getTaskAttemptNumber(); doReturn(1).when(outputContext).getTaskIndex(); doReturn(1).when(outputContext).getTaskVertexIndex(); doReturn("vertexName").when(outputContext).getTaskVertexName(); doReturn(uniqueId).when(outputContext).getUniqueIdentifier(); ByteBuffer portBuffer = ByteBuffer.allocate(4); portBuffer.mark(); portBuffer.putInt(SHUFFLE_PORT); portBuffer.reset(); doReturn(portBuffer).when(outputContext) .getServiceProviderMetaData(eq(ShuffleUtils.SHUFFLE_HANDLER_SERVICE_ID)); Path outDirBase = new Path(TEST_ROOT_DIR, "outDir_" + uniqueId); String[] outDirs = new String[] { outDirBase.toString() }; doReturn(outDirs).when(outputContext).getWorkDirs(); return outputContext; } private Configuration createConfiguration(OutputContext outputContext, Class<? extends Writable> keyClass, Class<? extends Writable> valClass, boolean shouldCompress, int maxSingleBufferSizeBytes) { return createConfiguration(outputContext, keyClass, valClass, shouldCompress, maxSingleBufferSizeBytes, PartitionerForTest.class); } private Configuration createConfiguration(OutputContext outputContext, Class<? extends Writable> keyClass, Class<? extends Writable> valClass, boolean shouldCompress, int maxSingleBufferSizeBytes, Class<? extends Partitioner> partitionerClass) { Configuration conf = new Configuration(false); conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, outputContext.getWorkDirs()); conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, keyClass.getName()); conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, valClass.getName()); conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS, partitionerClass.getName()); if (maxSingleBufferSizeBytes >= 0) { conf.setInt(TezRuntimeConfiguration.TEZ_RUNTIME_UNORDERED_OUTPUT_MAX_PER_BUFFER_SIZE_BYTES, maxSingleBufferSizeBytes); } conf.setBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS, shouldCompress); if (shouldCompress) { conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_COMPRESS_CODEC, DefaultCodec.class.getName()); } return conf; } public static class PartitionerForTest implements Partitioner { @Override public int getPartition(Object key, Object value, int numPartitions) { if (key instanceof IntWritable) { return ((IntWritable) key).get() % numPartitions; } else { throw new UnsupportedOperationException( "Test partitioner expected to be called with IntWritable only"); } } } private static class UnorderedPartitionedKVWriterForTest extends UnorderedPartitionedKVWriter { public UnorderedPartitionedKVWriterForTest(OutputContext outputContext, Configuration conf, int numOutputs, long availableMemoryBytes) throws IOException { super(outputContext, conf, numOutputs, availableMemoryBytes); } @Override String getHost() { return HOST_STRING; } } }