Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.*; import org.apache.spark.annotation.Private; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { private static final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @VisibleForTesting static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; private final BlockManager blockManager; private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; private final int shuffleId; private final int mapId; private final TaskContext taskContext; private final SparkConf sparkConf; private final boolean transferToEnabled; private final int initialSortBufferSize; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { MyByteArrayOutputStream(int size) { super(size); } public byte[] getBuf() { return buf; } } private MyByteArrayOutputStream serBuffer; private SerializationStream serOutputStream; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true * and then call stop() with success = false if they get an exception, we want to make sure * we don't try deleting files, etc twice. */ private boolean stopping = false; public UnsafeShuffleWriter(BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, SerializedShuffleHandle<K, V> handle, int mapId, TaskContext taskContext, SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException("UnsafeShuffleWriter can only be used for shuffles with at most " + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; this.mapId = mapId; final ShuffleDependency<K, V, V> dep = handle.dependency(); this.shuffleId = dep.shuffleId(); this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE); open(); } private void updatePeakMemoryUsed() { // sorter can be null if this writer is closed if (sorter != null) { long mem = sorter.getPeakMemoryUsedBytes(); if (mem > peakMemoryUsedBytes) { peakMemoryUsedBytes = mem; } } } /** * Return the peak memory used so far, in bytes. */ public long getPeakMemoryUsedBytes() { updatePeakMemoryUsed(); return peakMemoryUsedBytes; } /** * This convenience method should only be called in test code. */ @VisibleForTesting public void write(Iterator<Product2<K, V>> records) throws IOException { write(JavaConverters.asScalaIteratorConverter(records).asScala()); } @Override public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException { // Keep track of success so we know if we encountered an exception // We do this rather than a standard try/catch/re-throw to handle // generic throwables. boolean success = false; try { while (records.hasNext()) { insertRecordIntoSorter(records.next()); } closeAndWriteOutput(); success = true; } finally { if (sorter != null) { try { sorter.cleanupResources(); } catch (Exception e) { // Only throw this error if we won't be masking another // error. if (success) { throw e; } else { logger.error("In addition to a failure during writing, we failed during " + "cleanup.", e); } } } } } private void open() throws IOException { assert (sorter == null); sorter = new ShuffleExternalSorter(memoryManager, blockManager, taskContext, initialSortBufferSize, partitioner.numPartitions(), sparkConf, writeMetrics); serBuffer = new MyByteArrayOutputStream(1024 * 1024); serOutputStream = serializer.serializeStream(serBuffer); } @VisibleForTesting void closeAndWriteOutput() throws IOException { assert (sorter != null); updatePeakMemoryUsed(); serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; final long[] partitionLengths; final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); final File tmp = Utils.tempFileWith(output); try { try { partitionLengths = mergeSpills(spills, tmp); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Error while deleting spill file {}", spill.file.getPath()); } } } shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting void insertRecordIntoSorter(Product2<K, V> record) throws IOException { assert (sorter != null); final K key = record._1(); final int partitionId = partitioner.getPartition(key); serBuffer.reset(); serOutputStream.writeKey(key, OBJECT_CLASS_TAG); serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); serOutputStream.flush(); final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); sorter.insertRecord(serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting void forceSorterToSpill() throws IOException { assert (sorter != null); sorter.spill(); } /** * Merge zero or more spill files together, choosing the fastest merging strategy based on the * number of spills and the IO compression codec. * * @return the partition lengths in the merged file. */ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file return new long[partitioner.numPartitions()]; } else if (spills.length == 1) { // Here, we don't need to perform any metrics updates because the bytes written to this // output file would have already been counted as shuffle bytes written. Files.move(spills[0].file, outputFile); return spills[0].partitionLengths; } else { final long[] partitionLengths; // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill // files' sizes. To guard against this case, we look at the output file's actual size when // computing shuffle bytes written. // // We allow the individual merge methods to report their own IO times since different merge // strategies use different IO techniques. We count IO during merge towards the shuffle // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" // branch in ExternalSorter. if (fastMergeEnabled && fastMergeIsSupported) { // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { logger.debug("Using fileStream-based fast merge"); partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); } } else { logger.debug("Using slow merge"); partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); writeMetrics.incBytesWritten(outputFile.length()); return partitionLengths; } } catch (IOException e) { if (outputFile.exists() && !outputFile.delete()) { logger.error("Unable to delete output file {}", outputFile.getPath()); } throw e; } } /** * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in * cases where the IO compression codec does not support concatenation of compressed data, when * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in * order to work around kernel bugs. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to. * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile, @Nullable CompressionCodec compressionCodec) throws IOException { assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; // Use a counting output stream to avoid having to close the underlying file and ask // the file system for its size after each partition is written. final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( new FileOutputStream(outputFile)); boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new FileInputStream(spills[i].file); } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = mergedFileOutputStream.getByteCount(); // Shield the underlying output stream from close() calls, so that we can close the higher // level streams to make sure all data is really flushed and internal state is cleaned. OutputStream partitionOutput = new CloseShieldOutputStream( new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); try { partitionInputStream = blockManager.serializerManager() .wrapForEncryption(partitionInputStream); if (compressionCodec != null) { partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } ByteStreams.copy(partitionInputStream, partitionOutput); } finally { partitionInputStream.close(); } } } partitionOutput.flush(); partitionOutput.close(); partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); } threwException = false; } finally { // To avoid masking exceptions that caused us to prematurely enter the finally block, only // throw exceptions during cleanup if threwException == false. for (InputStream stream : spillInputStreams) { Closeables.close(stream, threwException); } Closeables.close(mergedFileOutputStream, threwException); } return partitionLengths; } /** * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. * This is only safe when the IO compression codec and serializer support concatenation of * serialized streams. * * @return the partition lengths in the merged file. */ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; FileChannel mergedFileOutputChannel = null; boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } // This file needs to opened in append mode in order to work around a Linux kernel bug that // affects transferTo; see SPARK-3948 for more details. mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); long bytesWrittenToMergedFile = 0; for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; long bytesToTransfer = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); while (bytesToTransfer > 0) { final long actualBytesTransferred = spillInputChannel.transferTo( spillInputChannelPositions[i], bytesToTransfer, mergedFileOutputChannel); spillInputChannelPositions[i] += actualBytesTransferred; bytesToTransfer -= actualBytesTransferred; } writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; } } // Check the position after transferTo loop to see if it is in the right position and raise an // exception if it is incorrect. The position will not be increased to the expected length // after calling transferTo in kernel version 2.6.32. This issue is described at // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { throw new IOException("Current position " + mergedFileOutputChannel.position() + " does not equal expected " + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + "to disable this NIO feature."); } threwException = false; } finally { // To avoid masking exceptions that caused us to prematurely enter the finally block, only // throw exceptions during cleanup if threwException == false. for (int i = 0; i < spills.length; i++) { assert (spillInputChannelPositions[i] == spills[i].file.length()); Closeables.close(spillInputChannels[i], threwException); } Closeables.close(mergedFileOutputChannel, threwException); } return partitionLengths; } @Override public Option<MapStatus> stop(boolean success) { try { taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); if (stopping) { return Option.apply(null); } else { stopping = true; if (success) { if (mapStatus == null) { throw new IllegalStateException("Cannot call stop(true) without having called write()"); } return Option.apply(mapStatus); } else { return Option.apply(null); } } } finally { if (sorter != null) { // If sorter is non-null, then this implies that we called stop() in response to an error, // so we need to clean up memory and spill files created by the sorter sorter.cleanupResources(); } } } }