package org.apache.hadoop.hive.ql.optimizer.spark;

import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.SerializationUtilities;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
import org.apache.hadoop.hive.ql.plan.SparkWork;

/* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.9-mapr-640-core.jar:org/apache/hadoop/hive/ql/optimizer/spark/SplitSparkWorkResolver.class */
public class SplitSparkWorkResolver implements PhysicalPlanResolver {
    @Override // org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver
    public PhysicalContext resolve(PhysicalContext physicalContext) throws SemanticException {
        for (Task<? extends Serializable> task : physicalContext.getRootTasks()) {
            if (task instanceof SparkTask) {
                splitSparkWork(((SparkTask) task).getWork());
            }
        }
        return physicalContext;
    }

    private void splitSparkWork(SparkWork sparkWork) {
        LinkedList linkedList = new LinkedList();
        HashSet hashSet = new HashSet();
        linkedList.addAll(sparkWork.getRoots());
        while (!linkedList.isEmpty()) {
            BaseWork baseWork = (BaseWork) linkedList.poll();
            if (hashSet.add(baseWork)) {
                List<BaseWork> children = sparkWork.getChildren(baseWork);
                Iterator<BaseWork> it = children.iterator();
                while (it.hasNext()) {
                    linkedList.add(it.next());
                }
                splitBaseWork(sparkWork, baseWork, children);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void splitBaseWork(SparkWork sparkWork, BaseWork baseWork, List<BaseWork> list) {
        if (getAllReduceSinks(baseWork).size() <= 1) {
            return;
        }
        List<BaseWork> parents = sparkWork.getParents(baseWork);
        boolean z = true;
        for (BaseWork baseWork2 : list) {
            BaseWork cloneBaseWork = SerializationUtilities.cloneBaseWork(baseWork);
            cloneBaseWork.setName(cloneBaseWork.getName().replaceAll("^([a-zA-Z]+)(\\s+)(\\d+)", "$1$2" + GenSparkUtils.getUtils().getNextSeqNumber()));
            setStatistics(baseWork, cloneBaseWork);
            String name = baseWork2.getName();
            SparkEdgeProperty edgeProperty = sparkWork.getEdgeProperty(baseWork, baseWork2);
            for (Operator<? extends OperatorDesc> operator : cloneBaseWork.getAllLeafOperators()) {
                if (operator instanceof ReduceSinkOperator) {
                    if (!((ReduceSinkDesc) ((ReduceSinkOperator) operator).getConf()).getOutputName().equals(name)) {
                        removeOpRecursive(operator);
                    }
                } else if (!z) {
                    removeOpRecursive(operator);
                }
            }
            z = false;
            sparkWork.add(cloneBaseWork);
            for (BaseWork baseWork3 : parents) {
                sparkWork.connect(baseWork3, cloneBaseWork, sparkWork.getEdgeProperty(baseWork3, baseWork));
            }
            sparkWork.connect(cloneBaseWork, baseWork2, edgeProperty);
            sparkWork.getCloneToWork().put(cloneBaseWork, baseWork);
        }
        sparkWork.remove(baseWork);
    }

    private Set<Operator<?>> getAllReduceSinks(BaseWork baseWork) {
        Set<Operator<? extends OperatorDesc>> allLeafOperators = baseWork.getAllLeafOperators();
        Iterator<Operator<? extends OperatorDesc>> it = allLeafOperators.iterator();
        while (it.hasNext()) {
            if (!(it.next() instanceof ReduceSinkOperator)) {
                it.remove();
            }
        }
        return allLeafOperators;
    }

    private void removeOpRecursive(Operator<?> operator) {
        ArrayList<Operator<?>> arrayList = new ArrayList();
        Iterator<Operator<? extends OperatorDesc>> it = operator.getParentOperators().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        for (Operator<?> operator2 : arrayList) {
            Preconditions.checkArgument(operator2.getChildOperators().contains(operator), "AssertionError: parent of " + operator.getName() + " doesn't have it as child.");
            operator2.removeChild(operator);
            if (operator2.getNumChild() == 0) {
                removeOpRecursive(operator2);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void setStatistics(BaseWork baseWork, BaseWork baseWork2) {
        if (!(baseWork instanceof MapWork) || !(baseWork2 instanceof MapWork)) {
            if ((baseWork instanceof ReduceWork) && (baseWork2 instanceof ReduceWork)) {
                setStatistics((Operator<? extends OperatorDesc>) ((ReduceWork) baseWork).getReducer(), (Operator<? extends OperatorDesc>) ((ReduceWork) baseWork2).getReducer());
                return;
            }
            return;
        }
        MapWork mapWork = (MapWork) baseWork2;
        for (Map.Entry<String, Operator<? extends OperatorDesc>> entry : ((MapWork) baseWork).getAliasToWork().entrySet()) {
            Operator<? extends OperatorDesc> operator = mapWork.getAliasToWork().get(entry.getKey());
            if (operator != null) {
                setStatistics(entry.getValue(), operator);
            }
        }
    }

    private void setStatistics(Operator<? extends OperatorDesc> operator, Operator<? extends OperatorDesc> operator2) {
        operator2.getConf().setStatistics(operator.getConf().getStatistics());
        operator2.getConf().setTraits(operator.getConf().getTraits());
        if (operator.getChildOperators().size() == operator2.getChildOperators().size()) {
            for (int i = 0; i < operator2.getChildOperators().size(); i++) {
                setStatistics(operator.getChildOperators().get(i), operator2.getChildOperators().get(i));
            }
        }
    }
}
