/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.spark;

import hive.com.google.common.base.Preconditions;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
import org.apache.hadoop.hive.ql.plan.SparkWork;

public class SplitSparkWorkResolver
implements PhysicalPlanResolver {
    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        for (Task<? extends Serializable> task : pctx.getRootTasks()) {
            if (!(task instanceof SparkTask)) continue;
            this.splitSparkWork((SparkWork)((SparkTask)task).getWork());
        }
        return pctx;
    }

    private void splitSparkWork(SparkWork sparkWork) {
        LinkedList<BaseWork> queue = new LinkedList<BaseWork>();
        HashSet<BaseWork> visited = new HashSet<BaseWork>();
        queue.addAll(sparkWork.getRoots());
        while (!queue.isEmpty()) {
            BaseWork work = (BaseWork)queue.poll();
            if (!visited.add(work)) continue;
            List<BaseWork> childWorks = sparkWork.getChildren(work);
            for (BaseWork w : childWorks) {
                queue.add(w);
            }
            this.splitBaseWork(sparkWork, work, childWorks);
        }
    }

    private void splitBaseWork(SparkWork sparkWork, BaseWork parentWork, List<BaseWork> childWorks) {
        if (this.getAllReduceSinks(parentWork).size() <= 1) {
            return;
        }
        List<BaseWork> grandParentWorks = sparkWork.getParents(parentWork);
        boolean isFirst = true;
        for (BaseWork childWork : childWorks) {
            BaseWork clonedParentWork = Utilities.cloneBaseWork(parentWork);
            clonedParentWork.setName(clonedParentWork.getName().replaceAll("^([a-zA-Z]+)(\\s+)(\\d+)", "$1$2" + GenSparkUtils.getUtils().getNextSeqNumber()));
            this.setStatistics(parentWork, clonedParentWork);
            String childReducerName = childWork.getName();
            SparkEdgeProperty clonedEdgeProperty = sparkWork.getEdgeProperty(parentWork, childWork);
            for (Operator<?> op : clonedParentWork.getAllLeafOperators()) {
                if (op instanceof ReduceSinkOperator) {
                    if (((ReduceSinkDesc)((ReduceSinkOperator)op).getConf()).getOutputName().equals(childReducerName)) continue;
                    this.removeOpRecursive(op);
                    continue;
                }
                if (isFirst) continue;
                this.removeOpRecursive(op);
            }
            isFirst = false;
            sparkWork.add(clonedParentWork);
            for (BaseWork gpw : grandParentWorks) {
                sparkWork.connect(gpw, clonedParentWork, sparkWork.getEdgeProperty(gpw, parentWork));
            }
            sparkWork.connect(clonedParentWork, childWork, clonedEdgeProperty);
            sparkWork.getCloneToWork().put(clonedParentWork, parentWork);
        }
        sparkWork.remove(parentWork);
    }

    private Set<Operator<?>> getAllReduceSinks(BaseWork work) {
        Set<Operator<?>> resultSet = work.getAllLeafOperators();
        Iterator<Operator<?>> it = resultSet.iterator();
        while (it.hasNext()) {
            if (it.next() instanceof ReduceSinkOperator) continue;
            it.remove();
        }
        return resultSet;
    }

    private void removeOpRecursive(Operator<?> operator) {
        ArrayList<Operator<OperatorDesc>> parentOperators = new ArrayList<Operator<OperatorDesc>>();
        for (Operator<OperatorDesc> op : operator.getParentOperators()) {
            parentOperators.add(op);
        }
        for (Operator<OperatorDesc> parentOperator : parentOperators) {
            Preconditions.checkArgument(parentOperator.getChildOperators().contains(operator), "AssertionError: parent of " + operator.getName() + " doesn't have it as child.");
            parentOperator.removeChild(operator);
            if (parentOperator.getNumChild() != 0) continue;
            this.removeOpRecursive(parentOperator);
        }
    }

    private void setStatistics(BaseWork origin, BaseWork clone) {
        if (origin instanceof MapWork && clone instanceof MapWork) {
            MapWork originMW = (MapWork)origin;
            MapWork cloneMW = (MapWork)clone;
            for (Map.Entry<String, Operator<? extends OperatorDesc>> entry : originMW.getAliasToWork().entrySet()) {
                String alias = entry.getKey();
                Operator<? extends OperatorDesc> cloneOP = cloneMW.getAliasToWork().get(alias);
                if (cloneOP == null) continue;
                this.setStatistics(entry.getValue(), cloneOP);
            }
        } else if (origin instanceof ReduceWork && clone instanceof ReduceWork) {
            this.setStatistics(((ReduceWork)origin).getReducer(), ((ReduceWork)clone).getReducer());
        }
    }

    private void setStatistics(Operator<? extends OperatorDesc> origin, Operator<? extends OperatorDesc> clone) {
        clone.getConf().setStatistics(origin.getConf().getStatistics());
        clone.getConf().setTraits(origin.getConf().getTraits());
        if (origin.getChildOperators().size() == clone.getChildOperators().size()) {
            for (int i = 0; i < clone.getChildOperators().size(); ++i) {
                this.setStatistics(origin.getChildOperators().get(i), clone.getChildOperators().get(i));
            }
        }
    }
}

