package com.nvidia.spark.rapids;

import ai.rapids.cudf.NvtxColor;
import ai.rapids.cudf.NvtxRange;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.spark.TaskContext;
import org.apache.spark.internal.Logging;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: GpuSemaphore.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00055r!\u0002\r\u001a\u0011\u0003\u0011c!\u0002\u0013\u001a\u0011\u0003)\u0003\"\u0002\u0017\u0002\t\u0003i\u0003b\u0002\u0018\u0002\u0005\u0004%Ia\f\u0005\u0007g\u0005\u0001\u000b\u0011\u0002\u0019\t\u0013Q\n\u0001\u0019!a\u0001\n\u0013)\u0004bCA\u0005\u0003\u0001\u0007\t\u0019!C\u0005\u0003\u0017A!\"!\u0005\u0002\u0001\u0004\u0005\t\u0015)\u00037\u0011\u0019\tY\"\u0001C\u0005k!9\u0011QD\u0001\u0005\u0002\u0005}\u0001B\u00027\u0002\t\u0003\t\u0019\u0003\u0003\u0004}\u0003\u0011\u0005\u0011\u0011\u0006\u0005\b\u0003\u000b\tA\u0011AA\u0004\r\u0011!\u0013DB\u001c\t\u0011\u0019k!\u0011!Q\u0001\n\u001dCQ\u0001L\u0007\u0005\u0002)Cq\u0001T\u0007C\u0002\u0013%Q\n\u0003\u0004Y\u001b\u0001\u0006IA\u0014\u0005\b36\u0011\r\u0011\"\u0003[\u0011\u0019YW\u0002)A\u00057\")A.\u0004C\u0001[\")A0\u0004C\u0001{\"1q0\u0004C\u0001\u0003\u0003Aq!!\u0002\u000e\t\u0003\t9!\u0001\u0007HaV\u001cV-\\1qQ>\u0014XM\u0003\u0002\u001b7\u00051!/\u00199jINT!\u0001H\u000f\u0002\u000bM\u0004\u0018M]6\u000b\u0005yy\u0012A\u00028wS\u0012L\u0017MC\u0001!\u0003\r\u0019w.\\\u0002\u0001!\t\u0019\u0013!D\u0001\u001a\u000519\u0005/^*f[\u0006\u0004\bn\u001c:f'\t\ta\u0005\u0005\u0002(U5\t\u0001FC\u0001*\u0003\u0015\u00198-\u00197b\u0013\tY\u0003F\u0001\u0004B]f\u0014VMZ\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003\t\nq!\u001a8bE2,G-F\u00011!\t9\u0013'\u0003\u00023Q\t9!i\\8mK\u0006t\u0017\u0001C3oC\ndW\r\u001a\u0011\u0002\u0011%t7\u000f^1oG\u0016,\u0012A\u000e\t\u0003G5\u0019B!\u0004\u00149\u0007B\u0011\u0011(Q\u0007\u0002u)\u00111\bP\u0001\tS:$XM\u001d8bY*\u0011A$\u0010\u0006\u0003}}\na!\u00199bG\",'\"\u0001!\u0002\u0007=\u0014x-\u0003\u0002Cu\t9Aj\\4hS:<\u0007CA\u0012E\u0013\t)\u0015DA\u0002Be6\f1\u0002^1tWN\u0004VM]$qkB\u0011q\u0005S\u0005\u0003\u0013\"\u00121!\u00138u)\t14\nC\u0003G\u001f\u0001\u0007q)A\u0005tK6\f\u0007\u000f[8sKV\ta\n\u0005\u0002P-6\t\u0001K\u0003\u0002R%\u0006Q1m\u001c8dkJ\u0014XM\u001c;\u000b\u0005M#\u0016\u0001B;uS2T\u0011!V\u0001\u0005U\u00064\u0018-\u0003\u0002X!\nI1+Z7ba\"|'/Z\u0001\u000bg\u0016l\u0017\r\u001d5pe\u0016\u0004\u0013aC1di&4X\rV1tWN,\u0012a\u0017\t\u0005\u001frs\u0016-\u0003\u0002^!\n\t2i\u001c8dkJ\u0014XM\u001c;ICNDW*\u00199\u0011\u0005\u001dz\u0016B\u00011)\u0005\u0011auN\\4\u0011\u0005\tLW\"A2\u000b\u0005\u0011,\u0017aB7vi\u0006\u0014G.\u001a\u0006\u0003M\u001e\fQ\u0001\\1oONR!\u0001[\u001f\u0002\u000f\r|W.\\8og&\u0011!n\u0019\u0002\u000b\u001bV$\u0018M\u00197f\u0013:$\u0018\u0001D1di&4X\rV1tWN\u0004\u0013AE1dcVL'/Z%g\u001d\u0016\u001cWm]:bef$2A\\9x!\t9s.\u0003\u0002qQ\t!QK\\5u\u0011\u0015\u0011H\u00031\u0001t\u0003\u001d\u0019wN\u001c;fqR\u0004\"\u0001^;\u000e\u0003qJ!A\u001e\u001f\u0003\u0017Q\u000b7o[\"p]R,\u0007\u0010\u001e\u0005\u0006qR\u0001\r!_\u0001\u000bo\u0006LG/T3ue&\u001c\u0007CA\u0012{\u0013\tY\u0018DA\u0005HaVlU\r\u001e:jG\u0006\u0011\"/\u001a7fCN,\u0017J\u001a(fG\u0016\u001c8/\u0019:z)\tqg\u0010C\u0003s+\u0001\u00071/\u0001\u0007d_6\u0004H.\u001a;f)\u0006\u001c8\u000eF\u0002o\u0003\u0007AQA\u001d\fA\u0002M\f\u0001b\u001d5vi\u0012|wO\u001c\u000b\u0002]\u0006a\u0011N\\:uC:\u001cWm\u0018\u0013fcR\u0019a.!\u0004\t\u0011\u0005=a!!AA\u0002Y\n1\u0001\u001f\u00132\u0003%Ign\u001d;b]\u000e,\u0007\u0005K\u0002\b\u0003+\u00012aJA\f\u0013\r\tI\u0002\u000b\u0002\tm>d\u0017\r^5mK\u0006Yq-\u001a;J]N$\u0018M\\2f\u0003)Ig.\u001b;jC2L'0\u001a\u000b\u0004]\u0006\u0005\u0002\"\u0002$\n\u0001\u00049E#\u00028\u0002&\u0005\u001d\u0002\"\u0002:\u000b\u0001\u0004\u0019\b\"\u0002=\u000b\u0001\u0004IHc\u00018\u0002,!)!o\u0003a\u0001g\u0002")
/* loaded from: input_file:com/nvidia/spark/rapids/GpuSemaphore.class */
public final class GpuSemaphore implements Logging, Arm {
    private final Semaphore semaphore;
    private final ConcurrentHashMap<Object, MutableInt> activeTasks;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    public static void initialize(int i) {
        GpuSemaphore$.MODULE$.initialize(i);
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V withResource(T t, Function1<T, V> function1) {
        Object withResource;
        withResource = withResource((GpuSemaphore) ((Arm) t), (Function1<GpuSemaphore, Object>) ((Function1<Arm, V>) function1));
        return (V) withResource;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V withResource(Option<T> option, Function1<Option<T>, V> function1) {
        Object withResource;
        withResource = withResource(option, function1);
        return (V) withResource;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V withResource(Seq<T> seq, Function1<Seq<T>, V> function1) {
        Object withResource;
        withResource = withResource(seq, function1);
        return (V) withResource;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V withResource(T[] tArr, Function1<T[], V> function1) {
        Object withResource;
        withResource = withResource(tArr, function1);
        return (V) withResource;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V withResource(ArrayBuffer<T> arrayBuffer, Function1<ArrayBuffer<T>, V> function1) {
        Object withResource;
        withResource = withResource(arrayBuffer, function1);
        return (V) withResource;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T, V> V withResourceIfAllowed(T t, Function1<T, V> function1) {
        Object withResourceIfAllowed;
        withResourceIfAllowed = withResourceIfAllowed(t, function1);
        return (V) withResourceIfAllowed;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V closeOnExcept(T t, Function1<T, V> function1) {
        Object closeOnExcept;
        closeOnExcept = closeOnExcept((GpuSemaphore) ((Arm) t), (Function1<GpuSemaphore, Object>) ((Function1<Arm, V>) function1));
        return (V) closeOnExcept;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V closeOnExcept(Seq<T> seq, Function1<Seq<T>, V> function1) {
        Object closeOnExcept;
        closeOnExcept = closeOnExcept(seq, function1);
        return (V) closeOnExcept;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V closeOnExcept(T[] tArr, Function1<T[], V> function1) {
        Object closeOnExcept;
        closeOnExcept = closeOnExcept(tArr, function1);
        return (V) closeOnExcept;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V closeOnExcept(ArrayBuffer<T> arrayBuffer, Function1<ArrayBuffer<T>, V> function1) {
        Object closeOnExcept;
        closeOnExcept = closeOnExcept(arrayBuffer, function1);
        return (V) closeOnExcept;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V closeOnExcept(Option<T> option, Function1<Option<T>, V> function1) {
        Object closeOnExcept;
        closeOnExcept = closeOnExcept(option, function1);
        return (V) closeOnExcept;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends RapidsBuffer, V> V freeOnExcept(T t, Function1<T, V> function1) {
        Object freeOnExcept;
        freeOnExcept = freeOnExcept(t, function1);
        return (V) freeOnExcept;
    }

    @Override // com.nvidia.spark.rapids.Arm
    public <T extends AutoCloseable, V> V withResource(CloseableHolder<T> closeableHolder, Function1<CloseableHolder<T>, V> function1) {
        Object withResource;
        withResource = withResource(closeableHolder, function1);
        return (V) withResource;
    }

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    private Semaphore semaphore() {
        return this.semaphore;
    }

    private ConcurrentHashMap<Object, MutableInt> activeTasks() {
        return this.activeTasks;
    }

    public void acquireIfNecessary(TaskContext taskContext, GpuMetric gpuMetric) {
        withResource((GpuSemaphore) new NvtxWithMetrics("Acquire GPU", NvtxColor.RED, Predef$.MODULE$.wrapRefArray(new GpuMetric[]{gpuMetric})), (Function1<GpuSemaphore, V>) nvtxWithMetrics -> {
            $anonfun$acquireIfNecessary$1(this, taskContext, nvtxWithMetrics);
            return BoxedUnit.UNIT;
        });
    }

    public void releaseIfNecessary(TaskContext taskContext) {
        NvtxRange nvtxRange = new NvtxRange("Release GPU", NvtxColor.RED);
        try {
            long taskAttemptId = taskContext.taskAttemptId();
            MutableInt mutableInt = activeTasks().get(BoxesRunTime.boxToLong(taskAttemptId));
            if (mutableInt != null && Predef$.MODULE$.Integer2int(mutableInt.getValue()) > 0) {
                if (mutableInt.decrementAndGet() == 0) {
                    logDebug(() -> {
                        return new StringBuilder(19).append("Task ").append(taskAttemptId).append(" releasing GPU").toString();
                    });
                    semaphore().release();
                }
            }
        } finally {
            nvtxRange.close();
        }
    }

    public void completeTask(TaskContext taskContext) {
        long taskAttemptId = taskContext.taskAttemptId();
        MutableInt remove = activeTasks().remove(BoxesRunTime.boxToLong(taskAttemptId));
        if (remove == null) {
            throw new IllegalStateException(new StringBuilder(27).append("Completion of unknown task ").append(taskAttemptId).toString());
        }
        if (Predef$.MODULE$.Integer2int(remove.getValue()) > 0) {
            logDebug(() -> {
                return new StringBuilder(19).append("Task ").append(taskAttemptId).append(" releasing GPU").toString();
            });
            semaphore().release();
        }
    }

    public void shutdown() {
        if (activeTasks().isEmpty()) {
            return;
        }
        logDebug(() -> {
            return new StringBuilder(42).append("shutting down with ").append(this.activeTasks().size()).append(" tasks still registered").toString();
        });
    }

    public static final /* synthetic */ void $anonfun$acquireIfNecessary$1(GpuSemaphore gpuSemaphore, TaskContext taskContext, NvtxWithMetrics nvtxWithMetrics) {
        long taskAttemptId = taskContext.taskAttemptId();
        MutableInt mutableInt = gpuSemaphore.activeTasks().get(BoxesRunTime.boxToLong(taskAttemptId));
        if (mutableInt == null || BoxesRunTime.equalsNumObject(mutableInt.getValue(), BoxesRunTime.boxToInteger(0))) {
            gpuSemaphore.logDebug(() -> {
                return new StringBuilder(19).append("Task ").append(taskAttemptId).append(" acquiring GPU").toString();
            });
            gpuSemaphore.semaphore().acquire();
            if (mutableInt != null) {
                mutableInt.increment();
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                gpuSemaphore.activeTasks().put(BoxesRunTime.boxToLong(taskAttemptId), new MutableInt(1));
                taskContext.addTaskCompletionListener(taskContext2 -> {
                    gpuSemaphore.completeTask(taskContext2);
                    return BoxedUnit.UNIT;
                });
            }
            GpuDeviceManager$.MODULE$.initializeFromTask();
        }
    }

    public GpuSemaphore(int i) {
        Logging.$init$(this);
        Arm.$init$(this);
        this.semaphore = new Semaphore(i);
        this.activeTasks = new ConcurrentHashMap<>();
    }
}
