package org.apache.spark.shuffle.sort;

import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.internal.config.package$;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.sparkproject.guava.annotations.VisibleForTesting;
import org.sparkproject.guava.io.ByteStreams;
import org.sparkproject.guava.io.Closeables;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

@Private
/* loaded from: input_file:org/apache/spark/shuffle/sort/UnsafeShuffleWriter.class */
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
    private static final Logger logger;
    private static final ClassTag<Object> OBJECT_CLASS_TAG;

    @VisibleForTesting
    static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1048576;
    private final BlockManager blockManager;
    private final TaskMemoryManager memoryManager;
    private final SerializerInstance serializer;
    private final Partitioner partitioner;
    private final ShuffleWriteMetricsReporter writeMetrics;
    private final ShuffleExecutorComponents shuffleExecutorComponents;
    private final int shuffleId;
    private final long mapId;
    private final TaskContext taskContext;
    private final SparkConf sparkConf;
    private final boolean transferToEnabled;
    private final int initialSortBufferSize;
    private final int inputBufferSizeInBytes;

    @Nullable
    private MapStatus mapStatus;

    @Nullable
    private ShuffleExternalSorter sorter;
    private MyByteArrayOutputStream serBuffer;
    private SerializationStream serOutputStream;
    static final /* synthetic */ boolean $assertionsDisabled;
    private long peakMemoryUsedBytes = 0;
    private boolean stopping = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/shuffle/sort/UnsafeShuffleWriter$MyByteArrayOutputStream.class */
    public static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
        MyByteArrayOutputStream(int i) {
            super(i);
        }

        public byte[] getBuf() {
            return this.buf;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/shuffle/sort/UnsafeShuffleWriter$StreamFallbackChannelWrapper.class */
    public static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper {
        private final WritableByteChannel channel;

        StreamFallbackChannelWrapper(OutputStream outputStream) {
            this.channel = Channels.newChannel(outputStream);
        }

        @Override // org.apache.spark.shuffle.api.WritableByteChannelWrapper
        public WritableByteChannel channel() {
            return this.channel;
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() throws IOException {
            this.channel.close();
        }
    }

    public UnsafeShuffleWriter(BlockManager blockManager, TaskMemoryManager taskMemoryManager, SerializedShuffleHandle<K, V> serializedShuffleHandle, long j, TaskContext taskContext, SparkConf sparkConf, ShuffleWriteMetricsReporter shuffleWriteMetricsReporter, ShuffleExecutorComponents shuffleExecutorComponents) {
        if (serializedShuffleHandle.dependency().partitioner().numPartitions() > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
            throw new IllegalArgumentException("UnsafeShuffleWriter can only be used for shuffles with at most " + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions");
        }
        this.blockManager = blockManager;
        this.memoryManager = taskMemoryManager;
        this.mapId = j;
        ShuffleDependency<K, V, V> dependency = serializedShuffleHandle.dependency();
        this.shuffleId = dependency.shuffleId();
        this.serializer = dependency.serializer().newInstance();
        this.partitioner = dependency.partitioner();
        this.writeMetrics = shuffleWriteMetricsReporter;
        this.shuffleExecutorComponents = shuffleExecutorComponents;
        this.taskContext = taskContext;
        this.sparkConf = sparkConf;
        this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
        this.initialSortBufferSize = (int) ((Long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE())).longValue();
        this.inputBufferSizeInBytes = ((int) ((Long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE())).longValue()) * 1024;
        open();
    }

    private void updatePeakMemoryUsed() {
        if (this.sorter != null) {
            long peakMemoryUsedBytes = this.sorter.getPeakMemoryUsedBytes();
            if (peakMemoryUsedBytes > this.peakMemoryUsedBytes) {
                this.peakMemoryUsedBytes = peakMemoryUsedBytes;
            }
        }
    }

    public long getPeakMemoryUsedBytes() {
        updatePeakMemoryUsed();
        return this.peakMemoryUsedBytes;
    }

    @VisibleForTesting
    public void write(Iterator<Product2<K, V>> it) throws IOException {
        write((scala.collection.Iterator) JavaConverters.asScalaIteratorConverter(it).asScala());
    }

    @Override // org.apache.spark.shuffle.ShuffleWriter
    public void write(scala.collection.Iterator<Product2<K, V>> iterator) throws IOException {
        boolean z = false;
        while (iterator.hasNext()) {
            try {
                insertRecordIntoSorter((Product2) iterator.next());
            } catch (Throwable th) {
                if (this.sorter != null) {
                    try {
                        this.sorter.cleanupResources();
                    } catch (Exception e) {
                        if (z) {
                            throw e;
                        }
                        logger.error("In addition to a failure during writing, we failed during cleanup.", e);
                    }
                }
                throw th;
            }
        }
        closeAndWriteOutput();
        z = true;
        if (this.sorter != null) {
            try {
                this.sorter.cleanupResources();
            } catch (Exception e2) {
                if (1 != 0) {
                    throw e2;
                }
                logger.error("In addition to a failure during writing, we failed during cleanup.", e2);
            }
        }
    }

    private void open() {
        if (!$assertionsDisabled && this.sorter != null) {
            throw new AssertionError();
        }
        this.sorter = new ShuffleExternalSorter(this.memoryManager, this.blockManager, this.taskContext, this.initialSortBufferSize, this.partitioner.numPartitions(), this.sparkConf, this.writeMetrics);
        this.serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
        this.serOutputStream = this.serializer.serializeStream(this.serBuffer);
    }

    @VisibleForTesting
    void closeAndWriteOutput() throws IOException {
        if (!$assertionsDisabled && this.sorter == null) {
            throw new AssertionError();
        }
        updatePeakMemoryUsed();
        this.serBuffer = null;
        this.serOutputStream = null;
        SpillInfo[] closeAndGetSpills = this.sorter.closeAndGetSpills();
        this.sorter = null;
        try {
            long[] mergeSpills = mergeSpills(closeAndGetSpills);
            for (SpillInfo spillInfo : closeAndGetSpills) {
                if (spillInfo.file.exists() && !spillInfo.file.delete()) {
                    logger.error("Error while deleting spill file {}", spillInfo.file.getPath());
                }
            }
            this.mapStatus = MapStatus$.MODULE$.apply(this.blockManager.shuffleServerId(), mergeSpills, this.mapId);
        } catch (Throwable th) {
            for (SpillInfo spillInfo2 : closeAndGetSpills) {
                if (spillInfo2.file.exists() && !spillInfo2.file.delete()) {
                    logger.error("Error while deleting spill file {}", spillInfo2.file.getPath());
                }
            }
            throw th;
        }
    }

    @VisibleForTesting
    void insertRecordIntoSorter(Product2<K, V> product2) throws IOException {
        if (!$assertionsDisabled && this.sorter == null) {
            throw new AssertionError();
        }
        Object _1 = product2._1();
        int partition = this.partitioner.getPartition(_1);
        this.serBuffer.reset();
        this.serOutputStream.writeKey(_1, OBJECT_CLASS_TAG);
        this.serOutputStream.writeValue(product2._2(), OBJECT_CLASS_TAG);
        this.serOutputStream.flush();
        int size = this.serBuffer.size();
        if (!$assertionsDisabled && size <= 0) {
            throw new AssertionError();
        }
        this.sorter.insertRecord(this.serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, size, partition);
    }

    @VisibleForTesting
    void forceSorterToSpill() throws IOException {
        if (!$assertionsDisabled && this.sorter == null) {
            throw new AssertionError();
        }
        this.sorter.spill();
    }

    private long[] mergeSpills(SpillInfo[] spillInfoArr) throws IOException {
        long[] mergeSpillsUsingStandardWriter;
        if (spillInfoArr.length == 0) {
            return this.shuffleExecutorComponents.createMapOutputWriter(this.shuffleId, this.mapId, this.partitioner.numPartitions()).commitAllPartitions().getPartitionLengths();
        }
        if (spillInfoArr.length == 1) {
            Optional<SingleSpillShuffleMapOutputWriter> createSingleFileMapOutputWriter = this.shuffleExecutorComponents.createSingleFileMapOutputWriter(this.shuffleId, this.mapId);
            if (createSingleFileMapOutputWriter.isPresent()) {
                mergeSpillsUsingStandardWriter = spillInfoArr[0].partitionLengths;
                logger.debug("Merge shuffle spills for mapId {} with length {}", Long.valueOf(this.mapId), Integer.valueOf(mergeSpillsUsingStandardWriter.length));
                createSingleFileMapOutputWriter.get().transferMapSpillFile(spillInfoArr[0].file, mergeSpillsUsingStandardWriter);
            } else {
                mergeSpillsUsingStandardWriter = mergeSpillsUsingStandardWriter(spillInfoArr);
            }
        } else {
            mergeSpillsUsingStandardWriter = mergeSpillsUsingStandardWriter(spillInfoArr);
        }
        return mergeSpillsUsingStandardWriter;
    }

    private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spillInfoArr) throws IOException {
        boolean booleanValue = ((Boolean) this.sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS())).booleanValue();
        CompressionCodec createCodec = CompressionCodec$.MODULE$.createCodec(this.sparkConf);
        boolean booleanValue2 = ((Boolean) this.sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE())).booleanValue();
        boolean z = !booleanValue || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(createCodec);
        boolean encryptionEnabled = this.blockManager.serializerManager().encryptionEnabled();
        ShuffleMapOutputWriter createMapOutputWriter = this.shuffleExecutorComponents.createMapOutputWriter(this.shuffleId, this.mapId, this.partitioner.numPartitions());
        try {
            if (!booleanValue2 || !z) {
                logger.debug("Using slow merge");
                mergeSpillsWithFileStream(spillInfoArr, createMapOutputWriter, createCodec);
            } else if (!this.transferToEnabled || encryptionEnabled) {
                logger.debug("Using fileStream-based fast merge");
                mergeSpillsWithFileStream(spillInfoArr, createMapOutputWriter, null);
            } else {
                logger.debug("Using transferTo-based fast merge");
                mergeSpillsWithTransferTo(spillInfoArr, createMapOutputWriter);
            }
            this.writeMetrics.decBytesWritten(spillInfoArr[spillInfoArr.length - 1].file.length());
            return createMapOutputWriter.commitAllPartitions().getPartitionLengths();
        } catch (Exception e) {
            try {
                createMapOutputWriter.abort(e);
            } catch (Exception e2) {
                logger.warn("Failed to abort writing the map output.", e2);
                e.addSuppressed(e2);
            }
            throw e;
        }
    }

    private void mergeSpillsWithFileStream(SpillInfo[] spillInfoArr, ShuffleMapOutputWriter shuffleMapOutputWriter, @Nullable CompressionCodec compressionCodec) throws IOException {
        logger.debug("Merge shuffle spills with FileStream for mapId {}", Long.valueOf(this.mapId));
        int numPartitions = this.partitioner.numPartitions();
        InputStream[] inputStreamArr = new InputStream[spillInfoArr.length];
        for (int i = 0; i < spillInfoArr.length; i++) {
            try {
                inputStreamArr[i] = new NioBufferedFileInputStream(spillInfoArr[i].file, this.inputBufferSizeInBytes);
                if (logger.isDebugEnabled()) {
                    logger.debug("Partition lengths for mapId {} in Spill {}: {}", new Object[]{Long.valueOf(this.mapId), Integer.valueOf(i), Arrays.toString(spillInfoArr[i].partitionLengths)});
                }
            } catch (Throwable th) {
                for (InputStream inputStream : inputStreamArr) {
                    Closeables.close(inputStream, true);
                }
                throw th;
            }
        }
        for (int i2 = 0; i2 < numPartitions; i2++) {
            ShufflePartitionWriter partitionWriter = shuffleMapOutputWriter.getPartitionWriter(i2);
            OutputStream openStream = partitionWriter.openStream();
            try {
                openStream = this.blockManager.serializerManager().wrapForEncryption(new TimeTrackingOutputStream(this.writeMetrics, openStream));
                if (compressionCodec != null) {
                    openStream = compressionCodec.compressedOutputStream(openStream);
                }
                for (int i3 = 0; i3 < spillInfoArr.length; i3++) {
                    long j = spillInfoArr[i3].partitionLengths[i2];
                    if (j > 0) {
                        InputStream inputStream2 = null;
                        try {
                            inputStream2 = this.blockManager.serializerManager().wrapForEncryption(new LimitedInputStream(inputStreamArr[i3], j, false));
                            if (compressionCodec != null) {
                                inputStream2 = compressionCodec.compressedInputStream(inputStream2);
                            }
                            ByteStreams.copy(inputStream2, openStream);
                            Closeables.close(inputStream2, false);
                        } catch (Throwable th2) {
                            Closeables.close(inputStream2, true);
                            throw th2;
                        }
                    }
                }
                Closeables.close(openStream, false);
                this.writeMetrics.incBytesWritten(partitionWriter.getNumBytesWritten());
            } catch (Throwable th3) {
                Closeables.close(openStream, true);
                throw th3;
            }
        }
        for (InputStream inputStream3 : inputStreamArr) {
            Closeables.close(inputStream3, false);
        }
    }

    /* JADX WARN: Finally extract failed */
    private void mergeSpillsWithTransferTo(SpillInfo[] spillInfoArr, ShuffleMapOutputWriter shuffleMapOutputWriter) throws IOException {
        logger.debug("Merge shuffle spills with TransferTo for mapId {}", Long.valueOf(this.mapId));
        int numPartitions = this.partitioner.numPartitions();
        FileChannel[] fileChannelArr = new FileChannel[spillInfoArr.length];
        long[] jArr = new long[spillInfoArr.length];
        for (int i = 0; i < spillInfoArr.length; i++) {
            try {
                fileChannelArr[i] = new FileInputStream(spillInfoArr[i].file).getChannel();
                if (logger.isDebugEnabled()) {
                    logger.debug("Partition lengths for mapId {} in Spill {}: {}", new Object[]{Long.valueOf(this.mapId), Integer.valueOf(i), Arrays.toString(spillInfoArr[i].partitionLengths)});
                }
            } catch (Throwable th) {
                for (int i2 = 0; i2 < spillInfoArr.length; i2++) {
                    if (!$assertionsDisabled && jArr[i2] != spillInfoArr[i2].file.length()) {
                        throw new AssertionError();
                    }
                    Closeables.close(fileChannelArr[i2], true);
                }
                throw th;
            }
        }
        for (int i3 = 0; i3 < numPartitions; i3++) {
            boolean z = true;
            ShufflePartitionWriter partitionWriter = shuffleMapOutputWriter.getPartitionWriter(i3);
            WritableByteChannelWrapper orElseGet = partitionWriter.openChannelWrapper().orElseGet(() -> {
                return new StreamFallbackChannelWrapper(openStreamUnchecked(partitionWriter));
            });
            for (int i4 = 0; i4 < spillInfoArr.length; i4++) {
                try {
                    long j = spillInfoArr[i4].partitionLengths[i3];
                    FileChannel fileChannel = fileChannelArr[i4];
                    long nanoTime = System.nanoTime();
                    Utils.copyFileStreamNIO(fileChannel, orElseGet.channel(), jArr[i4], j);
                    z = false;
                    int i5 = i4;
                    jArr[i5] = jArr[i5] + j;
                    this.writeMetrics.incWriteTime(System.nanoTime() - nanoTime);
                } catch (Throwable th2) {
                    Closeables.close(orElseGet, z);
                    throw th2;
                }
            }
            Closeables.close(orElseGet, z);
            this.writeMetrics.incBytesWritten(partitionWriter.getNumBytesWritten());
        }
        for (int i6 = 0; i6 < spillInfoArr.length; i6++) {
            if (!$assertionsDisabled && jArr[i6] != spillInfoArr[i6].file.length()) {
                throw new AssertionError();
            }
            Closeables.close(fileChannelArr[i6], false);
        }
    }

    @Override // org.apache.spark.shuffle.ShuffleWriter
    public Option<MapStatus> stop(boolean z) {
        try {
            this.taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
            if (this.stopping) {
                return Option.apply((Object) null);
            }
            this.stopping = true;
            if (!z) {
                Option<MapStatus> apply = Option.apply((Object) null);
                if (this.sorter != null) {
                    this.sorter.cleanupResources();
                }
                return apply;
            }
            if (this.mapStatus == null) {
                throw new IllegalStateException("Cannot call stop(true) without having called write()");
            }
            Option<MapStatus> apply2 = Option.apply(this.mapStatus);
            if (this.sorter != null) {
                this.sorter.cleanupResources();
            }
            return apply2;
        } finally {
            if (this.sorter != null) {
                this.sorter.cleanupResources();
            }
        }
    }

    private static OutputStream openStreamUnchecked(ShufflePartitionWriter shufflePartitionWriter) {
        try {
            return shufflePartitionWriter.openStream();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static {
        $assertionsDisabled = !UnsafeShuffleWriter.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
        OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
    }
}
