/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.spark.rapids;

import ai.rapids.cudf.Cuda;
import ai.rapids.cudf.CudaMemInfo;
import ai.rapids.cudf.PinnedMemoryPool;
import ai.rapids.cudf.Rmm;
import com.nvidia.spark.rapids.Errored$;
import com.nvidia.spark.rapids.Initialized$;
import com.nvidia.spark.rapids.MemoryState;
import com.nvidia.spark.rapids.RapidsBufferCatalog$;
import com.nvidia.spark.rapids.RapidsConf;
import com.nvidia.spark.rapids.RapidsConf$;
import com.nvidia.spark.rapids.Uninitialized$;
import java.io.Serializable;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.SerializedLambda;
import java.util.concurrent.ThreadFactory;
import org.apache.spark.SparkEnv$;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.internal.Logging;
import org.apache.spark.resource.ResourceInformation;
import org.apache.spark.sql.rapids.GpuShuffleEnv$;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.LambdaDeserialize;
import scala.runtime.LongRef;
import scala.runtime.RichInt$;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;
import scala.util.control.NonFatal$;

public final class GpuDeviceManager$
implements Logging {
    public static GpuDeviceManager$ MODULE$;
    private boolean rmmTaskInitEnabled;
    private final ThreadLocal<Object> threadGpuInitialized;
    private volatile MemoryState singletonMemoryInitialized;
    private volatile Option<Object> deviceId;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        new GpuDeviceManager$();
    }

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    public boolean rmmTaskInitEnabled() {
        return this.rmmTaskInitEnabled;
    }

    public void rmmTaskInitEnabled_$eq(boolean x$1) {
        this.rmmTaskInitEnabled = x$1;
    }

    public void setRmmTaskInitEnabled(boolean enabled) {
        this.rmmTaskInitEnabled_$eq(enabled);
    }

    private ThreadLocal<Object> threadGpuInitialized() {
        return this.threadGpuInitialized;
    }

    private MemoryState singletonMemoryInitialized() {
        return this.singletonMemoryInitialized;
    }

    private void singletonMemoryInitialized_$eq(MemoryState x$1) {
        this.singletonMemoryInitialized = x$1;
    }

    private Option<Object> deviceId() {
        return this.deviceId;
    }

    private void deviceId_$eq(Option<Object> x$1) {
        this.deviceId = x$1;
    }

    public Option<Object> getDeviceId() {
        return this.deviceId();
    }

    public boolean tryToSetGpuDeviceAndAcquire(int addr) {
        try {
            this.setGpuDeviceAndAcquire(addr);
        }
        catch (Throwable throwable) {
            Throwable throwable2 = throwable;
            Option option = NonFatal$.MODULE$.unapply(throwable2);
            if (!option.isEmpty()) {
                Throwable e = (Throwable)option.get();
                this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(29).append("Will not use GPU ").append(addr).append(" because of ").append(e).toString());
                return false;
            }
            throw throwable;
        }
        return true;
    }

    private int findGpuAndAcquire() {
        int deviceCount = Cuda.getDeviceCount();
        ArrayBuffer addrsToTry = ((ArrayBuffer)ArrayBuffer$.MODULE$.empty()).$plus$plus$eq((TraversableOnce)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), deviceCount));
        for (int numRetries = 2; numRetries > 0; --numRetries) {
            Option addr2 = addrsToTry.find((Function1)(JFunction1.mcZI.sp & Serializable & scala.Serializable)addr -> MODULE$.tryToSetGpuDeviceAndAcquire(addr));
            if (!addr2.isDefined()) continue;
            return BoxesRunTime.unboxToInt((Object)addr2.get());
        }
        throw new IllegalStateException("Could not find a single GPU to use");
    }

    public int setGpuDeviceAndAcquire(int addr) {
        this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(30).append("Initializing GPU device ID to ").append(addr).toString());
        Cuda.setDevice((int)addr);
        Cuda.freeZero();
        return addr;
    }

    public Option<Object> getGPUAddrFromResources(Map<String, ResourceInformation> resources, RapidsConf conf) {
        None$ none$;
        String sparkGpuResourceName = conf.getSparkGpuResourceName();
        if (resources.contains((Object)sparkGpuResourceName)) {
            this.logDebug((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(25).append("Spark resources contain: ").append(sparkGpuResourceName).toString());
            String[] addrs = ((ResourceInformation)resources.apply((Object)sparkGpuResourceName)).addresses();
            if (addrs.length > 1) {
                throw new IllegalArgumentException("Spark GPU Plugin only supports 1 gpu per executor");
            }
            none$ = new Some((Object)BoxesRunTime.boxToInteger((int)new StringOps(Predef$.MODULE$.augmentString((String)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])addrs)).head())).toInt()));
        } else {
            none$ = None$.MODULE$;
        }
        return none$;
    }

    public Option<Object> initializeGpu(Map<String, ResourceInformation> resources, RapidsConf conf) {
        return this.getGPUAddrFromResources(resources, conf).map((Function1)(JFunction1.mcII.sp & Serializable & scala.Serializable)x$1 -> MODULE$.setGpuDeviceAndAcquire(x$1));
    }

    public void initializeGpuAndMemory(Map<String, ResourceInformation> resources, RapidsConf conf) {
        block0: {
            if (!conf.isSqlExecuteOnGPU()) break block0;
            Option<Object> addr = this.initializeGpu(resources, conf);
            this.initializeMemory(addr, this.initializeMemory$default$2());
        }
    }

    public synchronized void shutdown() {
        this.singletonMemoryInitialized_$eq(Errored$.MODULE$);
        RapidsBufferCatalog$.MODULE$.close();
        GpuShuffleEnv$.MODULE$.shutdown();
        Rmm.shutdown();
        this.singletonMemoryInitialized_$eq(Uninitialized$.MODULE$);
    }

    public Map<String, ResourceInformation> getResourcesFromTaskContext() {
        TaskContext tc = TaskContext$.MODULE$.get();
        return tc == null ? Predef$.MODULE$.Map().empty() : tc.resources();
    }

    public void initializeFromTask() {
        block2: {
            BoxedUnit boxedUnit;
            if (BoxesRunTime.unboxToBoolean((Object)this.threadGpuInitialized().get())) break block2;
            Map<String, ResourceInformation> resources = this.getResourcesFromTaskContext();
            RapidsConf conf = new RapidsConf(SparkEnv$.MODULE$.get().conf());
            if (this.rmmTaskInitEnabled()) {
                this.initializeGpuAndMemory(resources, conf);
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = this.initializeGpu(resources, conf);
            }
            this.threadGpuInitialized().set(BoxesRunTime.boxToBoolean((boolean)true));
        }
    }

    private double toMB(long x) {
        return (double)(x / 1024L) / 1024.0;
    }

    private long computeRmmPoolSize(RapidsConf conf, CudaMemInfo info) {
        return BoxesRunTime.unboxToLong((Object)conf.rmmExactAlloc().map((Function1)(JFunction1.mcJJ.sp & Serializable & scala.Serializable)x -> GpuDeviceManager$.truncateToAlignment$1(x)).getOrElse((Function0)(JFunction0.mcJ.sp & Serializable & scala.Serializable)() -> {
            LongRef poolAllocation;
            block3: {
                long minAllocation = GpuDeviceManager$.truncateToAlignment$1((long)(conf.rmmAllocMinFraction() * (double)info$1.total));
                long maxAllocation = GpuDeviceManager$.truncateToAlignment$1((long)(conf.rmmAllocMaxFraction() * (double)info$1.total));
                long reserveAmount = conf.isUCXShuffleManagerMode() && conf.rmmPool().equalsIgnoreCase("ASYNC") ? conf.rmmAllocReserve() + conf.shuffleUcxBounceBuffersSize() * 2L : conf.rmmAllocReserve();
                poolAllocation = LongRef.create((long)GpuDeviceManager$.truncateToAlignment$1((long)(conf.rmmAllocFraction() * (double)(info$1.free - reserveAmount))));
                if (poolAllocation.elem < minAllocation) {
                    throw new IllegalArgumentException(new StringBuilder(154).append("The pool allocation of ").append(MODULE$.toMB(poolAllocation.elem)).append(" MB (calculated from ").append(RapidsConf$.MODULE$.RMM_ALLOC_FRACTION()).append(" ").append("(=").append(conf.rmmAllocFraction()).append(") and ").append(MODULE$.toMB(info$1.free)).append(" MB free memory) was less than ").append("the minimum allocation of ").append(MODULE$.toMB(minAllocation)).append(" (calculated from ").append(RapidsConf$.MODULE$.RMM_ALLOC_MIN_FRACTION()).append(" (=").append(conf.rmmAllocMinFraction()).append(") ").append("and ").append(MODULE$.toMB(info$1.total)).append(" MB total memory)").toString());
                }
                if (maxAllocation < poolAllocation.elem) {
                    throw new IllegalArgumentException(new StringBuilder(154).append("The pool allocation of ").append(MODULE$.toMB(poolAllocation.elem)).append(" MB (calculated from ").append(RapidsConf$.MODULE$.RMM_ALLOC_FRACTION()).append(" ").append("(=").append(conf.rmmAllocFraction()).append(") and ").append(MODULE$.toMB(info$1.free)).append(" MB free memory) was more than ").append("the maximum allocation of ").append(MODULE$.toMB(maxAllocation)).append(" (calculated from ").append(RapidsConf$.MODULE$.RMM_ALLOC_MAX_FRACTION()).append(" (=").append(conf.rmmAllocMaxFraction()).append(") ").append("and ").append(MODULE$.toMB(info$1.total)).append(" MB total memory)").toString());
                }
                if (reserveAmount >= maxAllocation) {
                    throw new IllegalArgumentException(new StringBuilder(98).append("RMM reserve memory (").append(MODULE$.toMB(reserveAmount)).append(" MB) ").append("larger than maximum pool size (").append(MODULE$.toMB(maxAllocation)).append(" MB). Check the settings for ").append(RapidsConf$.MODULE$.RMM_ALLOC_MAX_FRACTION()).append(" (=").append(conf.rmmAllocFraction()).append(") and ").append(RapidsConf$.MODULE$.RMM_ALLOC_RESERVE()).append(" (=").append(reserveAmount).append(")").toString());
                }
                long adjustedMaxAllocation = GpuDeviceManager$.truncateToAlignment$1(maxAllocation - reserveAmount);
                if (poolAllocation.elem <= adjustedMaxAllocation) break block3;
                MODULE$.logWarning((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(160).append("RMM pool allocation (").append(MODULE$.toMB(poolAllocation$1.elem)).append(" MB) does not leave enough free ").append("memory for reserve memory (").append(MODULE$.toMB(reserveAmount)).append(" MB), lowering the pool size to ").append(MODULE$.toMB(adjustedMaxAllocation)).append(" MB to accommodate the requested reserve amount.").toString());
                poolAllocation.elem = adjustedMaxAllocation;
            }
            return poolAllocation.elem;
        }));
    }

    private void initializeRmm(int gpuId, Option<RapidsConf> rapidsConf) {
        block22: {
            Rmm.LogConf logConf;
            BoxedUnit boxedUnit;
            if (Rmm.isInitialized()) break block22;
            RapidsConf conf = (RapidsConf)rapidsConf.getOrElse((Function0 & Serializable & scala.Serializable)() -> new RapidsConf(SparkEnv$.MODULE$.get().conf()));
            CudaMemInfo info = Cuda.memGetInfo();
            GpuShuffleEnv$.MODULE$.init(conf);
            long poolAllocation = this.computeRmmPoolSize(conf, info);
            int init = 0;
            ArrayBuffer features = (ArrayBuffer)ArrayBuffer$.MODULE$.apply((Seq)Nil$.MODULE$);
            if (conf.isPooledMemEnabled()) {
                int n;
                String string = conf.rmmPool();
                if ("default".equalsIgnoreCase(string)) {
                    if (Cuda.isPtdsEnabled()) {
                        this.logWarning((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "Configuring the DEFAULT allocator with a CUDF built for Per-Thread Default Stream (PTDS). This is known to be unstable! We recommend you use the ARENA allocator when PTDS is enabled.");
                    }
                    features.$plus$eq((Object)"POOLED");
                    n = init | 1;
                } else if ("arena".equalsIgnoreCase(string)) {
                    features.$plus$eq((Object)"ARENA");
                    n = init | 4;
                } else if ("async".equalsIgnoreCase(string)) {
                    features.$plus$eq((Object)"ASYNC");
                    n = init | 8;
                } else if ("none".equalsIgnoreCase(string)) {
                    n = init;
                } else {
                    throw new IllegalArgumentException(new StringBuilder(36).append("RMM pool set to '").append(string).append("' is not supported.").toString());
                }
                init = n;
            } else if (!"none".equalsIgnoreCase(conf.rmmPool())) {
                this.logWarning((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "RMM pool is disabled since spark.rapids.memory.gpu.pooling.enabled is set to false; however, this configuration is deprecated and the behavior may change in a future release.");
            }
            if (conf.isUvmEnabled()) {
                init |= 2;
                boxedUnit = features.$plus$eq((Object)"UVM");
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
            String string = conf.rmmDebugLocation();
            if ("none".equalsIgnoreCase(string)) {
                logConf = null;
            } else if ("stdout".equalsIgnoreCase(string)) {
                features.$plus$eq((Object)"LOG: STDOUT");
                logConf = Rmm.logToStdout();
            } else if ("stderr".equalsIgnoreCase(string)) {
                features.$plus$eq((Object)"LOG: STDERR");
                logConf = Rmm.logToStdout();
            } else {
                this.logWarning((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(60).append("RMM logging set to '").append(string).append("' is not supported and is being ignored.").toString());
                logConf = null;
            }
            Rmm.LogConf logConf2 = logConf;
            this.deviceId_$eq((Option<Object>)new Some((Object)BoxesRunTime.boxToInteger((int)gpuId)));
            this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(42).append("Initializing RMM").append(features.mkString(" ", " ", "")).append(" ").append("pool size = ").append(MODULE$.toMB(poolAllocation)).append(" MB on gpuId ").append(gpuId).toString());
            if (Cuda.isPtdsEnabled()) {
                this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "Using per-thread default stream");
            } else {
                this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "Using legacy default stream");
            }
            Cuda.setDevice((int)gpuId);
            Rmm.initialize((int)init, (Rmm.LogConf)logConf2, (long)poolAllocation);
            RapidsBufferCatalog$.MODULE$.init(conf);
        }
    }

    private Option<RapidsConf> initializeRmm$default$2() {
        return None$.MODULE$;
    }

    private void allocatePinnedMemory(int gpuId, Option<RapidsConf> rapidsConf) {
        block0: {
            RapidsConf conf = (RapidsConf)rapidsConf.getOrElse((Function0 & Serializable & scala.Serializable)() -> new RapidsConf(SparkEnv$.MODULE$.get().conf()));
            if (PinnedMemoryPool.isInitialized() || conf.pinnedPoolSize() <= 0L) break block0;
            this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(37).append("Initializing pinned memory pool (").append((double)(conf.pinnedPoolSize() / 1024L) / 1024.0).append(" MB)").toString());
            PinnedMemoryPool.initialize((long)conf.pinnedPoolSize(), (int)gpuId);
        }
    }

    private Option<RapidsConf> allocatePinnedMemory$default$2() {
        return None$.MODULE$;
    }

    public void initializeMemory(Option<Object> gpuId, Option<RapidsConf> rapidsConf) {
        MemoryState memoryState = this.singletonMemoryInitialized();
        Initialized$ initialized$ = Initialized$.MODULE$;
        if (memoryState == null ? initialized$ != null : !memoryState.equals(initialized$)) {
            GpuDeviceManager$ gpuDeviceManager$ = this;
            synchronized (gpuDeviceManager$) {
                MemoryState memoryState2 = this.singletonMemoryInitialized();
                Errored$ errored$ = Errored$.MODULE$;
                if (!(memoryState2 != null ? !memoryState2.equals(errored$) : errored$ != null)) {
                    throw new IllegalStateException("Cannot initialize memory due to previous shutdown failing");
                }
                MemoryState memoryState3 = this.singletonMemoryInitialized();
                Uninitialized$ uninitialized$ = Uninitialized$.MODULE$;
                if (!(memoryState3 != null ? !memoryState3.equals(uninitialized$) : uninitialized$ != null)) {
                    int gpu = BoxesRunTime.unboxToInt((Object)gpuId.getOrElse((Function0)(JFunction0.mcI.sp & Serializable & scala.Serializable)() -> MODULE$.findGpuAndAcquire()));
                    this.initializeRmm(gpu, rapidsConf);
                    this.allocatePinnedMemory(gpu, rapidsConf);
                    this.singletonMemoryInitialized_$eq(Initialized$.MODULE$);
                }
            }
        }
    }

    public Option<RapidsConf> initializeMemory$default$2() {
        return None$.MODULE$;
    }

    public ThreadFactory wrapThreadFactory(ThreadFactory factory) {
        return new ThreadFactory(factory){
            private final int devId;
            private final ThreadFactory factory$1;

            public Thread newThread(Runnable runnable) {
                return this.factory$1.newThread(() -> {
                    Cuda.setDevice((int)$this.devId);
                    runnable.run();
                });
            }
            {
                this.factory$1 = factory$1;
                this.devId = BoxesRunTime.unboxToInt((Object)GpuDeviceManager$.MODULE$.getDeviceId().getOrElse((Function0 & Serializable & scala.Serializable)() -> {
                    throw new IllegalStateException("Device ID is not set");
                }));
            }

            private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
                return LambdaDeserialize.bootstrap("lambdaDeserialize", new MethodHandle[]{$anonfun$devId$1()}, serializedLambda);
            }
        };
    }

    private static final long truncateToAlignment$1(long x) {
        return x & (0x1FFL ^ 0xFFFFFFFFFFFFFFFFL);
    }

    private GpuDeviceManager$() {
        MODULE$ = this;
        Logging.$init$((Logging)this);
        this.rmmTaskInitEnabled = Boolean.getBoolean("com.nvidia.spark.rapids.memory.gpu.rmm.init.task");
        this.threadGpuInitialized = new ThreadLocal();
        this.singletonMemoryInitialized = Uninitialized$.MODULE$;
        this.deviceId = None$.MODULE$;
    }
}

