package com.nvidia.spark.rapids;

import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.AttributeSet;
import org.apache.spark.sql.catalyst.expressions.If;
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression;
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction;
import org.apache.spark.sql.catalyst.expressions.aggregate.First;
import org.apache.spark.sql.types.DataType;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.$colon;
import scala.collection.immutable.List;
import scala.collection.immutable.Nil$;
import scala.math.Numeric$LongIsIntegral$;
import scala.runtime.BoxesRunTime;

/* compiled from: aggregate.scala */
/* loaded from: input_file:com/nvidia/spark/rapids/AggregateUtils$.class */
public final class AggregateUtils$ {
    public static AggregateUtils$ MODULE$;
    private final List<String> aggs;

    static {
        new AggregateUtils$();
    }

    private List<String> aggs() {
        return this.aggs;
    }

    public boolean validateAggregate(AttributeSet attributeSet) {
        return attributeSet.toSeq().exists(attribute -> {
            return BoxesRunTime.boxToBoolean($anonfun$validateAggregate$1(attribute));
        });
    }

    public boolean shouldFallbackMultiDistinct(Seq<AggregateExpression> seq) {
        return ((IterableLike) seq.map(aggregateExpression -> {
            return aggregateExpression.aggregateFunction();
        }, Seq$.MODULE$.canBuildFrom())).exists(aggregateFunction -> {
            return BoxesRunTime.boxToBoolean($anonfun$shouldFallbackMultiDistinct$2(aggregateFunction));
        });
    }

    public long computeTargetBatchSize(long j, Seq<DataType> seq, Seq<DataType> seq2, boolean z) {
        long typesToSize$1 = typesToSize$1(seq);
        long typesToSize$12 = typesToSize$1(seq2);
        long j2 = 4 * j;
        long j3 = typesToSize$1 + 16;
        if (z) {
            j2 -= typesToSize$12;
        } else {
            j3 += typesToSize$12;
        }
        return Math.min(typesToSize$1 * (j2 / j3), 2147483647L);
    }

    public static final /* synthetic */ boolean $anonfun$validateAggregate$2(Attribute attribute, String str) {
        return attribute.name().contains(str);
    }

    public static final /* synthetic */ boolean $anonfun$validateAggregate$1(Attribute attribute) {
        return MODULE$.aggs().exists(str -> {
            return BoxesRunTime.boxToBoolean($anonfun$validateAggregate$2(attribute, str));
        });
    }

    public static final /* synthetic */ boolean $anonfun$shouldFallbackMultiDistinct$2(AggregateFunction aggregateFunction) {
        return (aggregateFunction instanceof First) && (((First) aggregateFunction).child() instanceof If) && MODULE$.validateAggregate(aggregateFunction.references());
    }

    public static final /* synthetic */ long $anonfun$computeTargetBatchSize$1(DataType dataType) {
        return GpuBatchUtils$.MODULE$.estimateGpuMemory(dataType, false, 1L);
    }

    private static final long typesToSize$1(Seq seq) {
        return BoxesRunTime.unboxToLong(((TraversableOnce) seq.map(dataType -> {
            return BoxesRunTime.boxToLong($anonfun$computeTargetBatchSize$1(dataType));
        }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$LongIsIntegral$.MODULE$));
    }

    private AggregateUtils$() {
        MODULE$ = this;
        this.aggs = new $colon.colon("min", new $colon.colon("max", new $colon.colon("avg", new $colon.colon("sum", new $colon.colon("count", new $colon.colon("first", new $colon.colon("last", Nil$.MODULE$)))))));
    }
}
