package org.apache.spark.sql.rapids;

import ai.rapids.cudf.ContiguousTable;
import ai.rapids.cudf.DeviceMemoryBuffer;
import ai.rapids.cudf.NvtxColor;
import ai.rapids.cudf.NvtxRange;
import com.nvidia.spark.rapids.DegenerateRapidsBuffer;
import com.nvidia.spark.rapids.GpuCompressedColumnVector;
import com.nvidia.spark.rapids.GpuPackedTableColumn;
import com.nvidia.spark.rapids.MetaUtils$;
import com.nvidia.spark.rapids.RapidsDeviceMemoryStore;
import com.nvidia.spark.rapids.ShuffleBufferCatalog;
import com.nvidia.spark.rapids.ShuffleBufferId;
import com.nvidia.spark.rapids.SpillPriorities$;
import com.nvidia.spark.rapids.format.TableMeta;
import com.nvidia.spark.rapids.shuffle.RapidsShuffleServer;
import com.nvidia.spark.rapids.shuffle.RapidsShuffleTransport$;
import org.apache.spark.internal.Logging;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.storage.BlockManagerId$;
import org.apache.spark.storage.ShuffleBlockId;
import org.slf4j.Logger;
import scala.Function0;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Product2;
import scala.Some;
import scala.collection.Iterator;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.LongRef;

/* compiled from: RapidsShuffleInternalManagerBase.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\rf\u0001\u0002\f\u0018\u0001\tB\u0001\"\u0011\u0001\u0003\u0002\u0003\u0006IA\u0011\u0005\t\u0011\u0002\u0011\t\u0011)A\u0005\u0013\"AQ\n\u0001B\u0001B\u0003%a\n\u0003\u0005R\u0001\t\u0005\t\u0015!\u0003S\u0011!)\u0006A!A!\u0002\u00131\u0006\u0002\u00031\u0001\u0005\u0003\u0005\u000b\u0011B1\t\u0011\u0011\u0004!\u0011!Q\u0001\n\u0015D\u0001\"\u001c\u0001\u0003\u0002\u0003\u0006IA\u001c\u0005\b\u0003\u0013\u0001A\u0011AA\u0006\u0011%\ty\u0002\u0001b\u0001\n\u0013\t\t\u0003\u0003\u0005\u0002*\u0001\u0001\u000b\u0011BA\u0012\u0011%\tY\u0003\u0001b\u0001\n\u0013\ti\u0003\u0003\u0005\u00026\u0001\u0001\u000b\u0011BA\u0018\u0011%\t9\u0004\u0001b\u0001\n\u0013\tI\u0004\u0003\u0005\u0002R\u0001\u0001\u000b\u0011BA\u001e\u0011%\t\u0019\u0006\u0001b\u0001\n\u0013\t)\u0006C\u0004\u0002X\u0001\u0001\u000b\u0011\u0002?\t\u000f\u0005e\u0003\u0001\"\u0011\u0002\\!9\u0011q\u0010\u0001\u0005\n\u0005\u0005\u0005bBAB\u0001\u0011\u0005\u0013Q\u0011\u0005\b\u0003?\u0003A\u0011AAQ\u0005M\u0011\u0016\r]5eg\u000e\u000b7\r[5oO^\u0013\u0018\u000e^3s\u0015\tA\u0012$\u0001\u0004sCBLGm\u001d\u0006\u00035m\t1a]9m\u0015\taR$A\u0003ta\u0006\u00148N\u0003\u0002\u001f?\u00051\u0011\r]1dQ\u0016T\u0011\u0001I\u0001\u0004_J<7\u0001A\u000b\u0004G1J4c\u0001\u0001%wA!Q\u0005\u000b\u00169\u001b\u00051#BA\u0014\u001c\u0003\u001d\u0019\b.\u001e4gY\u0016L!!\u000b\u0014\u0003\u001bMCWO\u001a4mK^\u0013\u0018\u000e^3s!\tYC\u0006\u0004\u0001\u0005\u000b5\u0002!\u0019\u0001\u0018\u0003\u0003-\u000b\"aL\u001b\u0011\u0005A\u001aT\"A\u0019\u000b\u0003I\nQa]2bY\u0006L!\u0001N\u0019\u0003\u000f9{G\u000f[5oOB\u0011\u0001GN\u0005\u0003oE\u00121!\u00118z!\tY\u0013\bB\u0003;\u0001\t\u0007aFA\u0001W!\tat(D\u0001>\u0015\tq4$\u0001\u0005j]R,'O\\1m\u0013\t\u0001UHA\u0004M_\u001e<\u0017N\\4\u0002\u0019\tdwnY6NC:\fw-\u001a:\u0011\u0005\r3U\"\u0001#\u000b\u0005\u0015[\u0012aB:u_J\fw-Z\u0005\u0003\u000f\u0012\u0013AB\u00117pG.l\u0015M\\1hKJ\fa\u0001[1oI2,\u0007\u0003\u0002&LUaj\u0011aF\u0005\u0003\u0019^\u0011\u0001c\u00129v'\",hM\u001a7f\u0011\u0006tG\r\\3\u0002\u000b5\f\u0007/\u00133\u0011\u0005Az\u0015B\u0001)2\u0005\u0011auN\\4\u0002\u001f5,GO]5dgJ+\u0007o\u001c:uKJ\u0004\"!J*\n\u0005Q3#aG*ik\u001a4G.Z,sSR,W*\u001a;sS\u000e\u001c(+\u001a9peR,'/A\u0004dCR\fGn\\4\u0011\u0005]sV\"\u0001-\u000b\u0005aI&B\u0001\u000f[\u0015\tYF,\u0001\u0004om&$\u0017.\u0019\u0006\u0002;\u0006\u00191m\\7\n\u0005}C&\u0001F*ik\u001a4G.\u001a\"vM\u001a,'oQ1uC2|w-\u0001\btQV4g\r\\3Ti>\u0014\u0018mZ3\u0011\u0005]\u0013\u0017BA2Y\u0005]\u0011\u0016\r]5eg\u0012+g/[2f\u001b\u0016lwN]=Ti>\u0014X-A\nsCBLGm]*ik\u001a4G.Z*feZ,'\u000fE\u00021M\"L!aZ\u0019\u0003\r=\u0003H/[8o!\tI7.D\u0001k\u0015\t9\u0003,\u0003\u0002mU\n\u0019\"+\u00199jIN\u001c\u0006.\u001e4gY\u0016\u001cVM\u001d<fe\u00069Q.\u001a;sS\u000e\u001c\b\u0003B8wsrt!\u0001\u001d;\u0011\u0005E\fT\"\u0001:\u000b\u0005M\f\u0013A\u0002\u001fs_>$h(\u0003\u0002vc\u00051\u0001K]3eK\u001aL!a\u001e=\u0003\u00075\u000b\u0007O\u0003\u0002vcA\u0011qN_\u0005\u0003wb\u0014aa\u0015;sS:<\u0007cA?\u0002\u00065\taPC\u0002��\u0003\u0003\ta!\\3ue&\u001c'bAA\u00023\u0005IQ\r_3dkRLwN\\\u0005\u0004\u0003\u000fq(!C*R\u00196+GO]5d\u0003\u0019a\u0014N\\5u}Q\u0011\u0012QBA\b\u0003#\t\u0019\"!\u0006\u0002\u0018\u0005e\u00111DA\u000f!\u0011Q\u0005A\u000b\u001d\t\u000b\u0005K\u0001\u0019\u0001\"\t\u000b!K\u0001\u0019A%\t\u000b5K\u0001\u0019\u0001(\t\u000bEK\u0001\u0019\u0001*\t\u000bUK\u0001\u0019\u0001,\t\u000b\u0001L\u0001\u0019A1\t\u000b\u0011L\u0001\u0019A3\t\u000b5L\u0001\u0019\u00018\u0002\u00119,X\u000eU1siN,\"!a\t\u0011\u0007A\n)#C\u0002\u0002(E\u00121!\u00138u\u0003%qW/\u001c)beR\u001c\b%A\u0003tSj,7/\u0006\u0002\u00020A!\u0001'!\rO\u0013\r\t\u0019$\r\u0002\u0006\u0003J\u0014\u0018-_\u0001\u0007g&TXm\u001d\u0011\u0002!]\u0014\u0018\u000e\u001e;f]\n+hMZ3s\u0013\u0012\u001cXCAA\u001e!\u0019\ti$a\u0012\u0002L5\u0011\u0011q\b\u0006\u0005\u0003\u0003\n\u0019%A\u0004nkR\f'\r\\3\u000b\u0007\u0005\u0015\u0013'\u0001\u0006d_2dWm\u0019;j_:LA!!\u0013\u0002@\tY\u0011I\u001d:bs\n+hMZ3s!\r9\u0016QJ\u0005\u0004\u0003\u001fB&aD*ik\u001a4G.\u001a\"vM\u001a,'/\u00133\u0002#]\u0014\u0018\u000e\u001e;f]\n+hMZ3s\u0013\u0012\u001c\b%\u0001\nv]\u000e|W\u000e\u001d:fgN,G-T3ue&\u001cW#\u0001?\u0002'Ut7m\\7qe\u0016\u001c8/\u001a3NKR\u0014\u0018n\u0019\u0011\u0002\u000b]\u0014\u0018\u000e^3\u0015\t\u0005u\u00131\r\t\u0004a\u0005}\u0013bAA1c\t!QK\\5u\u0011\u001d\t)G\u0005a\u0001\u0003O\nqA]3d_J$7\u000f\u0005\u0004\u0002j\u0005M\u0014\u0011\u0010\b\u0005\u0003W\nyGD\u0002r\u0003[J\u0011AM\u0005\u0004\u0003c\n\u0014a\u00029bG.\fw-Z\u0005\u0005\u0003k\n9H\u0001\u0005Ji\u0016\u0014\u0018\r^8s\u0015\r\t\t(\r\t\u0006a\u0005m$\u0006O\u0005\u0004\u0003{\n$\u0001\u0003)s_\u0012,8\r\u001e\u001a\u0002\u0019\rdW-\u00198Ti>\u0014\u0018mZ3\u0015\u0005\u0005u\u0013\u0001B:u_B$B!a\"\u0002\u0016B!\u0001GZAE!\u0011\tY)!%\u000e\u0005\u00055%bAAH7\u0005I1o\u00195fIVdWM]\u0005\u0005\u0003'\u000biIA\u0005NCB\u001cF/\u0019;vg\"9\u0011q\u0013\u000bA\u0002\u0005e\u0015aB:vG\u000e,7o\u001d\t\u0004a\u0005m\u0015bAAOc\t9!i\\8mK\u0006t\u0017aE4fiB\u000b'\u000f^5uS>tG*\u001a8hi\"\u001cHCAA\u0018\u0001")
/* loaded from: input_file:org/apache/spark/sql/rapids/RapidsCachingWriter.class */
public class RapidsCachingWriter<K, V> extends ShuffleWriter<K, V> implements Logging {
    private final BlockManager blockManager;
    private final GpuShuffleHandle<K, V> handle;
    private final long mapId;
    private final ShuffleWriteMetricsReporter metricsReporter;
    private final ShuffleBufferCatalog catalog;
    private final RapidsDeviceMemoryStore shuffleStorage;
    private final Option<RapidsShuffleServer> rapidsShuffleServer;
    private final int numParts;
    private final long[] sizes;
    private final ArrayBuffer<ShuffleBufferId> writtenBufferIds;
    private final SQLMetric uncompressedMetric;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    private int numParts() {
        return this.numParts;
    }

    private long[] sizes() {
        return this.sizes;
    }

    private ArrayBuffer<ShuffleBufferId> writtenBufferIds() {
        return this.writtenBufferIds;
    }

    private SQLMetric uncompressedMetric() {
        return this.uncompressedMetric;
    }

    public void write(Iterator<Product2<K, V>> iterator) {
        NvtxRange nvtxRange = new NvtxRange("RapidsCachingWriter.write", NvtxColor.CYAN);
        try {
            LongRef create = LongRef.create(0L);
            LongRef create2 = LongRef.create(0L);
            iterator.foreach(product2 -> {
                $anonfun$write$10(this, create2, create, product2);
                return BoxedUnit.UNIT;
            });
            this.metricsReporter.incBytesWritten(create.elem);
            this.metricsReporter.incRecordsWritten(create2.elem);
        } finally {
            nvtxRange.close();
        }
    }

    private void cleanStorage() {
        writtenBufferIds().foreach(shuffleBufferId -> {
            $anonfun$cleanStorage$1(this, shuffleBufferId);
            return BoxedUnit.UNIT;
        });
    }

    public Option<MapStatus> stop(boolean z) {
        BlockManagerId shuffleServerId;
        None$ some;
        NvtxRange nvtxRange = new NvtxRange("RapidsCachingWriter.close", NvtxColor.CYAN);
        try {
            if (z) {
                if (this.rapidsShuffleServer.isDefined()) {
                    BlockManagerId originalShuffleServerId = ((RapidsShuffleServer) this.rapidsShuffleServer.get()).originalShuffleServerId();
                    shuffleServerId = BlockManagerId$.MODULE$.apply(originalShuffleServerId.executorId(), originalShuffleServerId.host(), originalShuffleServerId.port(), new Some(new StringBuilder(1).append(RapidsShuffleTransport$.MODULE$.BLOCK_MANAGER_ID_TOPO_PREFIX()).append("=").append(((RapidsShuffleServer) this.rapidsShuffleServer.get()).getPort()).toString()));
                } else {
                    shuffleServerId = this.blockManager.shuffleServerId();
                }
                BlockManagerId blockManagerId = shuffleServerId;
                logInfo(() -> {
                    return new StringBuilder(58).append("Done caching shuffle success=").append(z).append(", server_id=").append(blockManagerId).append(", ").append("map_id=").append(this.mapId).append(", sizes=").append(new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(this.sizes())).mkString(",")).toString();
                });
                some = new Some(MapStatus$.MODULE$.apply(blockManagerId, sizes(), this.mapId));
            } else {
                cleanStorage();
                some = None$.MODULE$;
            }
            return some;
        } finally {
            nvtxRange.close();
        }
    }

    public long[] getPartitionLengths() {
        throw new UnsupportedOperationException("TODO");
    }

    public static final /* synthetic */ void $anonfun$write$10(RapidsCachingWriter rapidsCachingWriter, LongRef longRef, LongRef longRef2, Product2 product2) {
        long length;
        int unboxToInt = BoxesRunTime.unboxToInt(product2._1());
        ColumnarBatch columnarBatch = (ColumnarBatch) product2._2();
        rapidsCachingWriter.logDebug(() -> {
            return new StringBuilder(66).append("Caching shuffle_id=").append(rapidsCachingWriter.handle.shuffleId()).append(" map_id=").append(rapidsCachingWriter.mapId).append(", partId=").append(unboxToInt).append(", ").append("batch=[num_cols=").append(columnarBatch.numCols()).append(", num_rows=").append(columnarBatch.numRows()).append("]").toString();
        });
        longRef.elem += columnarBatch.numRows();
        ShuffleBufferId nextShuffleBufferId = rapidsCachingWriter.catalog.nextShuffleBufferId(new ShuffleBlockId(rapidsCachingWriter.handle.shuffleId(), rapidsCachingWriter.mapId, unboxToInt));
        if (columnarBatch.numRows() <= 0 || columnarBatch.numCols() <= 0) {
            rapidsCachingWriter.catalog.registerNewBuffer(new DegenerateRapidsBuffer(nextShuffleBufferId, MetaUtils$.MODULE$.buildDegenerateTableMeta(columnarBatch)));
            if (columnarBatch.numRows() > 0) {
                rapidsCachingWriter.sizes()[unboxToInt] = rapidsCachingWriter.sizes()[unboxToInt] + 100;
            }
        } else {
            ColumnVector column = columnarBatch.column(0);
            if (column instanceof GpuPackedTableColumn) {
                GpuPackedTableColumn gpuPackedTableColumn = (GpuPackedTableColumn) column;
                ContiguousTable contiguousTable = gpuPackedTableColumn.getContiguousTable();
                length = gpuPackedTableColumn.getTableBuffer().getLength();
                rapidsCachingWriter.uncompressedMetric().$plus$eq(length);
                rapidsCachingWriter.shuffleStorage.addContiguousTable(nextShuffleBufferId, contiguousTable, SpillPriorities$.MODULE$.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY(), rapidsCachingWriter.shuffleStorage.addContiguousTable$default$4(), false);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!(column instanceof GpuCompressedColumnVector)) {
                    throw new IllegalStateException(new StringBuilder(24).append("Unexpected column type: ").append(column.getClass()).toString());
                }
                GpuCompressedColumnVector gpuCompressedColumnVector = (GpuCompressedColumnVector) column;
                DeviceMemoryBuffer tableBuffer = gpuCompressedColumnVector.getTableBuffer();
                tableBuffer.incRefCount();
                length = tableBuffer.getLength();
                TableMeta tableMeta = gpuCompressedColumnVector.getTableMeta();
                tableMeta.bufferMeta().mutateId(nextShuffleBufferId.tableId());
                rapidsCachingWriter.uncompressedMetric().$plus$eq(tableMeta.bufferMeta().uncompressedSize());
                rapidsCachingWriter.shuffleStorage.addBuffer(nextShuffleBufferId, tableBuffer, tableMeta, SpillPriorities$.MODULE$.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY(), rapidsCachingWriter.shuffleStorage.addBuffer$default$5(), false);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
            longRef2.elem += length;
            rapidsCachingWriter.sizes()[unboxToInt] = rapidsCachingWriter.sizes()[unboxToInt] + length;
        }
        rapidsCachingWriter.writtenBufferIds().append(Predef$.MODULE$.wrapRefArray(new ShuffleBufferId[]{nextShuffleBufferId}));
    }

    public static final /* synthetic */ void $anonfun$cleanStorage$1(RapidsCachingWriter rapidsCachingWriter, ShuffleBufferId shuffleBufferId) {
        rapidsCachingWriter.catalog.removeBuffer(shuffleBufferId);
    }

    public RapidsCachingWriter(BlockManager blockManager, GpuShuffleHandle<K, V> gpuShuffleHandle, long j, ShuffleWriteMetricsReporter shuffleWriteMetricsReporter, ShuffleBufferCatalog shuffleBufferCatalog, RapidsDeviceMemoryStore rapidsDeviceMemoryStore, Option<RapidsShuffleServer> option, Map<String, SQLMetric> map) {
        this.blockManager = blockManager;
        this.handle = gpuShuffleHandle;
        this.mapId = j;
        this.metricsReporter = shuffleWriteMetricsReporter;
        this.catalog = shuffleBufferCatalog;
        this.shuffleStorage = rapidsDeviceMemoryStore;
        this.rapidsShuffleServer = option;
        Logging.$init$(this);
        this.numParts = gpuShuffleHandle.m1763dependency().partitioner().numPartitions();
        this.sizes = new long[numParts()];
        this.writtenBufferIds = new ArrayBuffer<>(numParts());
        this.uncompressedMetric = (SQLMetric) map.apply("dataSize");
    }
}
