/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.physical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.FilterOperator;
import org.apache.hadoop.hive.ql.exec.FunctionInfo;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ScriptOperator;
import org.apache.hadoop.hive.ql.exec.SelectOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.tez.TezTask;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.AggregationDesc;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.plan.GroupByDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.SelectDesc;
import org.apache.hadoop.hive.ql.plan.Statistics;
import org.apache.hadoop.hive.ql.plan.TezWork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LlapDecider
implements PhysicalPlanResolver {
    protected static final transient Logger LOG = LoggerFactory.getLogger(LlapDecider.class);
    private PhysicalContext physicalContext;
    private HiveConf conf;
    private LlapMode mode;

    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        this.physicalContext = pctx;
        this.conf = pctx.getConf();
        this.mode = LlapMode.valueOf(HiveConf.getVar(this.conf, HiveConf.ConfVars.LLAP_EXECUTION_MODE));
        LOG.info("llap mode: " + (Object)((Object)this.mode));
        if (this.mode == LlapMode.none) {
            LOG.info("LLAP disabled.");
            return pctx;
        }
        LlapDecisionDispatcher disp = new LlapDecisionDispatcher(pctx);
        TaskGraphWalker ogw = new TaskGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    class LlapDecisionDispatcher
    implements Dispatcher {
        private final PhysicalContext pctx;
        private final HiveConf conf;

        public LlapDecisionDispatcher(PhysicalContext pctx) {
            this.pctx = pctx;
            this.conf = pctx.getConf();
        }

        @Override
        public Object dispatch(Node nd, Stack<Node> stack, Object ... nodeOutputs) throws SemanticException {
            Task currTask = (Task)nd;
            if (currTask instanceof TezTask) {
                TezWork work = (TezWork)((TezTask)currTask).getWork();
                for (BaseWork w : work.getAllWork()) {
                    this.handleWork(work, w);
                }
            }
            return null;
        }

        private void handleWork(TezWork tezWork, BaseWork work) throws SemanticException {
            if (this.evaluateWork(tezWork, work)) {
                this.convertWork(tezWork, work);
            }
        }

        private void convertWork(TezWork tezWork, BaseWork work) throws SemanticException {
            if (HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_ALLOW_UBER) && tezWork.getChildren(work).isEmpty() && work instanceof ReduceWork && ((ReduceWork)work).getNumReduceTasks() == 1) {
                work.setUberMode(true);
            }
            work.setLlapMode(true);
        }

        private boolean evaluateWork(TezWork tezWork, BaseWork work) throws SemanticException {
            LOG.info("Evaluating work item: " + work.getName());
            if (LlapDecider.this.mode == LlapMode.none) {
                return false;
            }
            if (work instanceof MapWork && ((MapWork)work).isUseOneNullRowInputFormat()) {
                return false;
            }
            if (!this.evaluateOperators(work)) {
                LOG.info("some operators cannot be run in llap");
                return false;
            }
            if (LlapDecider.this.mode == LlapMode.all) {
                return true;
            }
            if (LlapDecider.this.mode == LlapMode.map) {
                return work instanceof MapWork;
            }
            assert (LlapDecider.this.mode == LlapMode.auto);
            if (HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_ENFORCE_TREE) && !this.checkParentsInLlap(tezWork, work)) {
                LOG.info("Parent not in llap.");
                return false;
            }
            if (work instanceof MapWork && HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_ENFORCE_VECTORIZED) && !this.checkInputsVectorized((MapWork)work)) {
                LOG.info("Inputs not vectorized.");
                return false;
            }
            if (HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_ENFORCE_STATS) && !this.checkPartialStatsAvailable(work)) {
                LOG.info("No column stats available.");
                return false;
            }
            long maxInput = HiveConf.getLongVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_MAX_INPUT);
            long expectedInput = this.computeInputSize(work);
            if (maxInput >= 0L && expectedInput > maxInput) {
                LOG.info(String.format("Inputs too big (%d > %d)", expectedInput, maxInput));
                return false;
            }
            long maxOutput = HiveConf.getLongVar(this.conf, HiveConf.ConfVars.LLAP_AUTO_MAX_OUTPUT);
            long expectedOutput = this.computeOutputSize(work);
            if (maxOutput >= 0L && expectedOutput > maxOutput) {
                LOG.info(String.format("Outputs too big (%d > %d)", expectedOutput, maxOutput));
                return false;
            }
            return true;
        }

        private boolean checkExpression(ExprNodeDesc expr) {
            LinkedList<ExprNodeDesc> exprs = new LinkedList<ExprNodeDesc>();
            exprs.add(expr);
            while (!exprs.isEmpty()) {
                ExprNodeDesc cur;
                if (LOG.isDebugEnabled()) {
                    LOG.debug(String.format("Checking '%s'", expr.getExprString()));
                }
                if ((cur = (ExprNodeDesc)exprs.removeFirst()) == null) continue;
                if (cur.getChildren() != null) {
                    exprs.addAll(cur.getChildren());
                }
                if (!(cur instanceof ExprNodeGenericFuncDesc) || FunctionRegistry.isBuiltInFuncExpr((ExprNodeGenericFuncDesc)cur)) continue;
                LOG.info("Not a built-in function: " + cur.getExprString());
                return false;
            }
            return true;
        }

        private boolean checkAggregator(AggregationDesc agg) throws SemanticException {
            if (LOG.isDebugEnabled()) {
                LOG.debug(String.format("Checking '%s'", agg.getExprString()));
            }
            boolean result = this.checkExpressions(agg.getParameters());
            FunctionInfo fi = FunctionRegistry.getFunctionInfo(agg.getGenericUDAFName());
            boolean bl = result = result && fi != null && fi.isNative();
            if (!result) {
                LOG.info("Aggregator is not native: " + agg.getExprString());
            }
            return result;
        }

        private boolean checkExpressions(Collection<ExprNodeDesc> exprs) {
            boolean result = true;
            for (ExprNodeDesc expr : exprs) {
                result = result && this.checkExpression(expr);
            }
            return result;
        }

        private boolean checkAggregators(Collection<AggregationDesc> aggs) {
            boolean result = true;
            try {
                for (AggregationDesc agg : aggs) {
                    result = result && this.checkAggregator(agg);
                }
            }
            catch (SemanticException e) {
                LOG.warn("Exception testing aggregators.", (Throwable)e);
                result = false;
            }
            return result;
        }

        private Map<Rule, NodeProcessor> getRules() {
            LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
            opRules.put(new RuleRegExp("No scripts", ScriptOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                    return new Boolean(false);
                }
            });
            opRules.put(new RuleRegExp("No user code in fil", FilterOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                    ExprNodeDesc expr = ((FilterDesc)((FilterOperator)n).getConf()).getPredicate();
                    return new Boolean(LlapDecisionDispatcher.this.checkExpression(expr));
                }
            });
            opRules.put(new RuleRegExp("No user code in gby", GroupByOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                    ArrayList<AggregationDesc> aggs = ((GroupByDesc)((GroupByOperator)n).getConf()).getAggregators();
                    return new Boolean(LlapDecisionDispatcher.this.checkAggregators(aggs));
                }
            });
            opRules.put(new RuleRegExp("No user code in select", SelectOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                    List<ExprNodeDesc> exprs = ((SelectDesc)((SelectOperator)n).getConf()).getColList();
                    return new Boolean(LlapDecisionDispatcher.this.checkExpressions(exprs));
                }
            });
            return opRules;
        }

        private boolean evaluateOperators(BaseWork work) throws SemanticException {
            DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, this.getRules(), null);
            DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
            ArrayList<Node> topNodes = new ArrayList<Node>();
            topNodes.addAll(work.getAllRootOperators());
            HashMap<Node, Object> nodeOutput = new HashMap<Node, Object>();
            ogw.startWalking(topNodes, nodeOutput);
            for (Node n : nodeOutput.keySet()) {
                if (nodeOutput.get(n) == null || ((Boolean)nodeOutput.get(n)).booleanValue()) continue;
                return false;
            }
            return true;
        }

        private boolean checkParentsInLlap(TezWork tezWork, BaseWork base) {
            for (BaseWork w : tezWork.getParents(base)) {
                if (w.getLlapMode()) continue;
                LOG.info("Not all parents are run in llap");
                return false;
            }
            return true;
        }

        private boolean checkInputsVectorized(MapWork mapWork) {
            for (String path : mapWork.getPathToPartitionInfo().keySet()) {
                PartitionDesc pd = mapWork.getPathToPartitionInfo().get(path);
                List<Class<?>> interfaceList = Arrays.asList(pd.getInputFileFormatClass().getInterfaces());
                if (interfaceList.contains(VectorizedInputFormatInterface.class)) continue;
                LOG.info("Input format: " + pd.getInputFileFormatClassName() + ", doesn't provide vectorized input");
                return false;
            }
            return true;
        }

        private boolean checkPartialStatsAvailable(BaseWork base) {
            for (Operator<? extends OperatorDesc> o : base.getAllRootOperators()) {
                if (o.getStatistics().getColumnStatsState() != Statistics.State.NONE) continue;
                return false;
            }
            return true;
        }

        private long computeEdgeSize(BaseWork base, boolean input) {
            long size = 0L;
            for (Operator<? extends OperatorDesc> o : input ? base.getAllRootOperators() : base.getAllLeafOperators()) {
                if (o.getStatistics() == null) {
                    return Long.MAX_VALUE;
                }
                long currSize = o.getStatistics().getDataSize();
                if (currSize < 0L || Long.MAX_VALUE - size < currSize) {
                    return Long.MAX_VALUE;
                }
                size += currSize;
            }
            return size;
        }

        private long computeInputSize(BaseWork base) {
            return this.computeEdgeSize(base, true);
        }

        private long computeOutputSize(BaseWork base) {
            return this.computeEdgeSize(base, false);
        }
    }

    public static enum LlapMode {
        map,
        all,
        none,
        auto;

    }
}

