package org.apache.hadoop.hive.ql.optimizer.spark;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.CommonJoinOperator;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.TaskFactory;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.optimizer.GenMapRedUtils;
import org.apache.hadoop.hive.ql.optimizer.physical.GenMRSkewJoinProcessor;
import org.apache.hadoop.hive.ql.optimizer.physical.GenSparkSkewJoinProcessor;
import org.apache.hadoop.hive.ql.optimizer.physical.SkewJoinProcFactory;
import org.apache.hadoop.hive.ql.optimizer.physical.SparkMapJoinResolver;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkSkewJoinResolver;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PlanUtils;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.SMBJoinDesc;
import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
import org.apache.hadoop.hive.ql.plan.SparkWork;
import org.apache.hadoop.hive.ql.plan.TableDesc;

/* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2101.jar:org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory.class */
public class SparkSkewJoinProcFactory {
    private static final Set<JoinOperator> visitedJoinOp = new HashSet();

    /* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2101.jar:org/apache/hadoop/hive/ql/optimizer/spark/SparkSkewJoinProcFactory$SparkSkewJoinJoinProcessor.class */
    public static class SparkSkewJoinJoinProcessor implements NodeProcessor {
        @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
        public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
            SparkSkewJoinResolver.SparkSkewJoinProcCtx sparkSkewJoinProcCtx = (SparkSkewJoinResolver.SparkSkewJoinProcCtx) nodeProcessorCtx;
            Task<? extends Serializable> currentTask = sparkSkewJoinProcCtx.getCurrentTask();
            JoinOperator joinOperator = (JoinOperator) node;
            ReduceWork reduceWork = sparkSkewJoinProcCtx.getReducerToReduceWork().get(joinOperator);
            ParseContext parseCtx = sparkSkewJoinProcCtx.getParseCtx();
            if (reduceWork == null || SparkSkewJoinProcFactory.visitedJoinOp.contains(joinOperator) || !SparkSkewJoinProcFactory.supportRuntimeSkewJoin(joinOperator, reduceWork, currentTask, parseCtx.getConf())) {
                return null;
            }
            SparkSkewJoinProcFactory.splitTask((SparkTask) currentTask, reduceWork, parseCtx);
            GenSparkSkewJoinProcessor.processSkewJoin(joinOperator, currentTask, reduceWork, parseCtx);
            SparkSkewJoinProcFactory.visitedJoinOp.add(joinOperator);
            return null;
        }
    }

    private SparkSkewJoinProcFactory() {
    }

    public static NodeProcessor getDefaultProc() {
        return SkewJoinProcFactory.getDefaultProc();
    }

    public static NodeProcessor getJoinProc() {
        return new SparkSkewJoinJoinProcessor();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public static void splitTask(SparkTask sparkTask, ReduceWork reduceWork, ParseContext parseContext) throws SemanticException {
        SparkWork work = sparkTask.getWork();
        Set<Operator<?>> op = SparkMapJoinResolver.getOp(reduceWork, ReduceSinkOperator.class);
        if (work.getChildren(reduceWork).size() == 1 && canSplit(work) && op.size() == 1) {
            ReduceSinkOperator reduceSinkOperator = (ReduceSinkOperator) op.iterator().next();
            BaseWork baseWork = work.getChildren(reduceWork).get(0);
            SparkEdgeProperty edgeProperty = work.getEdgeProperty(reduceWork, baseWork);
            work.disconnect(reduceWork, baseWork);
            SparkWork sparkWork = new SparkWork(parseContext.getConf().getVar(HiveConf.ConfVars.HIVEQUERYID));
            sparkWork.add(baseWork);
            copyWorkGraph(work, sparkWork, baseWork);
            for (BaseWork baseWork2 : sparkWork.getAllWorkUnsorted()) {
                work.remove(baseWork2);
                work.getCloneToWork().remove(baseWork2);
            }
            Path mRTmpPath = parseContext.getContext().getMRTmpPath();
            Operator<? extends OperatorDesc> operator = reduceSinkOperator.getParentOperators().get(0);
            TableDesc intermediateFileTableDesc = PlanUtils.getIntermediateFileTableDesc(PlanUtils.getFieldSchemasFromRowSchema(operator.getSchema(), "temporarycol"));
            TableScanOperator createTemporaryFile = GenMapRedUtils.createTemporaryFile(operator, reduceSinkOperator, mRTmpPath, intermediateFileTableDesc, parseContext);
            MapWork mapWork = PlanUtils.getMapRedWork().getMapWork();
            mapWork.setName(Utilities.MAPNAME + GenSparkUtils.getUtils().getNextSeqNumber());
            sparkWork.add(mapWork);
            sparkWork.connect(mapWork, baseWork, edgeProperty);
            String uri = mRTmpPath.toUri().toString();
            if (GenMapRedUtils.needsTagging((ReduceWork) baseWork)) {
                Operator<?> reducer = ((ReduceWork) baseWork).getReducer();
                String str = null;
                if (reducer instanceof JoinOperator) {
                    if (parseContext.getJoinOps().contains(reducer)) {
                        str = ((JoinDesc) ((JoinOperator) reducer).getConf()).getId();
                    }
                } else if (reducer instanceof MapJoinOperator) {
                    if (parseContext.getMapJoinOps().contains(reducer)) {
                        str = ((MapJoinDesc) ((MapJoinOperator) reducer).getConf()).getId();
                    }
                } else if ((reducer instanceof SMBMapJoinOperator) && parseContext.getSmbMapJoinOps().contains(reducer)) {
                    str = ((SMBJoinDesc) ((SMBMapJoinOperator) reducer).getConf()).getId();
                }
                uri = str != null ? str + ":$INTNAME" : "$INTNAME";
                String str2 = uri;
                int i = 0;
                while (mapWork.getAliasToWork().get(uri) != null) {
                    i++;
                    uri = str2.concat(String.valueOf(i));
                }
            }
            GenMapRedUtils.setTaskPlan(mRTmpPath, uri, (Operator<? extends OperatorDesc>) createTemporaryFile, mapWork, false, intermediateFileTableDesc);
            Task<? extends Serializable> task = TaskFactory.get(sparkWork, parseContext.getConf(), new Task[0]);
            List<Task<? extends Serializable>> childTasks = sparkTask.getChildTasks();
            if (childTasks != null && childTasks.size() > 0) {
                Task<? extends Serializable> task2 = childTasks.get(0);
                sparkTask.removeDependentTask(task2);
                task.addDependentTask(task2);
            }
            sparkTask.addDependentTask(task);
            task.setFetchSource(sparkTask.isFetchSource());
        }
    }

    private static boolean canSplit(SparkWork sparkWork) {
        Iterator<BaseWork> it = sparkWork.getAllWorkUnsorted().iterator();
        while (it.hasNext()) {
            if (sparkWork.getChildren(it.next()).size() > 1) {
                return false;
            }
        }
        return true;
    }

    private static void copyWorkGraph(SparkWork sparkWork, SparkWork sparkWork2, BaseWork baseWork) {
        for (BaseWork baseWork2 : sparkWork.getChildren(baseWork)) {
            if (!sparkWork2.contains(baseWork2)) {
                sparkWork2.add(baseWork2);
                sparkWork2.connect(baseWork, baseWork2, sparkWork.getEdgeProperty(baseWork, baseWork2));
                copyWorkGraph(sparkWork, sparkWork2, baseWork2);
            }
        }
        for (BaseWork baseWork3 : sparkWork.getParents(baseWork)) {
            if (!sparkWork2.contains(baseWork3)) {
                sparkWork2.add(baseWork3);
                sparkWork2.connect(baseWork3, baseWork, sparkWork.getEdgeProperty(baseWork3, baseWork));
                copyWorkGraph(sparkWork, sparkWork2, baseWork3);
            }
        }
    }

    public static Set<JoinOperator> getVisitedJoinOp() {
        return visitedJoinOp;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public static boolean supportRuntimeSkewJoin(JoinOperator joinOperator, ReduceWork reduceWork, Task<? extends Serializable> task, HiveConf hiveConf) {
        if (!(task instanceof SparkTask) || !GenMRSkewJoinProcessor.skewJoinEnabled(hiveConf, joinOperator)) {
            return false;
        }
        SparkWork work = ((SparkTask) task).getWork();
        List<Task<? extends Serializable>> childTasks = task.getChildTasks();
        return !((JoinDesc) joinOperator.getConf()).isFixedAsSorted() && work.contains(reduceWork) && (childTasks == null || childTasks.size() <= 1) && SparkMapJoinResolver.getOp(reduceWork, CommonJoinOperator.class).size() == 1;
    }
}
