package org.apache.hadoop.hive.ql.parse.spark;

import hive.com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorUtils;
import org.apache.hadoop.hive.ql.exec.SerializationUtilities;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.spark.SparkUtilities;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;

/* loaded from: input_file:org/apache/hadoop/hive/ql/parse/spark/SplitOpTreeForDPP.class */
public class SplitOpTreeForDPP implements NodeProcessor {
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
    public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
        SparkPartitionPruningSinkOperator sparkPartitionPruningSinkOperator = (SparkPartitionPruningSinkOperator) node;
        GenSparkProcContext genSparkProcContext = (GenSparkProcContext) nodeProcessorCtx;
        Iterator<Operator<?>> it = genSparkProcContext.pruningSinkSet.iterator();
        while (it.hasNext()) {
            if (sparkPartitionPruningSinkOperator.getOperatorId().equals(it.next().getOperatorId())) {
                return null;
            }
        }
        if (sparkPartitionPruningSinkOperator.isWithMapjoin()) {
            genSparkProcContext.pruningSinkSet.add(sparkPartitionPruningSinkOperator);
            return null;
        }
        List<Operator<?>> linkedList = new LinkedList<>();
        collectRoots(linkedList, sparkPartitionPruningSinkOperator);
        Operator<?> branchingOp = sparkPartitionPruningSinkOperator.getBranchingOp();
        String str = "SPARK_DPP_BRANCH_POINT_" + branchingOp.getOperatorId();
        branchingOp.setMarker(str);
        List<Operator<? extends OperatorDesc>> childOperators = branchingOp.getChildOperators();
        List<Operator<?>> findFirstNodesOfPruningBranch = findFirstNodesOfPruningBranch(branchingOp);
        branchingOp.setChildOperators(null);
        List<Operator<?>> cloneOperatorTree = SerializationUtilities.cloneOperatorTree(linkedList);
        for (int i = 0; i < linkedList.size(); i++) {
            ((TableScanOperator) cloneOperatorTree.get(i)).getConf().setTableMetadata(((TableScanOperator) linkedList.get(i)).getConf().getTableMetadata());
        }
        genSparkProcContext.clonedPruningTableScanSet.addAll(cloneOperatorTree);
        Operator<?> operator = null;
        for (int i2 = 0; i2 < cloneOperatorTree.size() && operator == null; i2++) {
            operator = OperatorUtils.findOperatorByMarker(cloneOperatorTree.get(i2), str);
        }
        Preconditions.checkNotNull(operator, "Cannot find the branching operator in cloned tree.");
        operator.setChildOperators(findFirstNodesOfPruningBranch);
        branchingOp.setChildOperators(childOperators);
        Iterator<Operator<?>> it2 = findFirstNodesOfPruningBranch.iterator();
        while (it2.hasNext()) {
            branchingOp.removeChild(it2.next());
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (Operator<?> operator2 : findFirstNodesOfPruningBranch) {
            SparkUtilities.collectOp(linkedHashSet, operator2, SparkPartitionPruningSinkOperator.class);
            operator2.setParentOperators(Utilities.makeList(operator));
        }
        genSparkProcContext.pruningSinkSet.addAll(linkedHashSet);
        return null;
    }

    private List<Operator<?>> findFirstNodesOfPruningBranch(Operator<?> operator) {
        ArrayList arrayList = new ArrayList();
        for (Operator<? extends OperatorDesc> operator2 : operator.getChildOperators()) {
            if (SparkUtilities.isDirectDPPBranch(operator2)) {
                arrayList.add(operator2);
            }
        }
        return arrayList;
    }

    private void collectRoots(List<Operator<?>> list, Operator<?> operator) {
        if (operator.getNumParent() == 0) {
            list.add(operator);
            return;
        }
        Iterator<Operator<? extends OperatorDesc>> it = operator.getParentOperators().iterator();
        while (it.hasNext()) {
            collectRoots(list, it.next());
        }
    }
}
