/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.physical;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.FilterOperator;
import org.apache.hadoop.hive.ql.exec.FunctionInfo;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ScriptOperator;
import org.apache.hadoop.hive.ql.exec.SelectOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.tez.TezTask;
import org.apache.hadoop.hive.ql.io.HiveInputFormat;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.optimizer.physical.LlapClusterStateForCompile;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.AggregationDesc;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.plan.GroupByDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.SelectDesc;
import org.apache.hadoop.hive.ql.plan.Statistics;
import org.apache.hadoop.hive.ql.plan.TezWork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LlapDecider
implements PhysicalPlanResolver {
    protected static final transient Logger LOG = LoggerFactory.getLogger(LlapDecider.class);
    private HiveConf conf;
    private LlapMode mode;
    private final LlapClusterStateForCompile clusterState;

    public LlapDecider(LlapClusterStateForCompile clusterState) {
        this.clusterState = clusterState;
    }

    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        this.conf = pctx.getConf();
        this.mode = LlapMode.valueOf(HiveConf.getVar(this.conf, HiveConf.ConfVars.LLAP_EXECUTION_MODE));
        Preconditions.checkState(this.mode != null, "Unrecognized LLAP mode configuration: " + HiveConf.getVar(this.conf, HiveConf.ConfVars.LLAP_EXECUTION_MODE));
        LOG.info("llap mode: " + this.mode);
        if (this.mode == LlapMode.none) {
            LOG.info("LLAP disabled.");
            return pctx;
        }
        LlapDecisionDispatcher disp = new LlapDecisionDispatcher(pctx, this.mode);
        TaskGraphWalker ogw = new TaskGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    class LlapDecisionDispatcher
    implements Dispatcher {
        private final HiveConf conf;
        private final boolean doSkipUdfCheck;
        private final boolean arePermanentFnsAllowed;
        private final boolean shouldUber;
        private final float minReducersPerExec;
        private final int executorsPerNode;
        private List<MapJoinOperator> mapJoinOpList;
        private final Map<Rule, NodeProcessor> rules;

        public LlapDecisionDispatcher(PhysicalContext pctx, LlapMode mode) {
            this.conf = pctx.getConf();
            this.doSkipUdfCheck = HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_SKIP_COMPILE_UDF_CHECK);
            this.arePermanentFnsAllowed = HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_ALLOW_PERMANENT_FNS);
            this.shouldUber = HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_ALLOW_UBER) && mode != LlapMode.all;
            this.minReducersPerExec = HiveConf.getFloatVar(this.conf, HiveConf.ConfVars.TEZ_LLAP_MIN_REDUCER_PER_EXECUTOR);
            this.executorsPerNode = HiveConf.getIntVar(this.conf, HiveConf.ConfVars.LLAP_DAEMON_NUM_EXECUTORS);
            this.mapJoinOpList = new ArrayList<MapJoinOperator>();
            this.rules = this.getRules();
        }

        @Override
        public Object dispatch(Node nd, Stack<Node> stack, Object ... nodeOutputs) throws SemanticException {
            Task currTask = (Task)nd;
            if (currTask instanceof TezTask) {
                TezWork work = (TezWork)((TezTask)currTask).getWork();
                for (BaseWork w : work.getAllWork()) {
                    this.handleWork(work, w);
                }
            }
            return null;
        }

        private void handleWork(TezWork tezWork, BaseWork work) throws SemanticException {
            boolean workCanBeDoneInLlap = this.evaluateWork(tezWork, work);
            LOG.debug("Work " + work + " " + (workCanBeDoneInLlap ? "can" : "cannot") + " be done in LLAP");
            if (workCanBeDoneInLlap) {
                for (MapJoinOperator graceMapJoinOp : this.mapJoinOpList) {
                    LOG.debug("Disabling hybrid grace hash join in case of LLAP and non-dynamic partition hash join.");
                    ((MapJoinDesc)graceMapJoinOp.getConf()).setHybridHashJoin(false);
                }
                this.adjustAutoParallelism(work);
                this.convertWork(tezWork, work);
            }
            this.mapJoinOpList.clear();
        }

        private void adjustAutoParallelism(BaseWork work) {
            if (this.minReducersPerExec <= 0.0f || !(work instanceof ReduceWork)) {
                return;
            }
            ReduceWork reduceWork = (ReduceWork)work;
            if (!reduceWork.isAutoReduceParallelism() && !reduceWork.isUniformDistribution()) {
                return;
            }
            LlapDecider.this.clusterState.initClusterInfo();
            int targetCount = 0;
            if (!LlapDecider.this.clusterState.hasClusterInfo()) {
                LOG.warn("Cannot determine LLAP cluster information");
                targetCount = (int)Math.ceil(this.minReducersPerExec * 1.0f * (float)this.executorsPerNode);
            } else {
                targetCount = (int)Math.ceil(this.minReducersPerExec * (float)(LlapDecider.this.clusterState.getKnownExecutorCount() + LlapDecider.this.clusterState.getNodeCountWithUnknownExecutors() * this.executorsPerNode));
            }
            if (reduceWork.isAutoReduceParallelism()) {
                int newMin = Math.min(this.conf.getIntVar(HiveConf.ConfVars.MAXREDUCERS), Math.max(reduceWork.getMinReduceTasks(), targetCount));
                if (newMin < reduceWork.getMaxReduceTasks()) {
                    reduceWork.setMinReduceTasks(newMin);
                    reduceWork.getEdgePropRef().setAutoReduce(this.conf, true, newMin, reduceWork.getMaxReduceTasks(), this.conf.getLongVar(HiveConf.ConfVars.BYTESPERREDUCER));
                } else {
                    reduceWork.setAutoReduceParallelism(false);
                    reduceWork.setNumReduceTasks(newMin);
                    reduceWork.getEdgePropRef().setAutoReduce(null, false, 0, 0, 0L);
                }
            } else {
                reduceWork.setNumReduceTasks(Math.max(reduceWork.getNumReduceTasks(), targetCount));
            }
        }

        private void convertWork(TezWork tezWork, BaseWork work) throws SemanticException {
            if (this.shouldUber && tezWork.getChildren(work).isEmpty() && work instanceof ReduceWork && ((ReduceWork)work).getNumReduceTasks() == 1) {
                LOG.info("Converting work to uber: {}", (Object)work);
                work.setUberMode(true);
            }
            work.setLlapMode(true);
        }

        private boolean evaluateWork(TezWork tezWork, BaseWork work) throws SemanticException {
            LOG.info("Evaluating work item: " + work.getName());
            if (LlapDecider.this.mode == LlapMode.none) {
                return false;
            }
            if (!this.evaluateOperators(work)) {
                LOG.info("some operators cannot be run in llap");
                if (LlapDecider.this.mode == LlapMode.only) {
                    throw new RuntimeException("Cannot run all parts of query in llap. Failing since " + HiveConf.ConfVars.LLAP_EXECUTION_MODE.varname + " is set to " + LlapMode.only.name());
                }
                return false;
            }
            if (EnumSet.of(LlapMode.all, LlapMode.only).contains((Object)LlapDecider.this.mode)) {
                LOG.info("LLAP mode set to '" + LlapDecider.this.mode + "' so can convert any work.");
                return true;
            }
            if (LlapDecider.this.mode == LlapMode.map) {
                return work instanceof MapWork;
            }
            assert (LlapDecider.this.mode == LlapMode.auto) : "Mode must be " + LlapMode.auto.name() + " at this point";
            if (HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_ENFORCE_TREE) && !this.checkParentsInLlap(tezWork, work)) {
                LOG.info("Parent not in llap.");
                return false;
            }
            if (work instanceof MapWork && HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_ENFORCE_VECTORIZED) && !this.checkInputsVectorized((MapWork)work)) {
                LOG.info("Inputs not vectorized.");
                return false;
            }
            if (HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_ENFORCE_STATS) && !this.checkPartialStatsAvailable(work)) {
                LOG.info("No column stats available.");
                return false;
            }
            long maxInput = HiveConf.getLongVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_MAX_INPUT);
            long expectedInput = this.computeInputSize(work);
            if (maxInput >= 0L && expectedInput > maxInput) {
                LOG.info(String.format("Inputs too big (%d > %d)", expectedInput, maxInput));
                return false;
            }
            long maxOutput = HiveConf.getLongVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_MAX_OUTPUT);
            long expectedOutput = this.computeOutputSize(work);
            if (maxOutput >= 0L && expectedOutput > maxOutput) {
                LOG.info(String.format("Outputs too big (%d > %d)", expectedOutput, maxOutput));
                return false;
            }
            LOG.info("Can run work " + work.getName() + " in llap mode.");
            return true;
        }

        private boolean checkExpression(ExprNodeDesc expr) {
            LinkedList<ExprNodeDesc> exprs = new LinkedList<ExprNodeDesc>();
            exprs.add(expr);
            while (!exprs.isEmpty()) {
                ExprNodeGenericFuncDesc funcDesc;
                boolean isBuiltIn;
                ExprNodeDesc cur;
                if (LOG.isDebugEnabled()) {
                    LOG.debug(String.format("Checking '%s'", expr.getExprString()));
                }
                if ((cur = (ExprNodeDesc)exprs.removeFirst()) == null) continue;
                if (cur.getChildren() != null) {
                    exprs.addAll(cur.getChildren());
                }
                if (this.doSkipUdfCheck || !(cur instanceof ExprNodeGenericFuncDesc) || (isBuiltIn = FunctionRegistry.isBuiltInFuncExpr(funcDesc = (ExprNodeGenericFuncDesc)cur))) continue;
                if (!this.arePermanentFnsAllowed) {
                    LOG.info("Not a built-in function: " + cur.getExprString() + " (permanent functions are disabled)");
                    return false;
                }
                if (FunctionRegistry.isPermanentFunction(funcDesc)) continue;
                LOG.info("Not a built-in or permanent function: " + cur.getExprString());
                return false;
            }
            return true;
        }

        private boolean checkAggregator(AggregationDesc agg) throws SemanticException {
            if (LOG.isDebugEnabled()) {
                LOG.debug(String.format("Checking '%s'", agg.getExprString()));
            }
            boolean result = this.checkExpressions(agg.getParameters());
            FunctionInfo fi = FunctionRegistry.getFunctionInfo(agg.getGenericUDAFName());
            boolean bl = result = result && fi != null && fi.isNative();
            if (!result) {
                LOG.info("Aggregator is not native: " + agg.getExprString());
            }
            return result;
        }

        private boolean checkExpressions(Collection<ExprNodeDesc> exprs) {
            for (ExprNodeDesc expr : exprs) {
                if (this.checkExpression(expr)) continue;
                return false;
            }
            return true;
        }

        private boolean checkAggregators(Collection<AggregationDesc> aggs) {
            try {
                for (AggregationDesc agg : aggs) {
                    if (this.checkAggregator(agg)) continue;
                    return false;
                }
            }
            catch (SemanticException e) {
                LOG.warn("Exception testing aggregators.", (Throwable)e);
                return false;
            }
            return true;
        }

        private Map<Rule, NodeProcessor> getRules() {
            LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
            opRules.put(new RuleRegExp("No scripts", ScriptOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                    LOG.debug("Cannot run operator [" + n + "] in llap mode.");
                    return new Boolean(false);
                }
            });
            opRules.put(new RuleRegExp("No user code in fil", FilterOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                    ExprNodeDesc expr = ((FilterDesc)((FilterOperator)n).getConf()).getPredicate();
                    Boolean retval = new Boolean(LlapDecisionDispatcher.this.checkExpression(expr));
                    if (!retval.booleanValue()) {
                        LOG.info("Cannot run filter operator [" + n + "] in llap mode");
                    }
                    return new Boolean(retval);
                }
            });
            opRules.put(new RuleRegExp("No user code in gby", GroupByOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                    ArrayList<AggregationDesc> aggs = ((GroupByDesc)((Operator)n).getConf()).getAggregators();
                    Boolean retval = new Boolean(LlapDecisionDispatcher.this.checkAggregators(aggs));
                    if (!retval.booleanValue()) {
                        LOG.info("Cannot run group by operator [" + n + "] in llap mode");
                    }
                    return new Boolean(retval);
                }
            });
            opRules.put(new RuleRegExp("No user code in select", SelectOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                    List<ExprNodeDesc> exprs = ((SelectDesc)((Operator)n).getConf()).getColList();
                    Boolean retval = new Boolean(LlapDecisionDispatcher.this.checkExpressions(exprs));
                    if (!retval.booleanValue()) {
                        LOG.info("Cannot run select operator [" + n + "] in llap mode");
                    }
                    return new Boolean(retval);
                }
            });
            if (!this.conf.getBoolVar(HiveConf.ConfVars.LLAP_ENABLE_GRACE_JOIN_IN_LLAP)) {
                opRules.put(new RuleRegExp("Disable grace hash join if LLAP mode and not dynamic partition hash join", MapJoinOperator.getOperatorName() + "%"), new NodeProcessor(){

                    @Override
                    public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                        MapJoinOperator mapJoinOp = (MapJoinOperator)n;
                        if (((MapJoinDesc)mapJoinOp.getConf()).isHybridHashJoin() && !((MapJoinDesc)mapJoinOp.getConf()).isDynamicPartitionHashJoin()) {
                            LlapDecisionDispatcher.this.mapJoinOpList.add((MapJoinOperator)n);
                        }
                        return new Boolean(true);
                    }
                });
            }
            return opRules;
        }

        private boolean evaluateOperators(BaseWork work) throws SemanticException {
            DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, this.rules, null);
            DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
            ArrayList<Node> topNodes = new ArrayList<Node>();
            topNodes.addAll(work.getAllRootOperators());
            HashMap<Node, Object> nodeOutput = new HashMap<Node, Object>();
            ogw.startWalking(topNodes, nodeOutput);
            for (Node n : nodeOutput.keySet()) {
                if (nodeOutput.get(n) == null || ((Boolean)nodeOutput.get(n)).booleanValue()) continue;
                return false;
            }
            return true;
        }

        private boolean checkParentsInLlap(TezWork tezWork, BaseWork base) {
            for (BaseWork w : tezWork.getParents(base)) {
                if (w.getLlapMode()) continue;
                LOG.info("Not all parents are run in llap");
                return false;
            }
            return true;
        }

        private boolean checkInputsVectorized(MapWork mapWork) {
            boolean mayWrap = HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_IO_NONVECTOR_WRAPPER_ENABLED);
            Collection<Class<?>> excludedInputFormats = Utilities.getClassNamesFromConfig(this.conf, HiveConf.ConfVars.HIVE_VECTORIZATION_VECTORIZED_INPUT_FILE_FORMAT_EXCLUDES);
            for (PartitionDesc pd : mapWork.getPathToPartitionInfo().values()) {
                if (Utilities.isInputFileFormatVectorized(pd) && !excludedInputFormats.contains(pd.getInputFileFormatClass()) || mayWrap && HiveInputFormat.canWrapForLlap(pd.getInputFileFormatClass(), true)) continue;
                LOG.info("Input format: " + pd.getInputFileFormatClassName() + ", doesn't provide vectorized input");
                return false;
            }
            return true;
        }

        private boolean checkPartialStatsAvailable(BaseWork base) {
            for (Operator<? extends OperatorDesc> o : base.getAllRootOperators()) {
                if (o.getStatistics().getColumnStatsState() != Statistics.State.NONE) continue;
                return false;
            }
            return true;
        }

        private long computeEdgeSize(BaseWork base, boolean input) {
            long size = 0L;
            for (Operator<? extends OperatorDesc> o : input ? base.getAllRootOperators() : base.getAllLeafOperators()) {
                if (o.getStatistics() == null) {
                    return Long.MAX_VALUE;
                }
                long currSize = o.getStatistics().getDataSize();
                if (currSize < 0L || Long.MAX_VALUE - size < currSize) {
                    return Long.MAX_VALUE;
                }
                size += currSize;
            }
            return size;
        }

        private long computeInputSize(BaseWork base) {
            return this.computeEdgeSize(base, true);
        }

        private long computeOutputSize(BaseWork base) {
            return this.computeEdgeSize(base, false);
        }
    }

    public static enum LlapMode {
        map,
        all,
        none,
        only,
        auto;

    }
}

