package com.nvidia.spark.rapids;

import org.apache.spark.sql.catalyst.expressions.Alias;
import org.apache.spark.sql.catalyst.expressions.Alias$;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression$;
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression;
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode;
import org.apache.spark.sql.catalyst.expressions.aggregate.Final$;
import org.apache.spark.sql.catalyst.expressions.aggregate.PartialMerge$;
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.trees.TreeNodeTag;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.rapids.CpuToGpuAggregateBufferConverter;
import org.apache.spark.sql.rapids.GpuToCpuAggregateBufferConverter;
import org.apache.spark.sql.types.DataType;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.LinearSeqOptimized;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.$colon;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.Queue;
import scala.collection.mutable.Queue$;
import scala.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.util.Either;
import scala.util.Left;
import scala.util.Right;

/* compiled from: aggregate.scala */
/* loaded from: input_file:com/nvidia/spark/rapids/GpuTypedImperativeSupportedAggregateExecMeta$.class */
public final class GpuTypedImperativeSupportedAggregateExecMeta$ {
    public static GpuTypedImperativeSupportedAggregateExecMeta$ MODULE$;
    private final TreeNodeTag<Object> bufferConverterInjected;

    static {
        new GpuTypedImperativeSupportedAggregateExecMeta$();
    }

    private TreeNodeTag<Object> bufferConverterInjected() {
        return this.bufferConverterInjected;
    }

    public void com$nvidia$spark$rapids$GpuTypedImperativeSupportedAggregateExecMeta$$handleAggregationBuffer(GpuTypedImperativeSupportedAggregateExecMeta<?> gpuTypedImperativeSupportedAggregateExecMeta) {
        if (containTypedImperativeAggregate(gpuTypedImperativeSupportedAggregateExecMeta, new Some(Final$.MODULE$)) && !gpuTypedImperativeSupportedAggregateExecMeta.mo265agg().getTagValue(bufferConverterInjected()).contains(BoxesRunTime.boxToBoolean(true))) {
            gpuTypedImperativeSupportedAggregateExecMeta.mo265agg().setTagValue(bufferConverterInjected(), BoxesRunTime.boxToBoolean(true));
            List<GpuBaseAggregateMeta<?>> aggregateOfAllStages = GpuBaseAggregateMeta$.MODULE$.getAggregateOfAllStages(gpuTypedImperativeSupportedAggregateExecMeta, (LogicalPlan) gpuTypedImperativeSupportedAggregateExecMeta.mo265agg().logicalLink().get());
            IndexedSeq indexedSeq = (IndexedSeq) aggregateOfAllStages.indices().map(i -> {
                switch (i) {
                    default:
                        return i != aggregateOfAllStages.length() - 1 && (((RapidsMeta) aggregateOfAllStages.apply(i)).canThisBeReplaced() ^ ((RapidsMeta) aggregateOfAllStages.apply(i + 1)).canThisBeReplaced()) && MODULE$.containTypedImperativeAggregate((GpuBaseAggregateMeta) aggregateOfAllStages.apply(i), MODULE$.containTypedImperativeAggregate$default$2()) && MODULE$.containTypedImperativeAggregate((GpuBaseAggregateMeta) aggregateOfAllStages.apply(i + 1), MODULE$.containTypedImperativeAggregate$default$2());
                }
            }, IndexedSeq$.MODULE$.canBuildFrom());
            if (indexedSeq.forall(obj -> {
                return BoxesRunTime.boxToBoolean($anonfun$handleAggregationBuffer$2(BoxesRunTime.unboxToBoolean(obj)));
            })) {
                return;
            }
            if (!((LinearSeqOptimized) aggregateOfAllStages.zip(indexedSeq, List$.MODULE$.canBuildFrom())).forall(tuple2 -> {
                return BoxesRunTime.boxToBoolean($anonfun$handleAggregationBuffer$3(tuple2));
            })) {
                aggregateOfAllStages.foreach(gpuBaseAggregateMeta -> {
                    $anonfun$handleAggregationBuffer$4(gpuBaseAggregateMeta);
                    return BoxedUnit.UNIT;
                });
            } else {
                bindBufferConverters(aggregateOfAllStages, indexedSeq);
            }
        }
    }

    private void bindBufferConverters(Seq<GpuBaseAggregateMeta<?>> seq, Seq<Object> seq2) {
        ((IterableLike) seq2.zipWithIndex(Seq$.MODULE$.canBuildFrom())).foreach(tuple2 -> {
            $anonfun$bindBufferConverters$1(seq, tuple2);
            return BoxedUnit.UNIT;
        });
    }

    private boolean containTypedImperativeAggregate(GpuBaseAggregateMeta<?> gpuBaseAggregateMeta, Option<AggregateMode> option) {
        return gpuBaseAggregateMeta.mo265agg().aggregateExpressions().exists(aggregateExpression -> {
            return BoxesRunTime.boxToBoolean($anonfun$containTypedImperativeAggregate$1(option, aggregateExpression));
        });
    }

    private Option<AggregateMode> containTypedImperativeAggregate$default$2() {
        return None$.MODULE$;
    }

    private Seq<NamedExpression> createBufferConverter(GpuBaseAggregateMeta<?> gpuBaseAggregateMeta, GpuBaseAggregateMeta<?> gpuBaseAggregateMeta2, boolean z) {
        Queue apply = Queue$.MODULE$.apply(Nil$.MODULE$);
        gpuBaseAggregateMeta.childExprs().foreach(baseExprMeta -> {
            $anonfun$createBufferConverter$1(z, apply, baseExprMeta);
            return BoxedUnit.UNIT;
        });
        return (Seq) gpuBaseAggregateMeta2.resultExpressions().map(baseExprMeta2 -> {
            Alias alias;
            AttributeReference copy;
            Alias alias2;
            if (baseExprMeta2.typeMeta().typeConverted()) {
                AttributeReference attributeReference = (AttributeReference) baseExprMeta2.wrapped();
                if (z) {
                    copy = attributeReference;
                } else {
                    copy = attributeReference.copy(attributeReference.copy$default$1(), (DataType) baseExprMeta2.typeMeta().dataType().get(), attributeReference.copy$default$3(), attributeReference.copy$default$4(), attributeReference.exprId(), attributeReference.qualifier());
                }
                AttributeReference attributeReference2 = copy;
                Left left = (Either) apply.dequeue();
                if (left instanceof Left) {
                    Expression createExpression = ((CpuToGpuAggregateBufferConverter) left.value()).createExpression(attributeReference2);
                    String sb = new StringBuilder(10).append(attributeReference2.name()).append("_converted").toString();
                    alias2 = new Alias(createExpression, sb, NamedExpression$.MODULE$.newExprId(), Alias$.MODULE$.apply$default$4(createExpression, sb), Alias$.MODULE$.apply$default$5(createExpression, sb), Alias$.MODULE$.apply$default$6(createExpression, sb));
                } else {
                    if (!(left instanceof Right)) {
                        throw new MatchError(left);
                    }
                    Expression createExpression2 = ((GpuToCpuAggregateBufferConverter) ((Right) left).value()).createExpression(attributeReference2);
                    String sb2 = new StringBuilder(10).append(attributeReference2.name()).append("_converted").toString();
                    alias2 = new Alias(createExpression2, sb2, NamedExpression$.MODULE$.newExprId(), Alias$.MODULE$.apply$default$4(createExpression2, sb2), Alias$.MODULE$.apply$default$5(createExpression2, sb2), Alias$.MODULE$.apply$default$6(createExpression2, sb2));
                }
                alias = alias2;
            } else {
                alias = (NamedExpression) baseExprMeta2.wrapped();
            }
            return alias;
        }, Seq$.MODULE$.canBuildFrom());
    }

    private Seq<SparkPlanMeta<?>> nextEdgeForConversion(SparkPlanMeta<?> sparkPlanMeta) {
        while (true) {
            SparkPlanMeta<?> sparkPlanMeta2 = (SparkPlanMeta) sparkPlanMeta.childPlans().head();
            if (sparkPlanMeta.canThisBeReplaced() ^ sparkPlanMeta2.canThisBeReplaced()) {
                return new $colon.colon(sparkPlanMeta, new $colon.colon(sparkPlanMeta2, Nil$.MODULE$));
            }
            sparkPlanMeta = sparkPlanMeta2;
        }
    }

    public static final /* synthetic */ boolean $anonfun$handleAggregationBuffer$2(boolean z) {
        return !z;
    }

    public static final /* synthetic */ boolean $anonfun$handleAggregationBuffer$3(Tuple2 tuple2) {
        if (tuple2 != null) {
            return !tuple2._2$mcZ$sp() || ((GpuBaseAggregateMeta) tuple2._1()).availableRuntimeDataTransition();
        }
        throw new MatchError(tuple2);
    }

    public static final /* synthetic */ void $anonfun$handleAggregationBuffer$4(GpuBaseAggregateMeta gpuBaseAggregateMeta) {
        if (!gpuBaseAggregateMeta.canThisBeReplaced()) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            gpuBaseAggregateMeta.willNotWorkOnGpu("Associated fallback for TypedImperativeAggregate");
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
    }

    public static final /* synthetic */ void $anonfun$bindBufferConverters$1(Seq seq, Tuple2 tuple2) {
        if (tuple2 != null) {
            boolean _1$mcZ$sp = tuple2._1$mcZ$sp();
            int _2$mcI$sp = tuple2._2$mcI$sp();
            if (_1$mcZ$sp) {
                boolean z = false;
                Seq seq2 = null;
                Seq nextEdgeForConversion = MODULE$.nextEdgeForConversion((SparkPlanMeta) seq.apply(_2$mcI$sp));
                if (nextEdgeForConversion instanceof List) {
                    z = true;
                    seq2 = (List) nextEdgeForConversion;
                    Some unapplySeq = List$.MODULE$.unapplySeq(seq2);
                    if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((LinearSeqOptimized) unapplySeq.get()).lengthCompare(2) == 0) {
                        SparkPlanMeta sparkPlanMeta = (SparkPlanMeta) ((LinearSeqOptimized) unapplySeq.get()).apply(0);
                        SparkPlanMeta sparkPlanMeta2 = (SparkPlanMeta) ((LinearSeqOptimized) unapplySeq.get()).apply(1);
                        if (sparkPlanMeta.canThisBeReplaced()) {
                            ((SparkPlan) sparkPlanMeta2.wrapped()).setTagValue(GpuOverrides$.MODULE$.preRowToColProjection(), MODULE$.createBufferConverter((GpuBaseAggregateMeta) seq.apply(_2$mcI$sp), (GpuBaseAggregateMeta) seq.apply(_2$mcI$sp + 1), true));
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                            return;
                        }
                    }
                }
                if (z) {
                    Some unapplySeq2 = List$.MODULE$.unapplySeq(seq2);
                    if (!unapplySeq2.isEmpty() && unapplySeq2.get() != null && ((LinearSeqOptimized) unapplySeq2.get()).lengthCompare(2) == 0) {
                        ((SparkPlan) ((SparkPlanMeta) ((LinearSeqOptimized) unapplySeq2.get()).apply(0)).wrapped()).setTagValue(GpuOverrides$.MODULE$.postColToRowProjection(), MODULE$.createBufferConverter((GpuBaseAggregateMeta) seq.apply(_2$mcI$sp), (GpuBaseAggregateMeta) seq.apply(_2$mcI$sp + 1), false));
                        BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
                        BoxedUnit boxedUnit22 = BoxedUnit.UNIT;
                        return;
                    }
                }
                throw new MatchError(nextEdgeForConversion);
            }
        }
        BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
    }

    public static final /* synthetic */ boolean $anonfun$containTypedImperativeAggregate$2(AggregateExpression aggregateExpression, AggregateMode aggregateMode) {
        AggregateMode mode = aggregateExpression.mode();
        return aggregateMode != null ? aggregateMode.equals(mode) : mode == null;
    }

    public static final /* synthetic */ boolean $anonfun$containTypedImperativeAggregate$1(Option option, AggregateExpression aggregateExpression) {
        return (aggregateExpression == null || !option.forall(aggregateMode -> {
            return BoxesRunTime.boxToBoolean($anonfun$containTypedImperativeAggregate$2(aggregateExpression, aggregateMode));
        })) ? false : aggregateExpression.aggregateFunction() instanceof TypedImperativeAggregate;
    }

    public static final /* synthetic */ void $anonfun$createBufferConverter$1(boolean z, Queue queue, BaseExprMeta baseExprMeta) {
        if (baseExprMeta.childExprs().length() != 1 || !(baseExprMeta.childExprs().head() instanceof TypedImperativeAggExprMeta)) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            return;
        }
        AggregateMode mode = ((AggregateExpression) baseExprMeta.wrapped()).mode();
        if (Final$.MODULE$.equals(mode) ? true : PartialMerge$.MODULE$.equals(mode)) {
            TypedImperativeAggExprMeta typedImperativeAggExprMeta = (TypedImperativeAggExprMeta) baseExprMeta.childExprs().head();
            queue.enqueue(Predef$.MODULE$.wrapRefArray(new Either[]{z ? package$.MODULE$.Left().apply(typedImperativeAggExprMeta.createCpuToGpuBufferConverter()) : package$.MODULE$.Right().apply(typedImperativeAggExprMeta.createGpuToCpuBufferConverter())}));
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
    }

    private GpuTypedImperativeSupportedAggregateExecMeta$() {
        MODULE$ = this;
        this.bufferConverterInjected = new TreeNodeTag<>("rapids.gpu.bufferConverterInjected");
    }
}
