/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.spark.rapids;

import java.io.Serializable;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.runtime.RichLong$;
import scala.runtime.java8.JFunction1;
import scala.runtime.java8.JFunction2;

public final class GpuBatchUtils$ {
    public static GpuBatchUtils$ MODULE$;
    private final int VALIDITY_BUFFER_BOUNDARY_BYTES;
    private final int VALIDITY_BUFFER_BOUNDARY_ROWS;
    private final int OFFSET_BYTES;

    static {
        new GpuBatchUtils$();
    }

    public int VALIDITY_BUFFER_BOUNDARY_BYTES() {
        return this.VALIDITY_BUFFER_BOUNDARY_BYTES;
    }

    public int VALIDITY_BUFFER_BOUNDARY_ROWS() {
        return this.VALIDITY_BUFFER_BOUNDARY_ROWS;
    }

    public int OFFSET_BYTES() {
        return this.OFFSET_BYTES;
    }

    public int estimateRowCount(long desiredBatchSizeBytes, long currentBatchSize, long currentBatchRowCount) {
        Predef$.MODULE$.assert(currentBatchRowCount > 0L, (Function0 & Serializable & scala.Serializable)() -> "batch must contain at least one row");
        long targetRowCount = currentBatchSize > desiredBatchSizeBytes ? currentBatchRowCount : (currentBatchSize == 0L ? currentBatchRowCount : (long)((float)((double)desiredBatchSizeBytes / (double)currentBatchSize) * (float)currentBatchRowCount));
        return (int)RichLong$.MODULE$.min$extension(Predef$.MODULE$.longWrapper(targetRowCount), Integer.MAX_VALUE);
    }

    public long estimateGpuMemory(StructType schema, long rowCount) {
        return BoxesRunTime.unboxToLong((Object)((TraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])schema.fields())).indices().map((Function1)(JFunction1.mcJI.sp & Serializable & scala.Serializable)x$1 -> MODULE$.estimateGpuMemory(schema, x$1, rowCount), IndexedSeq$.MODULE$.canBuildFrom())).sum((Numeric)Numeric.LongIsIntegral$.MODULE$));
    }

    public long estimateGpuMemory(StructType schema, int columnIndex, long rowCount) {
        StructField field = schema.fields()[columnIndex];
        return this.estimateGpuMemory(field.dataType(), field.nullable(), rowCount);
    }

    public long estimateGpuMemory(DataType dataType, boolean nullable, long rowCount) {
        long l;
        long validityBufferSize = nullable ? this.calculateValidityBufferSize(rowCount) : 0L;
        DataType dataType2 = dataType;
        DataType dataType3 = DataTypes.BinaryType;
        DataType dataType4 = dataType2;
        if (!(dataType3 != null ? !dataType3.equals(dataType4) : dataType4 != null)) {
            DataType dataType5 = dataType2;
            long offsetBufferSize = this.calculateOffsetBufferSize(rowCount);
            long dataSize = (long)dataType5.defaultSize() * rowCount;
            l = dataSize + offsetBufferSize;
        } else {
            DataType dataType6 = DataTypes.StringType;
            DataType dataType7 = dataType2;
            if (!(dataType6 != null ? !dataType6.equals(dataType7) : dataType7 != null)) {
                DataType dataType8 = dataType2;
                long offsetBufferSize = this.calculateOffsetBufferSize(rowCount);
                long dataSize = (long)dataType8.defaultSize() * rowCount;
                l = dataSize + offsetBufferSize;
            } else if (dataType2 instanceof MapType) {
                MapType mapType = (MapType)dataType2;
                l = this.calculateOffsetBufferSize(rowCount) + this.estimateGpuMemory(mapType.keyType(), false, rowCount) + this.estimateGpuMemory(mapType.valueType(), mapType.valueContainsNull(), rowCount);
            } else if (dataType2 instanceof ArrayType) {
                ArrayType arrayType = (ArrayType)dataType2;
                l = this.calculateOffsetBufferSize(rowCount) + this.estimateGpuMemory(arrayType.elementType(), arrayType.containsNull(), rowCount);
            } else if (dataType2 instanceof StructType) {
                StructType structType = (StructType)dataType2;
                l = BoxesRunTime.unboxToLong((Object)new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps((long[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])structType.fields())).map((Function1 & Serializable & scala.Serializable)f -> BoxesRunTime.boxToLong((long)GpuBatchUtils$.MODULE$.estimateGpuMemory(f.dataType(), f.nullable(), rowCount)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Long())))).sum((Numeric)Numeric.LongIsIntegral$.MODULE$));
            } else {
                l = (long)dataType2.defaultSize() * rowCount;
            }
        }
        long dataSize = l;
        return dataSize + validityBufferSize;
    }

    public long calculateValidityBufferSize(long rows) {
        return this.roundToBoundary((rows + 7L) / 8L, 64);
    }

    public long calculateOffsetBufferSize(long rows) {
        return (rows + 1L) * 4L;
    }

    public int[] generateSplitIndices(long rows, int numSplits) {
        Predef$.MODULE$.require(rows > 0L, (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(19).append("invalid input rows ").append(rows).toString());
        Predef$.MODULE$.require(numSplits > 0, (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(18).append("invalid numSplits ").append(numSplits).toString());
        int baseIncrement = (int)(rows / (long)numSplits);
        IntRef extraIncrements = IntRef.create((int)((int)(rows % (long)numSplits)));
        ArrayBuffer indicesBuf = (ArrayBuffer)ArrayBuffer$.MODULE$.apply((Seq)Nil$.MODULE$);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(1), numSplits).foldLeft((Object)BoxesRunTime.boxToInteger((int)0), (Function2)(JFunction2.mcIII.sp & Serializable & scala.Serializable)(x0$1, x1$1) -> {
            int n;
            Tuple2.mcII.sp sp2 = new Tuple2.mcII.sp(x0$1, x1$1);
            if (sp2 != null) {
                int last = sp2._1$mcI$sp();
                if (extraIncrements$1.elem > 0) {
                    --extraIncrements$1.elem;
                    n = last + baseIncrement + 1;
                } else {
                    n = last + baseIncrement;
                }
            } else {
                throw new MatchError((Object)sp2);
            }
            int current = n;
            indicesBuf.$plus$eq((Object)BoxesRunTime.boxToInteger((int)current));
            int n2 = current;
            return n2;
        });
        return (int[])indicesBuf.toArray(ClassTag$.MODULE$.Int());
    }

    public boolean isVariableWidth(DataType dt) {
        return !this.isFixedWidth(dt);
    }

    public boolean isFixedWidth(DataType dt) {
        boolean bl;
        DataType dataType = dt;
        DataType dataType2 = DataTypes.StringType;
        DataType dataType3 = dataType;
        if (!(dataType2 != null ? !dataType2.equals(dataType3) : dataType3 != null)) {
            bl = true;
        } else {
            DataType dataType4 = DataTypes.BinaryType;
            DataType dataType5 = dataType;
            bl = !(dataType4 != null ? !dataType4.equals(dataType5) : dataType5 != null);
        }
        boolean bl2 = bl ? false : (dataType instanceof ArrayType ? false : (dataType instanceof StructType ? false : !(dataType instanceof MapType)));
        return bl2;
    }

    private long roundToBoundary(long bytes, int boundary) {
        long remainder = bytes % (long)boundary;
        return remainder > 0L ? bytes + (long)boundary - remainder : bytes;
    }

    private GpuBatchUtils$() {
        MODULE$ = this;
        this.VALIDITY_BUFFER_BOUNDARY_BYTES = 64;
        this.VALIDITY_BUFFER_BOUNDARY_ROWS = this.VALIDITY_BUFFER_BOUNDARY_BYTES() * 8;
        this.OFFSET_BYTES = 4;
    }
}

