package org.apache.hadoop.hive.ql.ppd;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import jodd.util.StringPool;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.CommonJoinOperator;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.PreOrderWalker;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDynamicListDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.plan.JoinCondDesc;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;

/* loaded from: input_file:WEB-INF/lib/hive-exec-1.2.0-mapr-1707.jar:org/apache/hadoop/hive/ql/ppd/SyntheticJoinPredicate.class */
public class SyntheticJoinPredicate implements Transform {
    private static transient Log LOG = LogFactory.getLog(SyntheticJoinPredicate.class.getName());

    /* loaded from: input_file:WEB-INF/lib/hive-exec-1.2.0-mapr-1707.jar:org/apache/hadoop/hive/ql/ppd/SyntheticJoinPredicate$JoinSynthetic.class */
    private static class JoinSynthetic implements NodeProcessor {
        private JoinSynthetic() {
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
        public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
            ExprNodeGenericFuncDesc exprNodeGenericFuncDesc;
            ((SyntheticContext) nodeProcessorCtx).getParseContext();
            CommonJoinOperator<JoinDesc> commonJoinOperator = (CommonJoinOperator) node;
            ReduceSinkOperator reduceSinkOperator = (ReduceSinkOperator) stack.get(stack.size() - 2);
            int indexOf = commonJoinOperator.getParentOperators().indexOf(reduceSinkOperator);
            List<Operator<? extends OperatorDesc>> parentOperators = commonJoinOperator.getParentOperators();
            int[][] targets = getTargets(commonJoinOperator);
            Operator<? extends OperatorDesc> operator = reduceSinkOperator.getParentOperators().get(0);
            RowSchema schema = operator.getSchema();
            if (((JoinDesc) commonJoinOperator.getConf()).getNullSafes() != null) {
                for (boolean z : ((JoinDesc) commonJoinOperator.getConf()).getNullSafes()) {
                    if (z) {
                        return null;
                    }
                }
            }
            for (int i : targets[indexOf]) {
                if (indexOf != i) {
                    if (SyntheticJoinPredicate.LOG.isDebugEnabled()) {
                        SyntheticJoinPredicate.LOG.debug("Synthetic predicate: " + indexOf + " --> " + i);
                    }
                    ReduceSinkOperator reduceSinkOperator2 = (ReduceSinkOperator) parentOperators.get(i);
                    ArrayList<ExprNodeDesc> keyCols = ((ReduceSinkDesc) reduceSinkOperator.getConf()).getKeyCols();
                    ArrayList<ExprNodeDesc> keyCols2 = ((ReduceSinkDesc) reduceSinkOperator2.getConf()).getKeyCols();
                    if (keyCols.size() >= 1) {
                        ExprNodeGenericFuncDesc exprNodeGenericFuncDesc2 = null;
                        for (int i2 = 0; i2 < keyCols.size(); i2++) {
                            ArrayList arrayList = new ArrayList();
                            arrayList.add(keyCols.get(i2));
                            arrayList.add(new ExprNodeDynamicListDesc(keyCols2.get(i2).getTypeInfo(), reduceSinkOperator2, i2));
                            ExprNodeGenericFuncDesc newInstance = ExprNodeGenericFuncDesc.newInstance(FunctionRegistry.getFunctionInfo("in").getGenericUDF(), arrayList);
                            if (exprNodeGenericFuncDesc2 != null) {
                                ArrayList arrayList2 = new ArrayList();
                                arrayList2.add(exprNodeGenericFuncDesc2);
                                arrayList2.add(newInstance);
                                exprNodeGenericFuncDesc = ExprNodeGenericFuncDesc.newInstance(FunctionRegistry.getFunctionInfo(StringPool.AND).getGenericUDF(), arrayList2);
                            } else {
                                exprNodeGenericFuncDesc = newInstance;
                            }
                            exprNodeGenericFuncDesc2 = exprNodeGenericFuncDesc;
                        }
                        operator = SyntheticJoinPredicate.createFilter(reduceSinkOperator, operator, schema, exprNodeGenericFuncDesc2);
                    }
                }
            }
            return null;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v14, types: [int[], int[][]] */
        private int[][] getTargets(CommonJoinOperator<JoinDesc> commonJoinOperator) {
            JoinCondDesc[] conds = ((JoinDesc) commonJoinOperator.getConf()).getConds();
            int length = conds.length + 1;
            Vectors vectors = new Vectors(length);
            for (JoinCondDesc joinCondDesc : conds) {
                int left = joinCondDesc.getLeft();
                int right = joinCondDesc.getRight();
                switch (joinCondDesc.getType()) {
                    case 0:
                    case 5:
                        vectors.add(left, right);
                        vectors.add(right, left);
                        break;
                    case 1:
                        vectors.add(right, left);
                        break;
                    case 2:
                        vectors.add(left, right);
                        break;
                }
            }
            ?? r0 = new int[length];
            for (int i = 0; i < length; i++) {
                r0[i] = vectors.traverse(i);
            }
            return r0;
        }
    }

    /* loaded from: input_file:WEB-INF/lib/hive-exec-1.2.0-mapr-1707.jar:org/apache/hadoop/hive/ql/ppd/SyntheticJoinPredicate$SyntheticContext.class */
    private static class SyntheticContext implements NodeProcessorCtx {
        ParseContext parseContext;

        public SyntheticContext(ParseContext parseContext) {
            this.parseContext = parseContext;
        }

        public ParseContext getParseContext() {
            return this.parseContext;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/hive-exec-1.2.0-mapr-1707.jar:org/apache/hadoop/hive/ql/ppd/SyntheticJoinPredicate$Vectors.class */
    public static class Vectors {
        private final Set<Integer>[] vector;

        public Vectors(int i) {
            this.vector = new Set[i];
        }

        public void add(int i, int i2) {
            if (this.vector[i] == null) {
                this.vector[i] = new HashSet();
            }
            this.vector[i].add(Integer.valueOf(i2));
        }

        public int[] traverse(int i) {
            HashSet hashSet = new HashSet();
            traverse(hashSet, i);
            return toArray(hashSet);
        }

        private int[] toArray(Set<Integer> set) {
            int i = 0;
            int[] iArr = new int[set.size()];
            Iterator<Integer> it = set.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                iArr[i2] = it.next().intValue();
            }
            return iArr;
        }

        private void traverse(Set<Integer> set, int i) {
            if (this.vector[i] == null) {
                return;
            }
            Iterator<Integer> it = this.vector[i].iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (set.add(Integer.valueOf(intValue))) {
                    traverse(set, intValue);
                }
            }
        }
    }

    @Override // org.apache.hadoop.hive.ql.optimizer.Transform
    public ParseContext transform(ParseContext parseContext) throws SemanticException {
        if (!parseContext.getConf().getVar(HiveConf.ConfVars.HIVE_EXECUTION_ENGINE).equals("tez") || !parseContext.getConf().getBoolVar(HiveConf.ConfVars.TEZ_DYNAMIC_PARTITION_PRUNING)) {
            return parseContext;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(new RuleRegExp("R1", "(" + TableScanOperator.getOperatorName() + "%.*" + ReduceSinkOperator.getOperatorName() + "%" + JoinOperator.getOperatorName() + "%)"), new JoinSynthetic());
        PreOrderWalker preOrderWalker = new PreOrderWalker(new DefaultRuleDispatcher(null, linkedHashMap, new SyntheticContext(parseContext)));
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(parseContext.getTopOps().values());
        preOrderWalker.startWalking(arrayList, null);
        return parseContext;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public static Operator<FilterDesc> createFilter(Operator<?> operator, Operator<?> operator2, RowSchema rowSchema, ExprNodeDesc exprNodeDesc) {
        Operator<FilterDesc> operator3 = OperatorFactory.get(new FilterDesc(exprNodeDesc, false), new RowSchema(rowSchema.getSignature()), new Operator[0]);
        operator3.getParentOperators().add(operator2);
        operator3.getChildOperators().add(operator);
        operator2.replaceChild(operator, operator3);
        operator.replaceParent(operator2, operator3);
        return operator3;
    }
}
