Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.mapreduce.task.reduce; import java.io.FilterInputStream; import java.lang.Void; import java.net.HttpURLConnection; import org.apache.hadoop.fs.ChecksumException; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.MapOutputFile; import org.apache.hadoop.mapreduce.MRJobConfig; import org.apache.hadoop.mapreduce.TaskID; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; import static org.junit.Assert.*; import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.net.SocketTimeoutException; import java.net.URL; import java.util.ArrayList; import javax.crypto.SecretKey; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters; import org.apache.hadoop.mapred.IFileInputStream; import org.apache.hadoop.mapred.IFileOutputStream; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.Reporter; import org.apache.hadoop.mapreduce.TaskAttemptID; import org.apache.hadoop.mapreduce.security.SecureShuffleUtils; import org.apache.hadoop.mapreduce.security.token.JobTokenSecretManager; import org.apache.hadoop.util.DiskChecker.DiskErrorException; import org.apache.hadoop.util.Time; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import com.nimbusds.jose.util.StringUtils; /** * Test that the Fetcher does what we expect it to. */ public class TestFetcher { private static final Log LOG = LogFactory.getLog(TestFetcher.class); JobConf job = null; JobConf jobWithRetry = null; TaskAttemptID id = null; ShuffleSchedulerImpl<Text, Text> ss = null; MergeManagerImpl<Text, Text> mm = null; Reporter r = null; ShuffleClientMetrics metrics = null; ExceptionReporter except = null; SecretKey key = null; HttpURLConnection connection = null; Counters.Counter allErrs = null; final String encHash = "vFE234EIFCiBgYs2tCXY/SjT8Kg="; final MapHost host = new MapHost("localhost", "http://localhost:8080/"); final TaskAttemptID map1ID = TaskAttemptID.forName("attempt_0_1_m_1_1"); final TaskAttemptID map2ID = TaskAttemptID.forName("attempt_0_1_m_2_1"); FileSystem fs = null; @Rule public TestName name = new TestName(); @Before @SuppressWarnings("unchecked") // mocked generics public void setup() { LOG.info(">>>> " + name.getMethodName()); job = new JobConf(); job.setBoolean(MRJobConfig.SHUFFLE_FETCH_RETRY_ENABLED, false); jobWithRetry = new JobConf(); jobWithRetry.setBoolean(MRJobConfig.SHUFFLE_FETCH_RETRY_ENABLED, true); id = TaskAttemptID.forName("attempt_0_1_r_1_1"); ss = mock(ShuffleSchedulerImpl.class); mm = mock(MergeManagerImpl.class); r = mock(Reporter.class); metrics = mock(ShuffleClientMetrics.class); except = mock(ExceptionReporter.class); key = JobTokenSecretManager.createSecretKey(new byte[] { 0, 0, 0, 0 }); connection = mock(HttpURLConnection.class); allErrs = mock(Counters.Counter.class); when(r.getCounter(anyString(), anyString())).thenReturn(allErrs); ArrayList<TaskAttemptID> maps = new ArrayList<TaskAttemptID>(1); maps.add(map1ID); maps.add(map2ID); when(ss.getMapsForHost(host)).thenReturn(maps); } @After public void teardown() throws IllegalArgumentException, IOException { LOG.info("<<<< " + name.getMethodName()); if (fs != null) { fs.delete(new Path(name.getMethodName()), true); } } @Test public void testReduceOutOfDiskSpace() throws Throwable { LOG.info("testReduceOutOfDiskSpace"); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); when(connection.getInputStream()).thenReturn(in); when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())) .thenThrow(new DiskErrorException("No disk space available")); underTest.copyFromHost(host); verify(ss).reportLocalError(any(IOException.class)); } @Test(timeout = 30000) public void testCopyFromHostConnectionTimeout() throws Exception { when(connection.getInputStream()).thenThrow(new SocketTimeoutException("This is a fake timeout :)")); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection); underTest.copyFromHost(host); verify(connection).addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); verify(allErrs).increment(1); verify(ss).copyFailed(map1ID, host, false, false); verify(ss).copyFailed(map2ID, host, false, false); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map1ID)); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } @Test public void testCopyFromHostBogusHeader() throws Exception { Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ByteArrayInputStream in = new ByteArrayInputStream( "\u00010 BOGUS DATA\nBOGUS DATA\nBOGUS DATA\n".getBytes()); when(connection.getInputStream()).thenReturn(in); underTest.copyFromHost(host); verify(connection).addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); verify(allErrs).increment(1); verify(ss).copyFailed(map1ID, host, true, false); verify(ss).copyFailed(map2ID, host, true, false); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map1ID)); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } @Test public void testCopyFromHostIncompatibleShuffleVersion() throws Exception { String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)).thenReturn("mapreduce").thenReturn("other") .thenReturn("other"); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)).thenReturn("1.0.1").thenReturn("1.0.0") .thenReturn("1.0.1"); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]); when(connection.getInputStream()).thenReturn(in); for (int i = 0; i < 3; ++i) { Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection); underTest.copyFromHost(host); } verify(connection, times(3)).addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); verify(allErrs, times(3)).increment(1); verify(ss, times(3)).copyFailed(map1ID, host, false, false); verify(ss, times(3)).copyFailed(map2ID, host, false, false); verify(ss, times(3)).putBackKnownMapOutput(any(MapHost.class), eq(map1ID)); verify(ss, times(3)).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } @Test public void testCopyFromHostIncompatibleShuffleVersionWithRetry() throws Exception { String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)).thenReturn("mapreduce").thenReturn("other") .thenReturn("other"); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)).thenReturn("1.0.1").thenReturn("1.0.0") .thenReturn("1.0.1"); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]); when(connection.getInputStream()).thenReturn(in); for (int i = 0; i < 3; ++i) { Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(jobWithRetry, id, ss, mm, r, metrics, except, key, connection); underTest.copyFromHost(host); } verify(connection, times(3)).addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); verify(allErrs, times(3)).increment(1); verify(ss, times(3)).copyFailed(map1ID, host, false, false); verify(ss, times(3)).copyFailed(map2ID, host, false, false); verify(ss, times(3)).putBackKnownMapOutput(any(MapHost.class), eq(map1ID)); verify(ss, times(3)).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } @Test public void testCopyFromHostWait() throws Exception { Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); when(connection.getInputStream()).thenReturn(in); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); //Defaults to null, which is what we want to test when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())).thenReturn(null); underTest.copyFromHost(host); verify(connection).addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); verify(allErrs, never()).increment(1); verify(ss, never()).copyFailed(map1ID, host, true, false); verify(ss, never()).copyFailed(map2ID, host, true, false); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map1ID)); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } @SuppressWarnings("unchecked") @Test(timeout = 10000) public void testCopyFromHostCompressFailure() throws Exception { InMemoryMapOutput<Text, Text> immo = mock(InMemoryMapOutput.class); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); when(connection.getInputStream()).thenReturn(in); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())).thenReturn(immo); doThrow(new java.lang.InternalError()).when(immo).shuffle(any(MapHost.class), any(InputStream.class), anyLong(), anyLong(), any(ShuffleClientMetrics.class), any(Reporter.class)); underTest.copyFromHost(host); verify(connection).addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); verify(ss, times(1)).copyFailed(map1ID, host, true, false); } @SuppressWarnings("unchecked") @Test(timeout = 10000) public void testCopyFromHostOnAnyException() throws Exception { InMemoryMapOutput<Text, Text> immo = mock(InMemoryMapOutput.class); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); when(connection.getInputStream()).thenReturn(in); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())).thenReturn(immo); doThrow(new ArrayIndexOutOfBoundsException()).when(immo).shuffle(any(MapHost.class), any(InputStream.class), anyLong(), anyLong(), any(ShuffleClientMetrics.class), any(Reporter.class)); underTest.copyFromHost(host); verify(connection).addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); verify(ss, times(1)).copyFailed(map1ID, host, true, false); } @SuppressWarnings("unchecked") @Test(timeout = 10000) public void testCopyFromHostWithRetry() throws Exception { InMemoryMapOutput<Text, Text> immo = mock(InMemoryMapOutput.class); ss = mock(ShuffleSchedulerImpl.class); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(jobWithRetry, id, ss, mm, r, metrics, except, key, connection, true); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); when(connection.getInputStream()).thenReturn(in); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())).thenReturn(immo); final long retryTime = Time.monotonicNow(); doAnswer(new Answer<Void>() { public Void answer(InvocationOnMock ignore) throws IOException { // Emulate host down for 3 seconds. if ((Time.monotonicNow() - retryTime) <= 3000) { throw new java.lang.InternalError(); } return null; } }).when(immo).shuffle(any(MapHost.class), any(InputStream.class), anyLong(), anyLong(), any(ShuffleClientMetrics.class), any(Reporter.class)); underTest.copyFromHost(host); verify(ss, never()).copyFailed(any(TaskAttemptID.class), any(MapHost.class), anyBoolean(), anyBoolean()); } @SuppressWarnings("unchecked") @Test(timeout = 10000) public void testCopyFromHostWithRetryThenTimeout() throws Exception { InMemoryMapOutput<Text, Text> immo = mock(InMemoryMapOutput.class); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(jobWithRetry, id, ss, mm, r, metrics, except, key, connection); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200).thenThrow(new SocketTimeoutException("forced timeout")); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); when(connection.getInputStream()).thenReturn(in); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())).thenReturn(immo); doThrow(new IOException("forced error")).when(immo).shuffle(any(MapHost.class), any(InputStream.class), anyLong(), anyLong(), any(ShuffleClientMetrics.class), any(Reporter.class)); underTest.copyFromHost(host); verify(allErrs).increment(1); verify(ss).copyFailed(map1ID, host, false, false); } @Test public void testCopyFromHostExtraBytes() throws Exception { Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 14, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(bout); IFileOutputStream ios = new IFileOutputStream(dos); header.write(dos); ios.write("MAPDATA123".getBytes()); ios.finish(); ShuffleHeader header2 = new ShuffleHeader(map2ID.toString(), 14, 10, 1); IFileOutputStream ios2 = new IFileOutputStream(dos); header2.write(dos); ios2.write("MAPDATA456".getBytes()); ios2.finish(); ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); when(connection.getInputStream()).thenReturn(in); // 8 < 10 therefore there appear to be extra bytes in the IFileInputStream IFileWrappedMapOutput<Text, Text> mapOut = new InMemoryMapOutput<Text, Text>(job, map1ID, mm, 8, null, true); IFileWrappedMapOutput<Text, Text> mapOut2 = new InMemoryMapOutput<Text, Text>(job, map2ID, mm, 10, null, true); when(mm.reserve(eq(map1ID), anyLong(), anyInt())).thenReturn(mapOut); when(mm.reserve(eq(map2ID), anyLong(), anyInt())).thenReturn(mapOut2); underTest.copyFromHost(host); verify(allErrs).increment(1); verify(ss).copyFailed(map1ID, host, true, false); verify(ss, never()).copyFailed(map2ID, host, true, false); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map1ID)); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } @Test public void testCorruptedIFile() throws Exception { final int fetcher = 7; Path onDiskMapOutputPath = new Path(name.getMethodName() + "/foo"); Path shuffledToDisk = OnDiskMapOutput.getTempPath(onDiskMapOutputPath, fetcher); fs = FileSystem.getLocal(job).getRaw(); IFileWrappedMapOutput<Text, Text> odmo = new OnDiskMapOutput<Text, Text>(map1ID, mm, 100L, job, fetcher, true, fs, onDiskMapOutputPath); String mapData = "MAPDATA12345678901234567890"; ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 14, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(bout); IFileOutputStream ios = new IFileOutputStream(dos); header.write(dos); int headerSize = dos.size(); try { ios.write(mapData.getBytes()); } finally { ios.close(); } int dataSize = bout.size() - headerSize; // Ensure that the OnDiskMapOutput shuffler can successfully read the data. MapHost host = new MapHost("TestHost", "http://test/url"); ByteArrayInputStream bin = new ByteArrayInputStream(bout.toByteArray()); try { // Read past the shuffle header. bin.read(new byte[headerSize], 0, headerSize); odmo.shuffle(host, bin, dataSize, dataSize, metrics, Reporter.NULL); } finally { bin.close(); } // Now corrupt the IFile data. byte[] corrupted = bout.toByteArray(); corrupted[headerSize + (dataSize / 2)] = 0x0; try { bin = new ByteArrayInputStream(corrupted); // Read past the shuffle header. bin.read(new byte[headerSize], 0, headerSize); odmo.shuffle(host, bin, dataSize, dataSize, metrics, Reporter.NULL); fail("OnDiskMapOutput.shuffle didn't detect the corrupted map partition file"); } catch (ChecksumException e) { LOG.info("The expected checksum exception was thrown.", e); } finally { bin.close(); } // Ensure that the shuffled file can be read. IFileInputStream iFin = new IFileInputStream(fs.open(shuffledToDisk), dataSize, job); try { iFin.read(new byte[dataSize], 0, dataSize); } finally { iFin.close(); } } @Test(timeout = 10000) public void testInterruptInMemory() throws Exception { final int FETCHER = 2; IFileWrappedMapOutput<Text, Text> immo = spy( new InMemoryMapOutput<Text, Text>(job, id, mm, 100, null, true)); when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())).thenReturn(immo); doNothing().when(mm).waitForResource(); when(ss.getHost()).thenReturn(host); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); final StuckInputStream in = new StuckInputStream(new ByteArrayInputStream(bout.toByteArray())); when(connection.getInputStream()).thenReturn(in); doAnswer(new Answer<Void>() { public Void answer(InvocationOnMock ignore) throws IOException { in.close(); return null; } }).when(connection).disconnect(); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection, FETCHER); underTest.start(); // wait for read in inputstream in.waitForFetcher(); underTest.shutDown(); underTest.join(); // rely on test timeout to kill if stuck assertTrue(in.wasClosedProperly()); verify(immo).abort(); } @Test(timeout = 10000) public void testInterruptOnDisk() throws Exception { final int FETCHER = 7; Path p = new Path("file:///tmp/foo"); Path pTmp = OnDiskMapOutput.getTempPath(p, FETCHER); FileSystem mFs = mock(FileSystem.class, RETURNS_DEEP_STUBS); IFileWrappedMapOutput<Text, Text> odmo = spy( new OnDiskMapOutput<Text, Text>(map1ID, mm, 100L, job, FETCHER, true, mFs, p)); when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())).thenReturn(odmo); doNothing().when(mm).waitForResource(); when(ss.getHost()).thenReturn(host); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); final StuckInputStream in = new StuckInputStream(new ByteArrayInputStream(bout.toByteArray())); when(connection.getInputStream()).thenReturn(in); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); doAnswer(new Answer<Void>() { public Void answer(InvocationOnMock ignore) throws IOException { in.close(); return null; } }).when(connection).disconnect(); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(job, id, ss, mm, r, metrics, except, key, connection, FETCHER); underTest.start(); // wait for read in inputstream in.waitForFetcher(); underTest.shutDown(); underTest.join(); // rely on test timeout to kill if stuck assertTrue(in.wasClosedProperly()); verify(mFs).create(eq(pTmp)); verify(mFs).delete(eq(pTmp), eq(false)); verify(odmo).abort(); } @SuppressWarnings("unchecked") @Test(timeout = 10000) public void testCopyFromHostWithRetryUnreserve() throws Exception { InMemoryMapOutput<Text, Text> immo = mock(InMemoryMapOutput.class); Fetcher<Text, Text> underTest = new FakeFetcher<Text, Text>(jobWithRetry, id, ss, mm, r, metrics, except, key, connection); String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); ByteArrayOutputStream bout = new ByteArrayOutputStream(); header.write(new DataOutputStream(bout)); ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); when(connection.getInputStream()).thenReturn(in); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); when(connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION)) .thenReturn(ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); // Verify that unreserve occurs if an exception happens after shuffle // buffer is reserved. when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())).thenReturn(immo); doThrow(new IOException("forced error")).when(immo).shuffle(any(MapHost.class), any(InputStream.class), anyLong(), anyLong(), any(ShuffleClientMetrics.class), any(Reporter.class)); underTest.copyFromHost(host); verify(immo).abort(); } public static class FakeFetcher<K, V> extends Fetcher<K, V> { // If connection need to be reopen. private boolean renewConnection = false; public FakeFetcher(JobConf job, TaskAttemptID reduceId, ShuffleSchedulerImpl<K, V> scheduler, MergeManagerImpl<K, V> merger, Reporter reporter, ShuffleClientMetrics metrics, ExceptionReporter exceptionReporter, SecretKey jobTokenSecret, HttpURLConnection connection) { super(job, reduceId, scheduler, merger, reporter, metrics, exceptionReporter, jobTokenSecret); this.connection = connection; } public FakeFetcher(JobConf job, TaskAttemptID reduceId, ShuffleSchedulerImpl<K, V> scheduler, MergeManagerImpl<K, V> merger, Reporter reporter, ShuffleClientMetrics metrics, ExceptionReporter exceptionReporter, SecretKey jobTokenSecret, HttpURLConnection connection, boolean renewConnection) { super(job, reduceId, scheduler, merger, reporter, metrics, exceptionReporter, jobTokenSecret); this.connection = connection; this.renewConnection = renewConnection; } public FakeFetcher(JobConf job, TaskAttemptID reduceId, ShuffleSchedulerImpl<K, V> scheduler, MergeManagerImpl<K, V> merger, Reporter reporter, ShuffleClientMetrics metrics, ExceptionReporter exceptionReporter, SecretKey jobTokenSecret, HttpURLConnection connection, int id) { super(job, reduceId, scheduler, merger, reporter, metrics, exceptionReporter, jobTokenSecret, id); this.connection = connection; } @Override protected void openConnection(URL url) throws IOException { if (null == connection || renewConnection) { super.openConnection(url); } // already 'opened' the mocked connection return; } } static class StuckInputStream extends FilterInputStream { boolean stuck = false; volatile boolean closed = false; StuckInputStream(InputStream inner) { super(inner); } int freeze() throws IOException { synchronized (this) { stuck = true; notify(); } // connection doesn't throw InterruptedException, but may return some // bytes geq 0 or throw an exception while (!Thread.currentThread().isInterrupted() || closed) { // spin if (closed) { throw new IOException("underlying stream closed, triggered an error"); } } return 0; } @Override public int read() throws IOException { int ret = super.read(); if (ret != -1) { return ret; } return freeze(); } @Override public int read(byte[] b) throws IOException { int ret = super.read(b); if (ret != -1) { return ret; } return freeze(); } @Override public int read(byte[] b, int off, int len) throws IOException { int ret = super.read(b, off, len); if (ret != -1) { return ret; } return freeze(); } @Override public void close() throws IOException { closed = true; } public synchronized void waitForFetcher() throws InterruptedException { while (!stuck) { wait(); } } public boolean wasClosedProperly() { return closed; } } }