package org.apache.hadoop.hive.ql.optimizer.physical;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Stack;
import java.util.TreeSet;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.StatsTask;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.tez.TezTask;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.MergeJoinWork;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2009.jar:org/apache/hadoop/hive/ql/optimizer/physical/MemoryDecider.class */
public class MemoryDecider implements PhysicalPlanResolver {
    protected static final transient Logger LOG = LoggerFactory.getLogger(MemoryDecider.class);

    /* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2009.jar:org/apache/hadoop/hive/ql/optimizer/physical/MemoryDecider$MemoryCalculator.class */
    public class MemoryCalculator implements Dispatcher {
        private final long totalAvailableMemory;
        private final long minimumHashTableSize;
        private final double inflationFactor;
        private final PhysicalContext pctx;

        /* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2009.jar:org/apache/hadoop/hive/ql/optimizer/physical/MemoryDecider$MemoryCalculator$DefaultRule.class */
        public class DefaultRule implements NodeProcessor {
            public DefaultRule() {
            }

            @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
            public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
                return null;
            }
        }

        public MemoryCalculator(PhysicalContext physicalContext) {
            this.pctx = physicalContext;
            this.totalAvailableMemory = HiveConf.getLongVar(physicalContext.conf, HiveConf.ConfVars.HIVECONVERTJOINNOCONDITIONALTASKTHRESHOLD);
            this.minimumHashTableSize = HiveConf.getIntVar(physicalContext.conf, HiveConf.ConfVars.HIVEHYBRIDGRACEHASHJOINMINNUMPARTITIONS) * HiveConf.getIntVar(physicalContext.conf, HiveConf.ConfVars.HIVEHYBRIDGRACEHASHJOINMINWBSIZE);
            this.inflationFactor = HiveConf.getFloatVar(physicalContext.conf, HiveConf.ConfVars.HIVE_HASH_TABLE_INFLATION_FACTOR);
        }

        @Override // org.apache.hadoop.hive.ql.lib.Dispatcher
        public Object dispatch(Node node, Stack<Node> stack, Object... objArr) throws SemanticException {
            Task task = (Task) node;
            if (task instanceof StatsTask) {
                task = ((StatsTask) task).getWork().getSourceTask();
            }
            if (!(task instanceof TezTask)) {
                return null;
            }
            Iterator<BaseWork> it = ((TezTask) task).getWork().getAllWork().iterator();
            while (it.hasNext()) {
                evaluateWork(it.next());
            }
            return null;
        }

        private void evaluateWork(BaseWork baseWork) throws SemanticException {
            if (baseWork instanceof MapWork) {
                evaluateMapWork((MapWork) baseWork);
                return;
            }
            if (baseWork instanceof ReduceWork) {
                evaluateReduceWork((ReduceWork) baseWork);
            } else if (baseWork instanceof MergeJoinWork) {
                evaluateMergeWork((MergeJoinWork) baseWork);
            } else {
                MemoryDecider.LOG.info("We are not going to evaluate this work type: " + baseWork.getClass().getCanonicalName());
            }
        }

        private void evaluateMergeWork(MergeJoinWork mergeJoinWork) throws SemanticException {
            Iterator<BaseWork> it = mergeJoinWork.getBaseWorkList().iterator();
            while (it.hasNext()) {
                evaluateOperators(it.next(), this.pctx);
            }
        }

        private void evaluateReduceWork(ReduceWork reduceWork) throws SemanticException {
            evaluateOperators(reduceWork, this.pctx);
        }

        private void evaluateMapWork(MapWork mapWork) throws SemanticException {
            evaluateOperators(mapWork, this.pctx);
        }

        /* JADX WARN: Multi-variable type inference failed */
        private void evaluateOperators(BaseWork baseWork, PhysicalContext physicalContext) throws SemanticException {
            final LinkedHashSet<MapJoinOperator> linkedHashSet = new LinkedHashSet();
            HashMap hashMap = new HashMap();
            hashMap.put(new RuleRegExp("Map join memory estimator", MapJoinOperator.getOperatorName() + "%"), new NodeProcessor() { // from class: org.apache.hadoop.hive.ql.optimizer.physical.MemoryDecider.MemoryCalculator.1
                @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
                public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) {
                    linkedHashSet.add((MapJoinOperator) node);
                    return null;
                }
            });
            DefaultGraphWalker defaultGraphWalker = new DefaultGraphWalker(new DefaultRuleDispatcher(null, hashMap, null));
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(baseWork.getAllRootOperators());
            defaultGraphWalker.startWalking(arrayList, new LinkedHashMap());
            if (linkedHashSet.size() == 0) {
                return;
            }
            try {
                long j = 0;
                final HashMap hashMap2 = new HashMap();
                final HashMap hashMap3 = new HashMap();
                int i = 0;
                for (MapJoinOperator mapJoinOperator : linkedHashSet) {
                    long computeSizeToFitInMem = computeSizeToFitInMem(mapJoinOperator);
                    hashMap2.put(mapJoinOperator, Long.valueOf(computeSizeToFitInMem));
                    int i2 = i;
                    i++;
                    hashMap3.put(mapJoinOperator, Integer.valueOf(i2));
                    j += computeSizeToFitInMem;
                }
                TreeSet<MapJoinOperator> treeSet = new TreeSet(new Comparator<MapJoinOperator>() { // from class: org.apache.hadoop.hive.ql.optimizer.physical.MemoryDecider.MemoryCalculator.2
                    @Override // java.util.Comparator
                    public int compare(MapJoinOperator mapJoinOperator2, MapJoinOperator mapJoinOperator3) {
                        if (mapJoinOperator2 == null || mapJoinOperator3 == null) {
                            throw new NullPointerException();
                        }
                        int compare = Long.compare(((Long) hashMap2.get(mapJoinOperator2)).longValue(), ((Long) hashMap2.get(mapJoinOperator3)).longValue());
                        if (compare == 0) {
                            compare = Integer.compare(((Integer) hashMap3.get(mapJoinOperator2)).intValue(), ((Integer) hashMap3.get(mapJoinOperator3)).intValue());
                        }
                        return compare;
                    }
                });
                treeSet.addAll(linkedHashSet);
                long j2 = this.totalAvailableMemory / 2;
                Iterator it = treeSet.iterator();
                long j3 = 0;
                while (it.hasNext()) {
                    MapJoinOperator mapJoinOperator2 = (MapJoinOperator) it.next();
                    long longValue = ((Long) hashMap2.get(mapJoinOperator2)).longValue();
                    if (MemoryDecider.LOG.isDebugEnabled()) {
                        MemoryDecider.LOG.debug("MapJoin: " + mapJoinOperator2 + ", size: " + longValue + ", remaining: " + j2);
                    }
                    if (longValue < j2) {
                        if (MemoryDecider.LOG.isInfoEnabled()) {
                            MemoryDecider.LOG.info("Setting " + longValue + " bytes needed for " + mapJoinOperator2 + " (in-mem)");
                        }
                        ((MapJoinDesc) mapJoinOperator2.getConf()).setMemoryNeeded(longValue);
                        j2 -= longValue;
                        it.remove();
                    } else {
                        j3 += ((Long) hashMap2.get(mapJoinOperator2)).longValue();
                    }
                }
                if (treeSet.isEmpty()) {
                    treeSet.addAll(linkedHashSet);
                    j3 = j;
                    if (j3 > this.totalAvailableMemory) {
                        throw new HiveException();
                    }
                    j2 = this.totalAvailableMemory / 2;
                }
                double d = (j2 + (this.totalAvailableMemory / 2)) / j3;
                for (MapJoinOperator mapJoinOperator3 : treeSet) {
                    long longValue2 = (long) (d * ((Long) hashMap2.get(mapJoinOperator3)).longValue());
                    if (MemoryDecider.LOG.isInfoEnabled()) {
                        MemoryDecider.LOG.info("Setting " + longValue2 + " bytes needed for " + mapJoinOperator3 + " (spills)");
                    }
                    ((MapJoinDesc) mapJoinOperator3.getConf()).setMemoryNeeded(longValue2);
                }
            } catch (HiveException e) {
                long size = this.totalAvailableMemory / linkedHashSet.size();
                if (MemoryDecider.LOG.isInfoEnabled()) {
                    MemoryDecider.LOG.info("Scaling mapjoin memory w/o stats");
                }
                for (MapJoinOperator mapJoinOperator4 : linkedHashSet) {
                    if (MemoryDecider.LOG.isInfoEnabled()) {
                        MemoryDecider.LOG.info("Setting " + size + " bytes needed for " + mapJoinOperator4 + " (fallback)");
                    }
                    ((MapJoinDesc) mapJoinOperator4.getConf()).setMemoryNeeded(size);
                }
            }
        }

        private long computeSizeToFitInMem(MapJoinOperator mapJoinOperator) throws HiveException {
            return (long) (Math.max(this.minimumHashTableSize, computeInputSize(mapJoinOperator)) * this.inflationFactor);
        }

        /* JADX WARN: Multi-variable type inference failed */
        private long computeInputSize(MapJoinOperator mapJoinOperator) throws HiveException {
            long j = 0;
            if (mapJoinOperator.getConf() != 0 && ((MapJoinDesc) mapJoinOperator.getConf()).getParentDataSizes() != null) {
                Iterator<Long> it = ((MapJoinDesc) mapJoinOperator.getConf()).getParentDataSizes().values().iterator();
                while (it.hasNext()) {
                    j += it.next().longValue();
                }
            }
            if (j == 0) {
                throw new HiveException("No data sizes");
            }
            return j;
        }
    }

    @Override // org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver
    public PhysicalContext resolve(PhysicalContext physicalContext) throws SemanticException {
        physicalContext.getConf();
        TaskGraphWalker taskGraphWalker = new TaskGraphWalker(new MemoryCalculator(physicalContext));
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(physicalContext.getRootTasks());
        taskGraphWalker.startWalking(arrayList, null);
        return physicalContext;
    }
}
