package org.apache.spark.sql.rapids.execution;

import com.nvidia.spark.rapids.GpuBindReferences$;
import com.nvidia.spark.rapids.GpuBoundReference;
import com.nvidia.spark.rapids.GpuExpression;
import com.nvidia.spark.rapids.GpuMetric;
import com.nvidia.spark.rapids.GpuMetric$;
import com.nvidia.spark.rapids.GpuPartitioning;
import com.nvidia.spark.rapids.GpuRangePartitioner;
import com.nvidia.spark.rapids.GpuRangePartitioner$;
import com.nvidia.spark.rapids.GpuRoundRobinPartitioning;
import com.nvidia.spark.rapids.GpuSinglePartitioning$;
import com.nvidia.spark.rapids.GpuSortEachBatchIterator;
import com.nvidia.spark.rapids.GpuSortEachBatchIterator$;
import com.nvidia.spark.rapids.GpuSorter;
import com.nvidia.spark.rapids.RapidsPluginImplicits;
import com.nvidia.spark.rapids.RapidsPluginImplicits$;
import com.nvidia.spark.rapids.shims.GpuHashPartitioning;
import com.nvidia.spark.rapids.shims.GpuRangePartitioning;
import java.util.NoSuchElementException;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.rdd.RDD;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.sql.catalyst.expressions.Ascending$;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.expressions.SortOrder$;
import org.apache.spark.sql.catalyst.expressions.package$;
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.sql.internal.SQLConf$;
import org.apache.spark.sql.rapids.GpuShuffleDependency;
import org.apache.spark.sql.rapids.GpuShuffleDependency$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.util.MutablePair;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Product2;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.AbstractIterator;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: GpuShuffleExchangeExecBase.scala */
/* loaded from: input_file:org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase$.class */
public final class GpuShuffleExchangeExecBase$ implements Serializable {
    public static GpuShuffleExchangeExecBase$ MODULE$;

    static {
        new GpuShuffleExchangeExecBase$();
    }

    public ShuffleDependency<Object, ColumnarBatch, ColumnarBatch> prepareBatchShuffleDependency(RDD<ColumnarBatch> rdd, Seq<Attribute> seq, GpuPartitioning gpuPartitioning, DataType[] dataTypeArr, Serializer serializer, boolean z, boolean z2, Map<String, GpuMetric> map, Map<String, SQLMetric> map2, Map<String, GpuMetric> map3) {
        RDD<ColumnarBatch> rdd2;
        if ((gpuPartitioning instanceof GpuRoundRobinPartitioning) && SQLConf$.MODULE$.get().sortBeforeRepartition()) {
            GpuSorter gpuSorter = new GpuSorter((Seq<SortOrder>) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) ((TraversableOnce) ((TraversableLike) seq.zipWithIndex(Seq$.MODULE$.canBuildFrom())).map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Attribute attribute = (Attribute) tuple2._1();
                return SortOrder$.MODULE$.apply(new GpuBoundReference(tuple2._2$mcI$sp(), attribute.dataType(), attribute.nullable(), attribute.exprId(), attribute.name()), Ascending$.MODULE$, SortOrder$.MODULE$.apply$default$3());
            }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(SortOrder.class)))).toSeq(), seq);
            rdd2 = rdd.mapPartitions(iterator -> {
                return new GpuSortEachBatchIterator(iterator, gpuSorter, false, GpuSortEachBatchIterator$.MODULE$.apply$default$4(), GpuSortEachBatchIterator$.MODULE$.apply$default$5(), GpuSortEachBatchIterator$.MODULE$.apply$default$6(), GpuSortEachBatchIterator$.MODULE$.apply$default$7(), GpuSortEachBatchIterator$.MODULE$.apply$default$8());
            }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(ColumnarBatch.class));
        } else {
            rdd2 = rdd;
        }
        RDD<ColumnarBatch> rdd3 = rdd2;
        GpuExpression partitioner = getPartitioner(rdd3, seq, gpuPartitioning);
        return new GpuShuffleDependency(rdd3.mapPartitions(iterator2 -> {
            final Function1 function1 = columnarBatch -> {
                return partitioner.columnarEval(columnarBatch);
            };
            return new AbstractIterator<Product2<Object, ColumnarBatch>>(iterator2, function1, map) { // from class: org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBase$$anon$1
                private Tuple2<ColumnarBatch, Object>[] partitioned;
                private int at = 0;
                private final MutablePair<Object, ColumnarBatch> mutablePair = new MutablePair<>();
                private final Iterator iter$1;
                private final Function1 getParts$1;
                private final Map metrics$1;

                private Tuple2<ColumnarBatch, Object>[] partitioned() {
                    return this.partitioned;
                }

                private void partitioned_$eq(Tuple2<ColumnarBatch, Object>[] tuple2Arr) {
                    this.partitioned = tuple2Arr;
                }

                private int at() {
                    return this.at;
                }

                private void at_$eq(int i) {
                    this.at = i;
                }

                private MutablePair<Object, ColumnarBatch> mutablePair() {
                    return this.mutablePair;
                }

                private void partNextBatch() {
                    ColumnarBatch columnarBatch2;
                    if (partitioned() != null) {
                        RapidsPluginImplicits.AutoCloseableArray AutoCloseableArray = RapidsPluginImplicits$.MODULE$.AutoCloseableArray((AutoCloseable[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(partitioned())).map(tuple22 -> {
                            return (ColumnarBatch) tuple22._1();
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ColumnarBatch.class))));
                        AutoCloseableArray.safeClose(AutoCloseableArray.safeClose$default$1());
                        partitioned_$eq(null);
                        at_$eq(0);
                    }
                    if (this.iter$1.hasNext()) {
                        Object next = this.iter$1.next();
                        while (true) {
                            columnarBatch2 = (ColumnarBatch) next;
                            if (columnarBatch2.numRows() != 0 || !this.iter$1.hasNext()) {
                                break;
                            }
                            columnarBatch2.close();
                            next = this.iter$1.next();
                        }
                        partitioned_$eq((Tuple2[]) this.getParts$1.apply(columnarBatch2));
                        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(partitioned())).foreach(tuple23 -> {
                            $anonfun$partNextBatch$2(this, tuple23);
                            return BoxedUnit.UNIT;
                        });
                        ((GpuMetric) this.metrics$1.apply(GpuMetric$.MODULE$.NUM_OUTPUT_BATCHES())).$plus$eq(partitioned().length);
                        at_$eq(0);
                    }
                }

                public boolean hasNext() {
                    if (partitioned() == null || at() >= partitioned().length) {
                        partNextBatch();
                    }
                    return partitioned() != null && at() < partitioned().length;
                }

                /* renamed from: next, reason: merged with bridge method [inline-methods] */
                public Product2<Object, ColumnarBatch> m1470next() {
                    if (partitioned() == null || at() >= partitioned().length) {
                        partNextBatch();
                    }
                    if (partitioned() == null || at() >= partitioned().length) {
                        throw new NoSuchElementException("Walked off of the end...");
                    }
                    Tuple2<ColumnarBatch, Object> tuple22 = partitioned()[at()];
                    mutablePair().update(BoxesRunTime.boxToInteger(tuple22._2$mcI$sp()), tuple22._1());
                    at_$eq(at() + 1);
                    return mutablePair();
                }

                public static final /* synthetic */ void $anonfun$partNextBatch$2(GpuShuffleExchangeExecBase$$anon$1 gpuShuffleExchangeExecBase$$anon$1, Tuple2 tuple22) {
                    ((GpuMetric) gpuShuffleExchangeExecBase$$anon$1.metrics$1.apply(GpuMetric$.MODULE$.NUM_OUTPUT_ROWS())).$plus$eq(((ColumnarBatch) tuple22._1()).numRows());
                }

                {
                    this.iter$1 = iterator2;
                    this.getParts$1 = function1;
                    this.metrics$1 = map;
                }
            };
        }, rdd3.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Product2.class)), new BatchPartitionIdPassthrough(gpuPartitioning.numPartitions()), dataTypeArr, serializer, GpuShuffleDependency$.MODULE$.$lessinit$greater$default$5(), GpuShuffleDependency$.MODULE$.$lessinit$greater$default$6(), GpuShuffleDependency$.MODULE$.$lessinit$greater$default$7(), ShuffleExchangeExec$.MODULE$.createShuffleWriteProcessor(map2), z, z2, GpuMetric$.MODULE$.unwrap(map3), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(ColumnarBatch.class), ClassTag$.MODULE$.apply(ColumnarBatch.class));
    }

    private GpuExpression getPartitioner(RDD<ColumnarBatch> rdd, Seq<Attribute> seq, GpuPartitioning gpuPartitioning) {
        GpuExpression bindReference;
        if (gpuPartitioning instanceof GpuHashPartitioning) {
            bindReference = (GpuExpression) GpuBindReferences$.MODULE$.bindReference((GpuHashPartitioning) gpuPartitioning, package$.MODULE$.AttributeSeq(seq));
        } else if (gpuPartitioning instanceof GpuRangePartitioning) {
            GpuRangePartitioning gpuRangePartitioning = (GpuRangePartitioning) gpuPartitioning;
            GpuSorter gpuSorter = new GpuSorter(gpuRangePartitioning.gpuOrdering(), seq);
            bindReference = new GpuRangePartitioner(GpuRangePartitioner$.MODULE$.createRangeBounds(gpuRangePartitioning.numPartitions(), gpuSorter, rdd, SQLConf$.MODULE$.get().rangeExchangeSampleSizePerPartition()), gpuSorter);
        } else if (GpuSinglePartitioning$.MODULE$.equals(gpuPartitioning)) {
            bindReference = GpuSinglePartitioning$.MODULE$;
        } else {
            if (!(gpuPartitioning instanceof GpuRoundRobinPartitioning)) {
                throw scala.sys.package$.MODULE$.error(new StringBuilder(29).append("Exchange not implemented for ").append(gpuPartitioning).toString());
            }
            bindReference = GpuBindReferences$.MODULE$.bindReference((GpuRoundRobinPartitioning) gpuPartitioning, package$.MODULE$.AttributeSeq(seq));
        }
        return bindReference;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private GpuShuffleExchangeExecBase$() {
        MODULE$ = this;
    }
}
