package org.apache.hadoop.hive.ql.optimizer.spark;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.SparkHashTableSinkOperator;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkProcContext;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.HashTableDummyDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PlanUtils;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
import org.apache.hadoop.hive.ql.plan.SparkHashTableSinkDesc;
import org.apache.hadoop.hive.ql.plan.SparkWork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.3-mapr-1901.jar:org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.class */
public class SparkReduceSinkMapJoinProc implements NodeProcessor {
    public static final Logger LOG = LoggerFactory.getLogger(SparkReduceSinkMapJoinProc.class.getName());

    /* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.3-mapr-1901.jar:org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc$SparkMapJoinFollowedByGroupByProcessor.class */
    public static class SparkMapJoinFollowedByGroupByProcessor implements NodeProcessor {
        private boolean hasGroupBy = false;

        @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
        public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
            this.hasGroupBy = true;
            ((GroupByOperator) node).getConf().setGroupByMemoryUsage(((GenSparkProcContext) nodeProcessorCtx).conf.getFloatVar(HiveConf.ConfVars.HIVEMAPJOINFOLLOWEDBYMAPAGGRHASHMEMORY));
            return null;
        }

        public boolean getHasGroupBy() {
            return this.hasGroupBy;
        }
    }

    private boolean hasGroupBy(Operator<? extends OperatorDesc> operator, GenSparkProcContext genSparkProcContext) throws SemanticException {
        List<Operator<? extends OperatorDesc>> childOperators = operator.getChildOperators();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        SparkMapJoinFollowedByGroupByProcessor sparkMapJoinFollowedByGroupByProcessor = new SparkMapJoinFollowedByGroupByProcessor();
        linkedHashMap.put(new RuleRegExp("GBY", GroupByOperator.getOperatorName() + "%"), sparkMapJoinFollowedByGroupByProcessor);
        DefaultGraphWalker defaultGraphWalker = new DefaultGraphWalker(new DefaultRuleDispatcher(null, linkedHashMap, genSparkProcContext));
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(childOperators);
        defaultGraphWalker.startWalking(arrayList, null);
        return sparkMapJoinFollowedByGroupByProcessor.getHasGroupBy();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v215, types: [java.util.Map] */
    /* JADX WARN: Type inference failed for: r5v0, types: [org.apache.hadoop.hive.ql.optimizer.spark.SparkReduceSinkMapJoinProc] */
    @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
    public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
        GenSparkProcContext genSparkProcContext = (GenSparkProcContext) nodeProcessorCtx;
        if (!node.getClass().equals(MapJoinOperator.class)) {
            return null;
        }
        MapJoinOperator mapJoinOperator = (MapJoinOperator) node;
        if (stack.size() < 2 || !(stack.get(stack.size() - 2) instanceof ReduceSinkOperator)) {
            genSparkProcContext.currentMapJoinOperators.add(mapJoinOperator);
            return null;
        }
        genSparkProcContext.preceedingWork = null;
        genSparkProcContext.currentRootOperator = null;
        ReduceSinkOperator reduceSinkOperator = (ReduceSinkOperator) stack.get(stack.size() - 2);
        ((ReduceSinkDesc) reduceSinkOperator.getConf()).setSkipTag(true);
        reduceSinkOperator.setSkipTag(true);
        if (!genSparkProcContext.mapJoinParentMap.containsKey(mapJoinOperator)) {
            genSparkProcContext.mapJoinParentMap.put(mapJoinOperator, new ArrayList(mapJoinOperator.getParentOperators()));
        }
        List<BaseWork> list = genSparkProcContext.mapJoinWorkMap.get(mapJoinOperator);
        int size = genSparkProcContext.childToWorkMap.get(reduceSinkOperator).size();
        Preconditions.checkArgument(size == 1, "AssertionError: expected context.childToWorkMap.get(parentRS).size() to be 1, but was " + size);
        BaseWork baseWork = genSparkProcContext.childToWorkMap.get(reduceSinkOperator).get(0);
        int indexOf = genSparkProcContext.mapJoinParentMap.get(mapJoinOperator).indexOf(reduceSinkOperator);
        if (indexOf == -1) {
            throw new SemanticException("Cannot find position of parent in mapjoin");
        }
        LOG.debug("Mapjoin " + mapJoinOperator + ", pos: " + indexOf + " --> " + baseWork.getName());
        ((MapJoinDesc) mapJoinOperator.getConf()).getParentToInput().put(Integer.valueOf(indexOf), baseWork.getName());
        SparkEdgeProperty sparkEdgeProperty = new SparkEdgeProperty(0L);
        if (list != null) {
            for (BaseWork baseWork2 : list) {
                SparkWork work = genSparkProcContext.currentTask.getWork();
                LOG.debug("connecting " + baseWork.getName() + " with " + baseWork2.getName());
                work.connect(baseWork, baseWork2, sparkEdgeProperty);
            }
        }
        HashMap hashMap = genSparkProcContext.linkOpWithWorkMap.containsKey(mapJoinOperator) ? (Map) genSparkProcContext.linkOpWithWorkMap.get(mapJoinOperator) : new HashMap();
        hashMap.put(baseWork, sparkEdgeProperty);
        genSparkProcContext.linkOpWithWorkMap.put(mapJoinOperator, hashMap);
        List<ReduceSinkOperator> list2 = genSparkProcContext.linkWorkWithReduceSinkMap.get(baseWork);
        if (list2 == null) {
            list2 = new ArrayList();
        }
        list2.add(reduceSinkOperator);
        genSparkProcContext.linkWorkWithReduceSinkMap.put(baseWork, list2);
        ArrayList arrayList = new ArrayList();
        HashTableDummyOperator hashTableDummyOperator = (HashTableDummyOperator) OperatorFactory.get(mapJoinOperator.getCompilationOpContext(), new HashTableDummyDesc());
        hashTableDummyOperator.getConf().setTbl(PlanUtils.getReduceValueTableDesc(PlanUtils.getFieldSchemasFromRowSchema(reduceSinkOperator.getParentOperators().get(0).getSchema(), "")));
        List<ExprNodeDesc> list3 = ((MapJoinDesc) mapJoinOperator.getConf()).getKeys().get((byte) 0);
        StringBuilder sb = new StringBuilder();
        StringBuilder sb2 = new StringBuilder();
        for (int i = 0; i < list3.size(); i++) {
            sb.append("+");
            sb2.append("a");
        }
        ((MapJoinDesc) mapJoinOperator.getConf()).setKeyTableDesc(PlanUtils.getReduceKeyTableDesc(PlanUtils.getFieldSchemasFromColumnList(list3, "mapjoinkey"), sb.toString(), sb2.toString()));
        mapJoinOperator.replaceParent(reduceSinkOperator, hashTableDummyOperator);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(mapJoinOperator);
        hashTableDummyOperator.setChildOperators(arrayList2);
        arrayList.add(hashTableDummyOperator);
        List<Operator<? extends OperatorDesc>> childOperators = reduceSinkOperator.getChildOperators();
        childOperators.remove(childOperators.indexOf(mapJoinOperator));
        if (list != null) {
            Iterator<BaseWork> it = list.iterator();
            while (it.hasNext()) {
                it.next().addDummyOp(hashTableDummyOperator);
            }
        }
        if (genSparkProcContext.linkChildOpWithDummyOp.containsKey(mapJoinOperator)) {
            Iterator<Operator<?>> it2 = genSparkProcContext.linkChildOpWithDummyOp.get(mapJoinOperator).iterator();
            while (it2.hasNext()) {
                arrayList.add(it2.next());
            }
        }
        genSparkProcContext.linkChildOpWithDummyOp.put(mapJoinOperator, arrayList);
        MapJoinDesc mapJoinDesc = (MapJoinDesc) mapJoinOperator.getConf();
        HiveConf hiveConf = genSparkProcContext.conf;
        mapJoinDesc.resetOrder();
        mapJoinDesc.setHashTableMemoryUsage(hasGroupBy(mapJoinOperator, genSparkProcContext) ? hiveConf.getFloatVar(HiveConf.ConfVars.HIVEHASHTABLEFOLLOWBYGBYMAXMEMORYUSAGE) : hiveConf.getFloatVar(HiveConf.ConfVars.HIVEHASHTABLEMAXMEMORYUSAGE));
        SparkHashTableSinkDesc sparkHashTableSinkDesc = new SparkHashTableSinkDesc(mapJoinDesc);
        SparkHashTableSinkOperator sparkHashTableSinkOperator = (SparkHashTableSinkOperator) OperatorFactory.get(mapJoinOperator.getCompilationOpContext(), sparkHashTableSinkDesc);
        byte b = (byte) indexOf;
        int[] valueIndex = mapJoinDesc.getValueIndex(b);
        if (valueIndex != null) {
            ArrayList arrayList3 = new ArrayList();
            List<ExprNodeDesc> list4 = sparkHashTableSinkDesc.getExprs().get(Byte.valueOf(b));
            for (int i2 = 0; i2 < list4.size(); i2++) {
                if (valueIndex[i2] < 0) {
                    arrayList3.add(list4.get(i2));
                }
            }
            sparkHashTableSinkDesc.getExprs().put(Byte.valueOf(b), arrayList3);
        }
        List<Operator<? extends OperatorDesc>> parentOperators = reduceSinkOperator.getParentOperators();
        Iterator<Operator<? extends OperatorDesc>> it3 = parentOperators.iterator();
        while (it3.hasNext()) {
            it3.next().replaceChild(reduceSinkOperator, sparkHashTableSinkOperator);
        }
        sparkHashTableSinkOperator.setParentOperators(parentOperators);
        ((SparkHashTableSinkDesc) sparkHashTableSinkOperator.getConf()).setTag(b);
        return true;
    }
}
