package com.nvidia.spark.rapids;

import com.nvidia.spark.rapids.shims.SparkShimImpl$;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.internal.Logging;
import org.apache.spark.sql.catalyst.plans.QueryPlan;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec;
import org.slf4j.Logger;
import scala.Function0;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.mutable.ListBuffer;
import scala.math.Numeric$DoubleIsFractional$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;

/* compiled from: CostBasedOptimizer.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Ec\u0001B\u0005\u000b\u0001MAQ!\u000b\u0001\u0005\u0002)BQ\u0001\f\u0001\u0005\u00025BQa\u0014\u0001\u0005\nACQa\u001c\u0001\u0005\nADq!!\t\u0001\t\u0013\t\u0019\u0003C\u0004\u00022\u0001!I!a\r\t\u000f\u0005e\u0002\u0001\"\u0003\u0002<!9\u0011\u0011\t\u0001\u0005\n\u0005\r#AE\"pgR\u0014\u0015m]3e\u001fB$\u0018.\\5{KJT!a\u0003\u0007\u0002\rI\f\u0007/\u001b3t\u0015\tia\"A\u0003ta\u0006\u00148N\u0003\u0002\u0010!\u00051aN^5eS\u0006T\u0011!E\u0001\u0004G>l7\u0001A\n\u0005\u0001QQb\u0004\u0005\u0002\u001615\taCC\u0001\u0018\u0003\u0015\u00198-\u00197b\u0013\tIbC\u0001\u0004B]f\u0014VM\u001a\t\u00037qi\u0011AC\u0005\u0003;)\u0011\u0011b\u00149uS6L'0\u001a:\u0011\u0005}9S\"\u0001\u0011\u000b\u0005\u0005\u0012\u0013\u0001C5oi\u0016\u0014h.\u00197\u000b\u00055\u0019#B\u0001\u0013&\u0003\u0019\t\u0007/Y2iK*\ta%A\u0002pe\u001eL!\u0001\u000b\u0011\u0003\u000f1{wmZ5oO\u00061A(\u001b8jiz\"\u0012a\u000b\t\u00037\u0001\t\u0001b\u001c9uS6L'0\u001a\u000b\u0004]u\u0012\u0005cA\u00188u9\u0011\u0001'\u000e\b\u0003cQj\u0011A\r\u0006\u0003gI\ta\u0001\u0010:p_Rt\u0014\"A\f\n\u0005Y2\u0012a\u00029bG.\fw-Z\u0005\u0003qe\u00121aU3r\u0015\t1d\u0003\u0005\u0002\u001cw%\u0011AH\u0003\u0002\r\u001fB$\u0018.\\5{CRLwN\u001c\u0005\u0006}\t\u0001\raP\u0001\u0005G>tg\r\u0005\u0002\u001c\u0001&\u0011\u0011I\u0003\u0002\u000b%\u0006\u0004\u0018\u000eZ:D_:4\u0007\"B\"\u0003\u0001\u0004!\u0015\u0001\u00029mC:\u00042aG#H\u0013\t1%BA\u0007Ta\u0006\u00148\u000e\u00157b]6+G/\u0019\t\u0003\u00116k\u0011!\u0013\u0006\u0003\u0015.\u000b\u0011\"\u001a=fGV$\u0018n\u001c8\u000b\u00051\u0013\u0013aA:rY&\u0011a*\u0013\u0002\n'B\f'o\u001b)mC:\f1C]3dkJ\u001c\u0018N^3ms>\u0003H/[7ju\u0016$r!U,Y;~\u0003'\u000e\u0005\u0003\u0016%R#\u0016BA*\u0017\u0005\u0019!V\u000f\u001d7feA\u0011Q#V\u0005\u0003-Z\u0011a\u0001R8vE2,\u0007\"\u0002 \u0004\u0001\u0004y\u0004\"B-\u0004\u0001\u0004Q\u0016\u0001D2qk\u000e{7\u000f^'pI\u0016d\u0007CA\u000e\\\u0013\ta&BA\u0005D_N$Xj\u001c3fY\")al\u0001a\u00015\u0006aq\r];D_N$Xj\u001c3fY\")1i\u0001a\u0001\t\")\u0011m\u0001a\u0001E\u0006iq\u000e\u001d;j[&T\u0018\r^5p]N\u00042a\u00195;\u001b\u0005!'BA3g\u0003\u001diW\u000f^1cY\u0016T!a\u001a\f\u0002\u0015\r|G\u000e\\3di&|g.\u0003\u0002jI\nQA*[:u\u0005V4g-\u001a:\t\u000b-\u001c\u0001\u0019\u00017\u0002\u001b\u0019Lg.\u00197Pa\u0016\u0014\u0018\r^8s!\t)R.\u0003\u0002o-\t9!i\\8mK\u0006t\u0017\u0001\u00037pO\u000e{7\u000f^:\u0015\u0011E$\u0018QAA\r\u0003;\u0001\"!\u0006:\n\u0005M4\"\u0001B+oSRDQa\u0011\u0003A\u0002U\u0004$A^=\u0011\u0007m)u\u000f\u0005\u0002ys2\u0001A!\u0003>u\u0003\u0003\u0005\tQ!\u0001|\u0005\ryF%M\t\u0003y~\u0004\"!F?\n\u0005y4\"a\u0002(pi\"Lgn\u001a\t\u0004+\u0005\u0005\u0011bAA\u0002-\t\u0019\u0011I\\=\t\u000f\u0005\u001dA\u00011\u0001\u0002\n\u00059Q.Z:tC\u001e,\u0007\u0003BA\u0006\u0003'qA!!\u0004\u0002\u0010A\u0011\u0011GF\u0005\u0004\u0003#1\u0012A\u0002)sK\u0012,g-\u0003\u0003\u0002\u0016\u0005]!AB*ue&twMC\u0002\u0002\u0012YAa!a\u0007\u0005\u0001\u0004!\u0016aB2qk\u000e{7\u000f\u001e\u0005\u0007\u0003?!\u0001\u0019\u0001+\u0002\u000f\u001d\u0004XoQ8ti\u0006Y1-\u00198Sk:|en\u00129v)\ra\u0017Q\u0005\u0005\u0007\u0007\u0016\u0001\r!a\n1\t\u0005%\u0012Q\u0006\t\u00057\u0015\u000bY\u0003E\u0002y\u0003[!1\"a\f\u0002&\u0005\u0005\t\u0011!B\u0001w\n\u0019q\f\n\u001a\u0002'Q\u0014\u0018M\\:ji&|g\u000eV8HaV\u001cun\u001d;\u0015\u000bQ\u000b)$a\u000e\t\u000by2\u0001\u0019A \t\u000b\r3\u0001\u0019\u0001#\u0002'Q\u0014\u0018M\\:ji&|g\u000eV8DaV\u001cun\u001d;\u0015\u000bQ\u000bi$a\u0010\t\u000by:\u0001\u0019A \t\u000b\r;\u0001\u0019\u0001#\u0002\u0019%\u001cX\t_2iC:<Wm\u00149\u0015\u00071\f)\u0005\u0003\u0004D\u0011\u0001\u0007\u0011q\t\u0019\u0005\u0003\u0013\ni\u0005\u0005\u0003\u001c\u000b\u0006-\u0003c\u0001=\u0002N\u0011Y\u0011qJA#\u0003\u0003\u0005\tQ!\u0001|\u0005\ryFe\r")
/* loaded from: input_file:com/nvidia/spark/rapids/CostBasedOptimizer.class */
public class CostBasedOptimizer implements Optimizer, Logging {
    private transient Logger org$apache$spark$internal$Logging$$log_;

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    @Override // com.nvidia.spark.rapids.Optimizer
    public Seq<Optimization> optimize(RapidsConf rapidsConf, SparkPlanMeta<SparkPlan> sparkPlanMeta) {
        logTrace(() -> {
            return "CBO optimizing plan";
        });
        CpuCostModel cpuCostModel = new CpuCostModel(rapidsConf);
        GpuCostModel gpuCostModel = new GpuCostModel(rapidsConf);
        ListBuffer<Optimization> listBuffer = new ListBuffer<>();
        recursivelyOptimize(rapidsConf, cpuCostModel, gpuCostModel, sparkPlanMeta, listBuffer, true);
        if (listBuffer.isEmpty()) {
            logTrace(() -> {
                return "CBO finished optimizing plan. No optimizations applied.";
            });
        } else {
            logTrace(() -> {
                return new StringBuilder(55).append("CBO finished optimizing plan. ").append(listBuffer.length()).append(" optimizations applied:\n\t").append(listBuffer.mkString("\n\t")).toString();
            });
        }
        return listBuffer;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Tuple2<Object, Object> recursivelyOptimize(RapidsConf rapidsConf, CostModel costModel, CostModel costModel2, SparkPlanMeta<SparkPlan> sparkPlanMeta, ListBuffer<Optimization> listBuffer, boolean z) {
        double cost = costModel.getCost(sparkPlanMeta);
        double cost2 = costModel2.getCost(sparkPlanMeta);
        Seq seq = (Seq) sparkPlanMeta.childPlans().map(sparkPlanMeta2 -> {
            return this.recursivelyOptimize(rapidsConf, costModel, costModel2, sparkPlanMeta2, listBuffer, false);
        }, Seq$.MODULE$.canBuildFrom());
        Tuple2 unzip = seq.unzip(Predef$.MODULE$.$conforms());
        if (unzip == null) {
            throw new MatchError(unzip);
        }
        Tuple2 tuple2 = new Tuple2((Seq) unzip._1(), (Seq) unzip._2());
        Seq seq2 = (Seq) tuple2._1();
        Seq seq3 = (Seq) tuple2._2();
        double unboxToDouble = cost + BoxesRunTime.unboxToDouble(seq2.sum(Numeric$DoubleIsFractional$.MODULE$));
        DoubleRef create = DoubleRef.create(cost2 + BoxesRunTime.unboxToDouble(seq3.sum(Numeric$DoubleIsFractional$.MODULE$)));
        logCosts(sparkPlanMeta, "Operator costs", cost, cost2);
        logCosts(sparkPlanMeta, "Operator + child costs", unboxToDouble, create.elem);
        sparkPlanMeta.estimatedOutputRows_$eq(RowCountPlanVisitor$.MODULE$.visit(sparkPlanMeta));
        int count = sparkPlanMeta.childPlans().count(sparkPlanMeta3 -> {
            return BoxesRunTime.boxToBoolean($anonfun$recursivelyOptimize$2(this, sparkPlanMeta, sparkPlanMeta3));
        });
        logCosts(sparkPlanMeta, new StringBuilder(15).append("numTransitions=").append(count).toString(), unboxToDouble, create.elem);
        if (count > 0) {
            if (canRunOnGpu(sparkPlanMeta)) {
                double unboxToDouble2 = BoxesRunTime.unboxToDouble(((TraversableOnce) ((TraversableLike) sparkPlanMeta.childPlans().filterNot(sparkPlanMeta4 -> {
                    return BoxesRunTime.boxToBoolean(this.canRunOnGpu(sparkPlanMeta4));
                })).map(sparkPlanMeta5 -> {
                    return BoxesRunTime.boxToDouble(this.transitionToGpuCost(rapidsConf, sparkPlanMeta5));
                }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$));
                if (cost2 + unboxToDouble2 <= cost || isExchangeOp(sparkPlanMeta)) {
                    create.elem += unboxToDouble2;
                    logCosts(sparkPlanMeta, new StringBuilder(22).append("transitionFromCpuCost=").append(unboxToDouble2).toString(), unboxToDouble, create.elem);
                } else {
                    listBuffer.append(Predef$.MODULE$.wrapRefArray(new Optimization[]{new AvoidTransition(sparkPlanMeta)}));
                    sparkPlanMeta.costPreventsRunningOnGpu();
                    create.elem = unboxToDouble;
                    logCosts(sparkPlanMeta, "Avoid transition to GPU", unboxToDouble, create.elem);
                }
            } else {
                ((IterableLike) sparkPlanMeta.childPlans().zip(seq, Seq$.MODULE$.canBuildFrom())).foreach(tuple22 -> {
                    $anonfun$recursivelyOptimize$5(this, rapidsConf, listBuffer, unboxToDouble, create, tuple22);
                    return BoxedUnit.UNIT;
                });
                double unboxToDouble3 = BoxesRunTime.unboxToDouble(((TraversableOnce) ((TraversableLike) sparkPlanMeta.childPlans().filter(sparkPlanMeta6 -> {
                    return BoxesRunTime.boxToBoolean(this.canRunOnGpu(sparkPlanMeta6));
                })).map(sparkPlanMeta7 -> {
                    return BoxesRunTime.boxToDouble(this.transitionToCpuCost(rapidsConf, sparkPlanMeta7));
                }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$));
                create.elem += unboxToDouble3;
                logCosts(sparkPlanMeta, new StringBuilder(22).append("transitionFromGpuCost=").append(unboxToDouble3).toString(), unboxToDouble, create.elem);
            }
        }
        if (z && canRunOnGpu(sparkPlanMeta)) {
            double transitionToCpuCost = transitionToCpuCost(rapidsConf, sparkPlanMeta);
            create.elem += transitionToCpuCost;
            logCosts(sparkPlanMeta, new StringBuilder(38).append("final operator, transitionFromGpuCost=").append(transitionToCpuCost).toString(), unboxToDouble, create.elem);
        }
        if (create.elem > unboxToDouble && canRunOnGpu(sparkPlanMeta) && !isExchangeOp(sparkPlanMeta)) {
            listBuffer.append(Predef$.MODULE$.wrapRefArray(new Optimization[]{new ReplaceSection(sparkPlanMeta, unboxToDouble, create.elem)}));
            sparkPlanMeta.recursiveCostPreventsRunningOnGpu();
            create.elem = unboxToDouble;
            logCosts(sparkPlanMeta, new StringBuilder(16).append("ReplaceSection: ").append(sparkPlanMeta).toString(), unboxToDouble, create.elem);
        }
        if (!canRunOnGpu(sparkPlanMeta) || isExchangeOp(sparkPlanMeta)) {
            create.elem = unboxToDouble;
            logCosts(sparkPlanMeta, "Reset costs (not on GPU / exchange)", unboxToDouble, create.elem);
        }
        logCosts(sparkPlanMeta, "END", unboxToDouble, create.elem);
        return new Tuple2.mcDD.sp(unboxToDouble, create.elem);
    }

    private void logCosts(SparkPlanMeta<?> sparkPlanMeta, String str, double d, double d2) {
        String str2 = d == d2 ? "==" : d < d2 ? "<" : ">";
        logTrace(() -> {
            return new StringBuilder(28).append("CBO [").append(sparkPlanMeta.wrapped().getClass().getSimpleName()).append("] ").append(str).append(": ").append("cpuCost=").append(d).append(StringUtils.SPACE).append(str2).append(" gpuCost=").append(d2).append(")").toString();
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean canRunOnGpu(SparkPlanMeta<?> sparkPlanMeta) {
        return sparkPlanMeta.wrapped() instanceof AdaptiveSparkPlanExec ? true : sparkPlanMeta.canThisBeReplaced();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double transitionToGpuCost(RapidsConf rapidsConf, SparkPlanMeta<SparkPlan> sparkPlanMeta) {
        double unboxToDouble = BoxesRunTime.unboxToDouble(RowCountPlanVisitor$.MODULE$.visit(sparkPlanMeta).map(bigInt -> {
            return BoxesRunTime.boxToDouble(bigInt.toDouble());
        }).getOrElse(() -> {
            return rapidsConf.defaultRowCount();
        }));
        long estimateGpuMemory = MemoryCostHelper$.MODULE$.estimateGpuMemory(((QueryPlan) sparkPlanMeta.wrapped()).schema(), unboxToDouble);
        return (BoxesRunTime.unboxToDouble(rapidsConf.getGpuOperatorCost("GpuRowToColumnarExec").getOrElse(() -> {
            return 0.0d;
        })) * unboxToDouble) + MemoryCostHelper$.MODULE$.calculateCost(estimateGpuMemory, rapidsConf.cpuReadMemorySpeed()) + MemoryCostHelper$.MODULE$.calculateCost(estimateGpuMemory, rapidsConf.gpuWriteMemorySpeed());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double transitionToCpuCost(RapidsConf rapidsConf, SparkPlanMeta<SparkPlan> sparkPlanMeta) {
        double unboxToDouble = BoxesRunTime.unboxToDouble(RowCountPlanVisitor$.MODULE$.visit(sparkPlanMeta).map(bigInt -> {
            return BoxesRunTime.boxToDouble(bigInt.toDouble());
        }).getOrElse(() -> {
            return rapidsConf.defaultRowCount();
        }));
        long estimateGpuMemory = MemoryCostHelper$.MODULE$.estimateGpuMemory(((QueryPlan) sparkPlanMeta.wrapped()).schema(), unboxToDouble);
        return (BoxesRunTime.unboxToDouble(rapidsConf.getGpuOperatorCost("GpuColumnarToRowExec").getOrElse(() -> {
            return 0.0d;
        })) * unboxToDouble) + MemoryCostHelper$.MODULE$.calculateCost(estimateGpuMemory, rapidsConf.gpuReadMemorySpeed()) + MemoryCostHelper$.MODULE$.calculateCost(estimateGpuMemory, rapidsConf.cpuWriteMemorySpeed());
    }

    private boolean isExchangeOp(SparkPlanMeta<?> sparkPlanMeta) {
        return SparkShimImpl$.MODULE$.isExchangeOp(sparkPlanMeta);
    }

    public static final /* synthetic */ boolean $anonfun$recursivelyOptimize$2(CostBasedOptimizer costBasedOptimizer, SparkPlanMeta sparkPlanMeta, SparkPlanMeta sparkPlanMeta2) {
        return costBasedOptimizer.canRunOnGpu(sparkPlanMeta2) != costBasedOptimizer.canRunOnGpu(sparkPlanMeta);
    }

    public static final /* synthetic */ void $anonfun$recursivelyOptimize$5(CostBasedOptimizer costBasedOptimizer, RapidsConf rapidsConf, ListBuffer listBuffer, double d, DoubleRef doubleRef, Tuple2 tuple2) {
        BoxedUnit boxedUnit;
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        SparkPlanMeta<SparkPlan> sparkPlanMeta = (SparkPlanMeta) tuple2._1();
        Tuple2 tuple22 = (Tuple2) tuple2._2();
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        Tuple2.mcDD.sp spVar = new Tuple2.mcDD.sp(tuple22._1$mcD$sp(), tuple22._2$mcD$sp());
        double _1$mcD$sp = spVar._1$mcD$sp();
        double _2$mcD$sp = spVar._2$mcD$sp() + costBasedOptimizer.transitionToCpuCost(rapidsConf, sparkPlanMeta);
        if (!costBasedOptimizer.canRunOnGpu(sparkPlanMeta) || costBasedOptimizer.isExchangeOp(sparkPlanMeta) || _2$mcD$sp <= _1$mcD$sp) {
            boxedUnit = BoxedUnit.UNIT;
        } else {
            listBuffer.append(Predef$.MODULE$.wrapRefArray(new Optimization[]{new ReplaceSection(sparkPlanMeta, d, doubleRef.elem)}));
            sparkPlanMeta.recursiveCostPreventsRunningOnGpu();
            boxedUnit = BoxedUnit.UNIT;
        }
    }

    public CostBasedOptimizer() {
        Logging.$init$(this);
    }
}
