/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.sql.rapids;

import ai.rapids.cudf.ContiguousTable;
import ai.rapids.cudf.DeviceMemoryBuffer;
import ai.rapids.cudf.NvtxColor;
import ai.rapids.cudf.NvtxRange;
import com.nvidia.spark.rapids.DegenerateRapidsBuffer;
import com.nvidia.spark.rapids.GpuCompressedColumnVector;
import com.nvidia.spark.rapids.GpuPackedTableColumn;
import com.nvidia.spark.rapids.MetaUtils$;
import com.nvidia.spark.rapids.RapidsDeviceMemoryStore;
import com.nvidia.spark.rapids.ShuffleBufferCatalog;
import com.nvidia.spark.rapids.ShuffleBufferId;
import com.nvidia.spark.rapids.SpillCallback;
import com.nvidia.spark.rapids.SpillPriorities$;
import com.nvidia.spark.rapids.format.TableMeta;
import com.nvidia.spark.rapids.shuffle.RapidsShuffleServer;
import com.nvidia.spark.rapids.shuffle.RapidsShuffleTransport$;
import java.io.Serializable;
import org.apache.spark.internal.Logging;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.sql.rapids.GpuShuffleHandle;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.storage.BlockManagerId$;
import org.apache.spark.storage.ShuffleBlockId;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Product2;
import scala.Some;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.LongRef;

@ScalaSignature(bytes="\u0006\u0001\u0005\rf\u0001\u0002\f\u0018\u0001\tB\u0001\"\u0011\u0001\u0003\u0002\u0003\u0006IA\u0011\u0005\t\u0011\u0002\u0011\t\u0011)A\u0005\u0013\"AQ\n\u0001B\u0001B\u0003%a\n\u0003\u0005R\u0001\t\u0005\t\u0015!\u0003S\u0011!)\u0006A!A!\u0002\u00131\u0006\u0002\u00031\u0001\u0005\u0003\u0005\u000b\u0011B1\t\u0011\u0011\u0004!\u0011!Q\u0001\n\u0015D\u0001\"\u001c\u0001\u0003\u0002\u0003\u0006IA\u001c\u0005\b\u0003\u0013\u0001A\u0011AA\u0006\u0011%\ty\u0002\u0001b\u0001\n\u0013\t\t\u0003\u0003\u0005\u0002*\u0001\u0001\u000b\u0011BA\u0012\u0011%\tY\u0003\u0001b\u0001\n\u0013\ti\u0003\u0003\u0005\u00026\u0001\u0001\u000b\u0011BA\u0018\u0011%\t9\u0004\u0001b\u0001\n\u0013\tI\u0004\u0003\u0005\u0002R\u0001\u0001\u000b\u0011BA\u001e\u0011%\t\u0019\u0006\u0001b\u0001\n\u0013\t)\u0006C\u0004\u0002X\u0001\u0001\u000b\u0011\u0002?\t\u000f\u0005e\u0003\u0001\"\u0011\u0002\\!9\u0011q\u0010\u0001\u0005\n\u0005\u0005\u0005bBAB\u0001\u0011\u0005\u0013Q\u0011\u0005\b\u0003?\u0003A\u0011AAQ\u0005M\u0011\u0016\r]5eg\u000e\u000b7\r[5oO^\u0013\u0018\u000e^3s\u0015\tA\u0012$\u0001\u0004sCBLGm\u001d\u0006\u00035m\t1a]9m\u0015\taR$A\u0003ta\u0006\u00148N\u0003\u0002\u001f?\u00051\u0011\r]1dQ\u0016T\u0011\u0001I\u0001\u0004_J<7\u0001A\u000b\u0004G1J4c\u0001\u0001%wA!Q\u0005\u000b\u00169\u001b\u00051#BA\u0014\u001c\u0003\u001d\u0019\b.\u001e4gY\u0016L!!\u000b\u0014\u0003\u001bMCWO\u001a4mK^\u0013\u0018\u000e^3s!\tYC\u0006\u0004\u0001\u0005\u000b5\u0002!\u0019\u0001\u0018\u0003\u0003-\u000b\"aL\u001b\u0011\u0005A\u001aT\"A\u0019\u000b\u0003I\nQa]2bY\u0006L!\u0001N\u0019\u0003\u000f9{G\u000f[5oOB\u0011\u0001GN\u0005\u0003oE\u00121!\u00118z!\tY\u0013\bB\u0003;\u0001\t\u0007aFA\u0001W!\tat(D\u0001>\u0015\tq4$\u0001\u0005j]R,'O\\1m\u0013\t\u0001UHA\u0004M_\u001e<\u0017N\\4\u0002\u0019\tdwnY6NC:\fw-\u001a:\u0011\u0005\r3U\"\u0001#\u000b\u0005\u0015[\u0012aB:u_J\fw-Z\u0005\u0003\u000f\u0012\u0013AB\u00117pG.l\u0015M\\1hKJ\fa\u0001[1oI2,\u0007\u0003\u0002&LUaj\u0011aF\u0005\u0003\u0019^\u0011\u0001c\u00129v'\",hM\u001a7f\u0011\u0006tG\r\\3\u0002\u000b5\f\u0007/\u00133\u0011\u0005Az\u0015B\u0001)2\u0005\u0011auN\\4\u0002\u001f5,GO]5dgJ+\u0007o\u001c:uKJ\u0004\"!J*\n\u0005Q3#aG*ik\u001a4G.Z,sSR,W*\u001a;sS\u000e\u001c(+\u001a9peR,'/A\u0004dCR\fGn\\4\u0011\u0005]sV\"\u0001-\u000b\u0005aI&B\u0001\u000f[\u0015\tYF,\u0001\u0004om&$\u0017.\u0019\u0006\u0002;\u0006\u00191m\\7\n\u0005}C&\u0001F*ik\u001a4G.\u001a\"vM\u001a,'oQ1uC2|w-\u0001\btQV4g\r\\3Ti>\u0014\u0018mZ3\u0011\u0005]\u0013\u0017BA2Y\u0005]\u0011\u0016\r]5eg\u0012+g/[2f\u001b\u0016lwN]=Ti>\u0014X-A\nsCBLGm]*ik\u001a4G.Z*feZ,'\u000fE\u00021M\"L!aZ\u0019\u0003\r=\u0003H/[8o!\tI7.D\u0001k\u0015\t9\u0003,\u0003\u0002mU\n\u0019\"+\u00199jIN\u001c\u0006.\u001e4gY\u0016\u001cVM\u001d<fe\u00069Q.\u001a;sS\u000e\u001c\b\u0003B8wsrt!\u0001\u001d;\u0011\u0005E\fT\"\u0001:\u000b\u0005M\f\u0013A\u0002\u001fs_>$h(\u0003\u0002vc\u00051\u0001K]3eK\u001aL!a\u001e=\u0003\u00075\u000b\u0007O\u0003\u0002vcA\u0011qN_\u0005\u0003wb\u0014aa\u0015;sS:<\u0007cA?\u0002\u00065\taPC\u0002\u0000\u0003\u0003\ta!\\3ue&\u001c'bAA\u00023\u0005IQ\r_3dkRLwN\\\u0005\u0004\u0003\u000fq(!C*R\u00196+GO]5d\u0003\u0019a\u0014N\\5u}Q\u0011\u0012QBA\b\u0003#\t\u0019\"!\u0006\u0002\u0018\u0005e\u00111DA\u000f!\u0011Q\u0005A\u000b\u001d\t\u000b\u0005K\u0001\u0019\u0001\"\t\u000b!K\u0001\u0019A%\t\u000b5K\u0001\u0019\u0001(\t\u000bEK\u0001\u0019\u0001*\t\u000bUK\u0001\u0019\u0001,\t\u000b\u0001L\u0001\u0019A1\t\u000b\u0011L\u0001\u0019A3\t\u000b5L\u0001\u0019\u00018\u0002\u00119,X\u000eU1siN,\"!a\t\u0011\u0007A\n)#C\u0002\u0002(E\u00121!\u00138u\u0003%qW/\u001c)beR\u001c\b%A\u0003tSj,7/\u0006\u0002\u00020A!\u0001'!\rO\u0013\r\t\u0019$\r\u0002\u0006\u0003J\u0014\u0018-_\u0001\u0007g&TXm\u001d\u0011\u0002!]\u0014\u0018\u000e\u001e;f]\n+hMZ3s\u0013\u0012\u001cXCAA\u001e!\u0019\ti$a\u0012\u0002L5\u0011\u0011q\b\u0006\u0005\u0003\u0003\n\u0019%A\u0004nkR\f'\r\\3\u000b\u0007\u0005\u0015\u0013'\u0001\u0006d_2dWm\u0019;j_:LA!!\u0013\u0002@\tY\u0011I\u001d:bs\n+hMZ3s!\r9\u0016QJ\u0005\u0004\u0003\u001fB&aD*ik\u001a4G.\u001a\"vM\u001a,'/\u00133\u0002#]\u0014\u0018\u000e\u001e;f]\n+hMZ3s\u0013\u0012\u001c\b%\u0001\nv]\u000e|W\u000e\u001d:fgN,G-T3ue&\u001cW#\u0001?\u0002'Ut7m\\7qe\u0016\u001c8/\u001a3NKR\u0014\u0018n\u0019\u0011\u0002\u000b]\u0014\u0018\u000e^3\u0015\t\u0005u\u00131\r\t\u0004a\u0005}\u0013bAA1c\t!QK\\5u\u0011\u001d\t)G\u0005a\u0001\u0003O\nqA]3d_J$7\u000f\u0005\u0004\u0002j\u0005M\u0014\u0011\u0010\b\u0005\u0003W\nyGD\u0002r\u0003[J\u0011AM\u0005\u0004\u0003c\n\u0014a\u00029bG.\fw-Z\u0005\u0005\u0003k\n9H\u0001\u0005Ji\u0016\u0014\u0018\r^8s\u0015\r\t\t(\r\t\u0006a\u0005m$\u0006O\u0005\u0004\u0003{\n$\u0001\u0003)s_\u0012,8\r\u001e\u001a\u0002\u0019\rdW-\u00198Ti>\u0014\u0018mZ3\u0015\u0005\u0005u\u0013\u0001B:u_B$B!a\"\u0002\u0016B!\u0001GZAE!\u0011\tY)!%\u000e\u0005\u00055%bAAH7\u0005I1o\u00195fIVdWM]\u0005\u0005\u0003'\u000biIA\u0005NCB\u001cF/\u0019;vg\"9\u0011q\u0013\u000bA\u0002\u0005e\u0015aB:vG\u000e,7o\u001d\t\u0004a\u0005m\u0015bAAOc\t9!i\\8mK\u0006t\u0017aE4fiB\u000b'\u000f^5uS>tG*\u001a8hi\"\u001cHCAA\u0018\u0001")
public class RapidsCachingWriter<K, V>
extends ShuffleWriter<K, V>
implements Logging {
    private final BlockManager blockManager;
    private final GpuShuffleHandle<K, V> handle;
    private final long mapId;
    private final ShuffleWriteMetricsReporter metricsReporter;
    private final ShuffleBufferCatalog catalog;
    private final RapidsDeviceMemoryStore shuffleStorage;
    private final Option<RapidsShuffleServer> rapidsShuffleServer;
    private final int numParts;
    private final long[] sizes;
    private final ArrayBuffer<ShuffleBufferId> writtenBufferIds;
    private final SQLMetric uncompressedMetric;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    private int numParts() {
        return this.numParts;
    }

    private long[] sizes() {
        return this.sizes;
    }

    private ArrayBuffer<ShuffleBufferId> writtenBufferIds() {
        return this.writtenBufferIds;
    }

    private SQLMetric uncompressedMetric() {
        return this.uncompressedMetric;
    }

    public void write(Iterator<Product2<K, V>> records) {
        try (NvtxRange nvtxRange = new NvtxRange("RapidsCachingWriter.write", NvtxColor.CYAN);){
            LongRef bytesWritten = LongRef.create((long)0L);
            LongRef recordsWritten = LongRef.create((long)0L);
            records.foreach((Function1 & Serializable & scala.Serializable)p -> {
                RapidsCachingWriter.$anonfun$write$10(this, recordsWritten, bytesWritten, p);
                return BoxedUnit.UNIT;
            });
            this.metricsReporter.incBytesWritten(bytesWritten.elem);
            this.metricsReporter.incRecordsWritten(recordsWritten.elem);
        }
    }

    private void cleanStorage() {
        this.writtenBufferIds().foreach((Function1 & Serializable & scala.Serializable)id -> {
            this.catalog.removeBuffer(id);
            return BoxedUnit.UNIT;
        });
    }

    public Option<MapStatus> stop(boolean success) {
        None$ none$;
        try (NvtxRange nvtxRange = new NvtxRange("RapidsCachingWriter.close", NvtxColor.CYAN);){
            if (!success) {
                this.cleanStorage();
                none$ = None$.MODULE$;
            } else {
                BlockManagerId blockManagerId;
                if (this.rapidsShuffleServer.isDefined()) {
                    BlockManagerId originalShuffleServerId = ((RapidsShuffleServer)this.rapidsShuffleServer.get()).originalShuffleServerId();
                    RapidsShuffleServer server = (RapidsShuffleServer)this.rapidsShuffleServer.get();
                    blockManagerId = BlockManagerId$.MODULE$.apply(originalShuffleServerId.executorId(), originalShuffleServerId.host(), originalShuffleServerId.port(), (Option)new Some((Object)new StringBuilder(1).append(RapidsShuffleTransport$.MODULE$.BLOCK_MANAGER_ID_TOPO_PREFIX()).append("=").append(server.getPort()).toString()));
                } else {
                    blockManagerId = this.blockManager.shuffleServerId();
                }
                BlockManagerId shuffleServerId = blockManagerId;
                this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(58).append("Done caching shuffle success=").append(success).append(", server_id=").append(shuffleServerId).append(", ").append("map_id=").append($this.mapId).append(", sizes=").append(new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(this.sizes())).mkString(",")).toString());
                none$ = new Some((Object)MapStatus$.MODULE$.apply(shuffleServerId, this.sizes(), this.mapId));
            }
        }
        return none$;
    }

    public long[] getPartitionLengths() {
        throw new UnsupportedOperationException("TODO");
    }

    public static final /* synthetic */ void $anonfun$write$10(RapidsCachingWriter $this, LongRef recordsWritten$1, LongRef bytesWritten$1, Product2 p) {
        int partId = BoxesRunTime.unboxToInt((Object)p._1());
        ColumnarBatch batch = (ColumnarBatch)p._2();
        $this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Caching shuffle_id=").append($this.handle.shuffleId()).append(" map_id=").append($this.mapId).append(", partId=").append(partId).append(", ").append("batch=[num_cols=").append(batch.numCols()).append(", num_rows=").append(batch.numRows()).append("]").toString());
        recordsWritten$1.elem += (long)batch.numRows();
        long partSize = 0L;
        ShuffleBlockId blockId = new ShuffleBlockId($this.handle.shuffleId(), $this.mapId, partId);
        ShuffleBufferId bufferId = $this.catalog.nextShuffleBufferId(blockId);
        if (batch.numRows() > 0 && batch.numCols() > 0) {
            ColumnVector columnVector = batch.column(0);
            if (columnVector instanceof GpuPackedTableColumn) {
                GpuPackedTableColumn gpuPackedTableColumn = (GpuPackedTableColumn)columnVector;
                ContiguousTable contigTable = gpuPackedTableColumn.getContiguousTable();
                partSize = gpuPackedTableColumn.getTableBuffer().getLength();
                $this.uncompressedMetric().$plus$eq(partSize);
                ShuffleBufferId x$1 = bufferId;
                ContiguousTable x$2 = contigTable;
                long x$3 = SpillPriorities$.MODULE$.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY();
                boolean x$4 = false;
                SpillCallback x$5 = $this.shuffleStorage.addContiguousTable$default$4();
                $this.shuffleStorage.addContiguousTable(x$1, x$2, x$3, x$5, x$4);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (columnVector instanceof GpuCompressedColumnVector) {
                GpuCompressedColumnVector gpuCompressedColumnVector = (GpuCompressedColumnVector)columnVector;
                DeviceMemoryBuffer buffer = gpuCompressedColumnVector.getTableBuffer();
                buffer.incRefCount();
                partSize = buffer.getLength();
                TableMeta tableMeta = gpuCompressedColumnVector.getTableMeta();
                tableMeta.bufferMeta().mutateId(bufferId.tableId());
                $this.uncompressedMetric().$plus$eq(tableMeta.bufferMeta().uncompressedSize());
                ShuffleBufferId x$6 = bufferId;
                DeviceMemoryBuffer x$7 = buffer;
                TableMeta x$8 = tableMeta;
                long x$9 = SpillPriorities$.MODULE$.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY();
                boolean x$10 = false;
                SpillCallback x$11 = $this.shuffleStorage.addBuffer$default$5();
                $this.shuffleStorage.addBuffer(x$6, x$7, x$8, x$9, x$11, x$10);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                throw new IllegalStateException(new StringBuilder(24).append("Unexpected column type: ").append(columnVector.getClass()).toString());
            }
            bytesWritten$1.elem += partSize;
            $this.sizes()[partId] = $this.sizes()[partId] + partSize;
        } else {
            TableMeta tableMeta = MetaUtils$.MODULE$.buildDegenerateTableMeta(batch);
            $this.catalog.registerNewBuffer(new DegenerateRapidsBuffer(bufferId, tableMeta));
            if (batch.numRows() > 0) {
                $this.sizes()[partId] = $this.sizes()[partId] + 100L;
            }
        }
        $this.writtenBufferIds().append((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ShuffleBufferId[]{bufferId}));
    }

    public RapidsCachingWriter(BlockManager blockManager, GpuShuffleHandle<K, V> handle, long mapId, ShuffleWriteMetricsReporter metricsReporter, ShuffleBufferCatalog catalog, RapidsDeviceMemoryStore shuffleStorage, Option<RapidsShuffleServer> rapidsShuffleServer, Map<String, SQLMetric> metrics) {
        this.blockManager = blockManager;
        this.handle = handle;
        this.mapId = mapId;
        this.metricsReporter = metricsReporter;
        this.catalog = catalog;
        this.shuffleStorage = shuffleStorage;
        this.rapidsShuffleServer = rapidsShuffleServer;
        Logging.$init$((Logging)this);
        this.numParts = handle.dependency().partitioner().numPartitions();
        this.sizes = new long[this.numParts()];
        this.writtenBufferIds = new ArrayBuffer(this.numParts());
        this.uncompressedMetric = (SQLMetric)metrics.apply((Object)"dataSize");
    }
}

