/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.spark.rapids;

import com.nvidia.spark.rapids.AvoidTransition;
import com.nvidia.spark.rapids.CostModel;
import com.nvidia.spark.rapids.CpuCostModel;
import com.nvidia.spark.rapids.GpuCostModel;
import com.nvidia.spark.rapids.MemoryCostHelper$;
import com.nvidia.spark.rapids.Optimization;
import com.nvidia.spark.rapids.Optimizer;
import com.nvidia.spark.rapids.RapidsConf;
import com.nvidia.spark.rapids.ReplaceSection;
import com.nvidia.spark.rapids.RowCountPlanVisitor$;
import com.nvidia.spark.rapids.SparkPlanMeta;
import com.nvidia.spark.rapids.shims.SparkShimImpl$;
import java.io.Serializable;
import org.apache.spark.internal.Logging;
import org.apache.spark.sql.catalyst.plans.QueryPlan;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.mutable.ListBuffer;
import scala.math.Numeric;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.java8.JFunction0;

@ScalaSignature(bytes="\u0006\u0001\u0005Ec\u0001B\u0005\u000b\u0001MAQ!\u000b\u0001\u0005\u0002)BQ\u0001\f\u0001\u0005\u00025BQa\u0014\u0001\u0005\nACQa\u001c\u0001\u0005\nADq!!\t\u0001\t\u0013\t\u0019\u0003C\u0004\u00022\u0001!I!a\r\t\u000f\u0005e\u0002\u0001\"\u0003\u0002<!9\u0011\u0011\t\u0001\u0005\n\u0005\r#AE\"pgR\u0014\u0015m]3e\u001fB$\u0018.\\5{KJT!a\u0003\u0007\u0002\rI\f\u0007/\u001b3t\u0015\tia\"A\u0003ta\u0006\u00148N\u0003\u0002\u0010!\u00051aN^5eS\u0006T\u0011!E\u0001\u0004G>l7\u0001A\n\u0005\u0001QQb\u0004\u0005\u0002\u001615\taCC\u0001\u0018\u0003\u0015\u00198-\u00197b\u0013\tIbC\u0001\u0004B]f\u0014VM\u001a\t\u00037qi\u0011AC\u0005\u0003;)\u0011\u0011b\u00149uS6L'0\u001a:\u0011\u0005}9S\"\u0001\u0011\u000b\u0005\u0005\u0012\u0013\u0001C5oi\u0016\u0014h.\u00197\u000b\u00055\u0019#B\u0001\u0013&\u0003\u0019\t\u0007/Y2iK*\ta%A\u0002pe\u001eL!\u0001\u000b\u0011\u0003\u000f1{wmZ5oO\u00061A(\u001b8jiz\"\u0012a\u000b\t\u00037\u0001\t\u0001b\u001c9uS6L'0\u001a\u000b\u0004]u\u0012\u0005cA\u00188u9\u0011\u0001'\u000e\b\u0003cQj\u0011A\r\u0006\u0003gI\ta\u0001\u0010:p_Rt\u0014\"A\f\n\u0005Y2\u0012a\u00029bG.\fw-Z\u0005\u0003qe\u00121aU3r\u0015\t1d\u0003\u0005\u0002\u001cw%\u0011AH\u0003\u0002\r\u001fB$\u0018.\\5{CRLwN\u001c\u0005\u0006}\t\u0001\raP\u0001\u0005G>tg\r\u0005\u0002\u001c\u0001&\u0011\u0011I\u0003\u0002\u000b%\u0006\u0004\u0018\u000eZ:D_:4\u0007\"B\"\u0003\u0001\u0004!\u0015\u0001\u00029mC:\u00042aG#H\u0013\t1%BA\u0007Ta\u0006\u00148\u000e\u00157b]6+G/\u0019\t\u0003\u00116k\u0011!\u0013\u0006\u0003\u0015.\u000b\u0011\"\u001a=fGV$\u0018n\u001c8\u000b\u00051\u0013\u0013aA:rY&\u0011a*\u0013\u0002\n'B\f'o\u001b)mC:\f1C]3dkJ\u001c\u0018N^3ms>\u0003H/[7ju\u0016$r!U,Y;~\u0003'\u000e\u0005\u0003\u0016%R#\u0016BA*\u0017\u0005\u0019!V\u000f\u001d7feA\u0011Q#V\u0005\u0003-Z\u0011a\u0001R8vE2,\u0007\"\u0002 \u0004\u0001\u0004y\u0004\"B-\u0004\u0001\u0004Q\u0016\u0001D2qk\u000e{7\u000f^'pI\u0016d\u0007CA\u000e\\\u0013\ta&BA\u0005D_N$Xj\u001c3fY\")al\u0001a\u00015\u0006aq\r];D_N$Xj\u001c3fY\")1i\u0001a\u0001\t\")\u0011m\u0001a\u0001E\u0006iq\u000e\u001d;j[&T\u0018\r^5p]N\u00042a\u00195;\u001b\u0005!'BA3g\u0003\u001diW\u000f^1cY\u0016T!a\u001a\f\u0002\u0015\r|G\u000e\\3di&|g.\u0003\u0002jI\nQA*[:u\u0005V4g-\u001a:\t\u000b-\u001c\u0001\u0019\u00017\u0002\u001b\u0019Lg.\u00197Pa\u0016\u0014\u0018\r^8s!\t)R.\u0003\u0002o-\t9!i\\8mK\u0006t\u0017\u0001\u00037pO\u000e{7\u000f^:\u0015\u0011E$\u0018QAA\r\u0003;\u0001\"!\u0006:\n\u0005M4\"\u0001B+oSRDQa\u0011\u0003A\u0002U\u0004$A^=\u0011\u0007m)u\u000f\u0005\u0002ys2\u0001A!\u0003>u\u0003\u0003\u0005\tQ!\u0001|\u0005\ryF%M\t\u0003y~\u0004\"!F?\n\u0005y4\"a\u0002(pi\"Lgn\u001a\t\u0004+\u0005\u0005\u0011bAA\u0002-\t\u0019\u0011I\\=\t\u000f\u0005\u001dA\u00011\u0001\u0002\n\u00059Q.Z:tC\u001e,\u0007\u0003BA\u0006\u0003'qA!!\u0004\u0002\u0010A\u0011\u0011GF\u0005\u0004\u0003#1\u0012A\u0002)sK\u0012,g-\u0003\u0003\u0002\u0016\u0005]!AB*ue&twMC\u0002\u0002\u0012YAa!a\u0007\u0005\u0001\u0004!\u0016aB2qk\u000e{7\u000f\u001e\u0005\u0007\u0003?!\u0001\u0019\u0001+\u0002\u000f\u001d\u0004XoQ8ti\u0006Y1-\u00198Sk:|en\u00129v)\ra\u0017Q\u0005\u0005\u0007\u0007\u0016\u0001\r!a\n1\t\u0005%\u0012Q\u0006\t\u00057\u0015\u000bY\u0003E\u0002y\u0003[!1\"a\f\u0002&\u0005\u0005\t\u0011!B\u0001w\n\u0019q\f\n\u001a\u0002'Q\u0014\u0018M\\:ji&|g\u000eV8HaV\u001cun\u001d;\u0015\u000bQ\u000b)$a\u000e\t\u000by2\u0001\u0019A \t\u000b\r3\u0001\u0019\u0001#\u0002'Q\u0014\u0018M\\:ji&|g\u000eV8DaV\u001cun\u001d;\u0015\u000bQ\u000bi$a\u0010\t\u000by:\u0001\u0019A \t\u000b\r;\u0001\u0019\u0001#\u0002\u0019%\u001cX\t_2iC:<Wm\u00149\u0015\u00071\f)\u0005\u0003\u0004D\u0011\u0001\u0007\u0011q\t\u0019\u0005\u0003\u0013\ni\u0005\u0005\u0003\u001c\u000b\u0006-\u0003c\u0001=\u0002N\u0011Y\u0011qJA#\u0003\u0003\u0005\tQ!\u0001|\u0005\ryFe\r")
public class CostBasedOptimizer
implements Optimizer,
Logging {
    private transient Logger org$apache$spark$internal$Logging$$log_;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    @Override
    public Seq<Optimization> optimize(RapidsConf conf, SparkPlanMeta<SparkPlan> plan) {
        this.logTrace((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "CBO optimizing plan");
        CpuCostModel cpuCostModel = new CpuCostModel(conf);
        GpuCostModel gpuCostModel = new GpuCostModel(conf);
        ListBuffer optimizations = new ListBuffer();
        this.recursivelyOptimize(conf, cpuCostModel, gpuCostModel, plan, (ListBuffer<Optimization>)optimizations, true);
        if (optimizations.isEmpty()) {
            this.logTrace((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "CBO finished optimizing plan. No optimizations applied.");
        } else {
            this.logTrace((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(55).append("CBO finished optimizing plan. ").append(optimizations.length()).append(" optimizations applied:\n\t").append(optimizations.mkString("\n\t")).toString());
        }
        return optimizations;
    }

    private Tuple2<Object, Object> recursivelyOptimize(RapidsConf conf, CostModel cpuCostModel, CostModel gpuCostModel, SparkPlanMeta<SparkPlan> plan2, ListBuffer<Optimization> optimizations, boolean finalOperator) {
        double operatorCpuCost = cpuCostModel.getCost(plan2);
        double operatorGpuCost = gpuCostModel.getCost(plan2);
        Seq childCosts = (Seq)plan2.childPlans().map((Function1 & Serializable & scala.Serializable)child -> this.recursivelyOptimize(conf, cpuCostModel, gpuCostModel, (SparkPlanMeta<SparkPlan>)child, optimizations, false), Seq$.MODULE$.canBuildFrom());
        Tuple2 tuple2 = childCosts.unzip((Function1)Predef$.MODULE$.$conforms());
        if (tuple2 == null) {
            throw new MatchError((Object)tuple2);
        }
        Seq childCpuCosts = (Seq)tuple2._1();
        Seq childGpuCosts = (Seq)tuple2._2();
        Tuple2 tuple22 = new Tuple2((Object)childCpuCosts, (Object)childGpuCosts);
        Tuple2 tuple23 = tuple22;
        Seq childCpuCosts2 = (Seq)tuple23._1();
        Seq childGpuCosts2 = (Seq)tuple23._2();
        double totalCpuCost = operatorCpuCost + BoxesRunTime.unboxToDouble((Object)childCpuCosts2.sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
        DoubleRef totalGpuCost = DoubleRef.create((double)(operatorGpuCost + BoxesRunTime.unboxToDouble((Object)childGpuCosts2.sum((Numeric)Numeric.DoubleIsFractional$.MODULE$))));
        this.logCosts(plan2, "Operator costs", operatorCpuCost, operatorGpuCost);
        this.logCosts(plan2, "Operator + child costs", totalCpuCost, totalGpuCost.elem);
        plan2.estimatedOutputRows_$eq(RowCountPlanVisitor$.MODULE$.visit(plan2));
        int numTransitions = plan2.childPlans().count((Function1 & Serializable & scala.Serializable)x$2 -> BoxesRunTime.boxToBoolean((boolean)CostBasedOptimizer.$anonfun$recursivelyOptimize$2(this, plan2, x$2)));
        this.logCosts(plan2, new StringBuilder(15).append("numTransitions=").append(numTransitions).toString(), totalCpuCost, totalGpuCost.elem);
        if (numTransitions > 0) {
            if (this.canRunOnGpu(plan2)) {
                double transitionCost = BoxesRunTime.unboxToDouble((Object)((TraversableOnce)((TraversableLike)plan2.childPlans().filterNot((Function1 & Serializable & scala.Serializable)plan -> BoxesRunTime.boxToBoolean((boolean)this.canRunOnGpu(plan)))).map((Function1 & Serializable & scala.Serializable)x$3 -> BoxesRunTime.boxToDouble((double)this.transitionToGpuCost(conf, x$3)), Seq$.MODULE$.canBuildFrom())).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
                if (operatorGpuCost + transitionCost > operatorCpuCost && !this.isExchangeOp(plan2)) {
                    optimizations.append((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Optimization[]{new AvoidTransition<SparkPlan>(plan2)}));
                    plan2.costPreventsRunningOnGpu();
                    totalGpuCost.elem = totalCpuCost;
                    this.logCosts(plan2, "Avoid transition to GPU", totalCpuCost, totalGpuCost.elem);
                } else {
                    totalGpuCost.elem += transitionCost;
                    this.logCosts(plan2, new StringBuilder(22).append("transitionFromCpuCost=").append(transitionCost).toString(), totalCpuCost, totalGpuCost.elem);
                }
            } else {
                ((IterableLike)plan2.childPlans().zip((GenIterable)childCosts, Seq$.MODULE$.canBuildFrom())).foreach((Function1 & Serializable & scala.Serializable)x0$1 -> {
                    CostBasedOptimizer.$anonfun$recursivelyOptimize$5(this, conf, optimizations, totalCpuCost, totalGpuCost, x0$1);
                    return BoxedUnit.UNIT;
                });
                double transitionCost = BoxesRunTime.unboxToDouble((Object)((TraversableOnce)((TraversableLike)plan2.childPlans().filter((Function1 & Serializable & scala.Serializable)plan -> BoxesRunTime.boxToBoolean((boolean)this.canRunOnGpu(plan)))).map((Function1 & Serializable & scala.Serializable)x$5 -> BoxesRunTime.boxToDouble((double)this.transitionToCpuCost(conf, x$5)), Seq$.MODULE$.canBuildFrom())).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
                totalGpuCost.elem += transitionCost;
                this.logCosts(plan2, new StringBuilder(22).append("transitionFromGpuCost=").append(transitionCost).toString(), totalCpuCost, totalGpuCost.elem);
            }
        }
        if (finalOperator && this.canRunOnGpu(plan2)) {
            double transitionCost = this.transitionToCpuCost(conf, plan2);
            totalGpuCost.elem += transitionCost;
            this.logCosts(plan2, new StringBuilder(38).append("final operator, transitionFromGpuCost=").append(transitionCost).toString(), totalCpuCost, totalGpuCost.elem);
        }
        if (totalGpuCost.elem > totalCpuCost && this.canRunOnGpu(plan2) && !this.isExchangeOp(plan2)) {
            optimizations.append((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Optimization[]{new ReplaceSection<SparkPlan>(plan2, totalCpuCost, totalGpuCost.elem)}));
            plan2.recursiveCostPreventsRunningOnGpu();
            totalGpuCost.elem = totalCpuCost;
            this.logCosts(plan2, new StringBuilder(16).append("ReplaceSection: ").append(plan2).toString(), totalCpuCost, totalGpuCost.elem);
        }
        if (!this.canRunOnGpu(plan2) || this.isExchangeOp(plan2)) {
            totalGpuCost.elem = totalCpuCost;
            this.logCosts(plan2, "Reset costs (not on GPU / exchange)", totalCpuCost, totalGpuCost.elem);
        }
        this.logCosts(plan2, "END", totalCpuCost, totalGpuCost.elem);
        return new Tuple2.mcDD.sp(totalCpuCost, totalGpuCost.elem);
    }

    private void logCosts(SparkPlanMeta<?> plan, String message, double cpuCost, double gpuCost) {
        String sign = cpuCost == gpuCost ? "==" : (cpuCost < gpuCost ? "<" : ">");
        this.logTrace((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(28).append("CBO [").append(plan.wrapped().getClass().getSimpleName()).append("] ").append(message).append(": ").append("cpuCost=").append(cpuCost).append(" ").append(sign).append(" gpuCost=").append(gpuCost).append(")").toString());
    }

    private boolean canRunOnGpu(SparkPlanMeta<?> plan) {
        Object INPUT = plan.wrapped();
        boolean bl = INPUT instanceof AdaptiveSparkPlanExec ? true : plan.canThisBeReplaced();
        return bl;
    }

    private double transitionToGpuCost(RapidsConf conf, SparkPlanMeta<SparkPlan> plan) {
        double rowCount = BoxesRunTime.unboxToDouble((Object)RowCountPlanVisitor$.MODULE$.visit(plan).map((Function1 & Serializable & scala.Serializable)x$6 -> BoxesRunTime.boxToDouble((double)x$6.toDouble())).getOrElse((Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> conf.defaultRowCount()));
        long dataSize = MemoryCostHelper$.MODULE$.estimateGpuMemory(((QueryPlan)plan.wrapped()).schema(), rowCount);
        return BoxesRunTime.unboxToDouble((Object)conf.getGpuOperatorCost("GpuRowToColumnarExec").getOrElse((Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> 0.0)) * rowCount + MemoryCostHelper$.MODULE$.calculateCost(dataSize, conf.cpuReadMemorySpeed()) + MemoryCostHelper$.MODULE$.calculateCost(dataSize, conf.gpuWriteMemorySpeed());
    }

    private double transitionToCpuCost(RapidsConf conf, SparkPlanMeta<SparkPlan> plan) {
        double rowCount = BoxesRunTime.unboxToDouble((Object)RowCountPlanVisitor$.MODULE$.visit(plan).map((Function1 & Serializable & scala.Serializable)x$7 -> BoxesRunTime.boxToDouble((double)x$7.toDouble())).getOrElse((Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> conf.defaultRowCount()));
        long dataSize = MemoryCostHelper$.MODULE$.estimateGpuMemory(((QueryPlan)plan.wrapped()).schema(), rowCount);
        return BoxesRunTime.unboxToDouble((Object)conf.getGpuOperatorCost("GpuColumnarToRowExec").getOrElse((Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> 0.0)) * rowCount + MemoryCostHelper$.MODULE$.calculateCost(dataSize, conf.gpuReadMemorySpeed()) + MemoryCostHelper$.MODULE$.calculateCost(dataSize, conf.cpuWriteMemorySpeed());
    }

    private boolean isExchangeOp(SparkPlanMeta<?> plan) {
        return SparkShimImpl$.MODULE$.isExchangeOp(plan);
    }

    public static final /* synthetic */ boolean $anonfun$recursivelyOptimize$2(CostBasedOptimizer $this, SparkPlanMeta plan$1, SparkPlanMeta x$2) {
        return $this.canRunOnGpu(x$2) != $this.canRunOnGpu(plan$1);
    }

    public static final /* synthetic */ void $anonfun$recursivelyOptimize$5(CostBasedOptimizer $this, RapidsConf conf$1, ListBuffer optimizations$2, double totalCpuCost$1, DoubleRef totalGpuCost$1, Tuple2 x0$1) {
        BoxedUnit boxedUnit;
        Tuple2 tuple2 = x0$1;
        if (tuple2 != null) {
            SparkPlanMeta child = (SparkPlanMeta)tuple2._1();
            Tuple2 childCosts = (Tuple2)tuple2._2();
            Tuple2 tuple22 = childCosts;
            if (tuple22 == null) {
                throw new MatchError((Object)tuple22);
            }
            double childCpuCost = tuple22._1$mcD$sp();
            double childGpuCost = tuple22._2$mcD$sp();
            Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(childCpuCost, childGpuCost);
            Tuple2.mcDD.sp sp3 = sp2;
            double childCpuCost2 = sp3._1$mcD$sp();
            double childGpuCost2 = sp3._2$mcD$sp();
            double transitionCost = $this.transitionToCpuCost(conf$1, child);
            double childGpuTotal = childGpuCost2 + transitionCost;
            if ($this.canRunOnGpu(child) && !$this.isExchangeOp(child) && childGpuTotal > childCpuCost2) {
                optimizations$2.append((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Optimization[]{new ReplaceSection(child, totalCpuCost$1, totalGpuCost$1.elem)}));
                child.recursiveCostPreventsRunningOnGpu();
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
        } else {
            throw new MatchError((Object)tuple2);
        }
        BoxedUnit boxedUnit2 = boxedUnit;
    }

    public CostBasedOptimizer() {
        Logging.$init$((Logging)this);
    }
}

