package org.apache.hadoop.hive.ql.optimizer;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.UnionOperator;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.optimizer.ppr.PartitionPruner;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.PrunedPartitionList;
import org.apache.hadoop.hive.ql.parse.QBJoinTree;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.util.StringUtils;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/BucketMapJoinOptimizer.class */
public class BucketMapJoinOptimizer implements Transform {
    private static final Log LOG = LogFactory.getLog(GroupByOptimizer.class.getName());

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/BucketMapJoinOptimizer$BucketMapjoinOptProc.class */
    public class BucketMapjoinOptProc implements NodeProcessor {
        protected ParseContext pGraphContext;
        static final /* synthetic */ boolean $assertionsDisabled;

        public BucketMapjoinOptProc(ParseContext parseContext) {
            this.pGraphContext = parseContext;
        }

        /* JADX WARN: Multi-variable type inference failed */
        private boolean convertBucketMapJoin(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
            QBJoinTree qBJoinTree;
            List<String> columns;
            MapJoinOperator mapJoinOperator = (MapJoinOperator) node;
            BucketMapjoinOptProcCtx bucketMapjoinOptProcCtx = (BucketMapjoinOptProcCtx) nodeProcessorCtx;
            bucketMapjoinOptProcCtx.getConf();
            if (bucketMapjoinOptProcCtx.getListOfRejectedMapjoins().contains(mapJoinOperator) || (qBJoinTree = this.pGraphContext.getMapJoinContext().get(mapJoinOperator)) == null) {
                return false;
            }
            ArrayList arrayList = new ArrayList();
            String[] baseSrc = qBJoinTree.getBaseSrc();
            String[] leftAliases = qBJoinTree.getLeftAliases();
            List<String> mapAliases = qBJoinTree.getMapAliases();
            String str = null;
            for (String str2 : leftAliases) {
                if (str2 != null && !arrayList.contains(str2)) {
                    arrayList.add(str2);
                    if (!mapAliases.contains(str2)) {
                        str = str2;
                    }
                }
            }
            for (String str3 : baseSrc) {
                if (str3 != null && !arrayList.contains(str3)) {
                    arrayList.add(str3);
                    if (!mapAliases.contains(str3)) {
                        str = str3;
                    }
                }
            }
            MapJoinDesc mapJoinDesc = (MapJoinDesc) mapJoinOperator.getConf();
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            HashMap<String, Operator<? extends OperatorDesc>> topOps = this.pGraphContext.getTopOps();
            HashMap<TableScanOperator, Table> topToTable = this.pGraphContext.getTopToTable();
            LinkedHashMap linkedHashMap3 = new LinkedHashMap();
            LinkedHashMap linkedHashMap4 = new LinkedHashMap();
            Integer[] numArr = null;
            boolean z = true;
            for (int i = 0; i < arrayList.size(); i++) {
                String str4 = (String) arrayList.get(i);
                TableScanOperator tableScanOperator = (TableScanOperator) topOps.get(str4);
                if (tableScanOperator == null || (columns = toColumns(mapJoinDesc.getKeys().get(Byte.valueOf((byte) i)))) == null || columns.isEmpty()) {
                    return false;
                }
                if (numArr == null) {
                    numArr = new Integer[columns.size()];
                }
                Table table = topToTable.get(tableScanOperator);
                if (table.isPartitioned()) {
                    try {
                        PrunedPartitionList prunedPartitionList = this.pGraphContext.getOpToPartList().get(tableScanOperator);
                        if (prunedPartitionList == null) {
                            prunedPartitionList = PartitionPruner.prune(table, this.pGraphContext.getOpToPartPruner().get(tableScanOperator), this.pGraphContext.getConf(), str4, this.pGraphContext.getPrunedPartitions());
                            this.pGraphContext.getOpToPartList().put(tableScanOperator, prunedPartitionList);
                        }
                        List<Partition> notDeniedPartns = prunedPartitionList.getNotDeniedPartns();
                        if (!notDeniedPartns.isEmpty()) {
                            ArrayList arrayList2 = new ArrayList();
                            ArrayList arrayList3 = new ArrayList();
                            for (Partition partition : notDeniedPartns) {
                                if (!checkBucketColumns(partition.getBucketCols(), columns, numArr)) {
                                    return false;
                                }
                                List<String> onePartitionBucketFileNames = getOnePartitionBucketFileNames(partition.getDataLocation());
                                int bucketCount = partition.getBucketCount();
                                if (onePartitionBucketFileNames.size() != bucketCount) {
                                    throw new SemanticException(ErrorMsg.BUCKETED_TABLE_METADATA_INCORRECT.getMsg("The number of buckets for table " + table.getTableName() + " partition " + partition.getName() + " is " + partition.getBucketCount() + ", whereas the number of files is " + onePartitionBucketFileNames.size()));
                                }
                                if (str4.equals(str)) {
                                    linkedHashMap3.put(partition, onePartitionBucketFileNames);
                                    linkedHashMap4.put(partition, Integer.valueOf(bucketCount));
                                } else {
                                    arrayList3.add(onePartitionBucketFileNames);
                                    arrayList2.add(Integer.valueOf(bucketCount));
                                }
                            }
                            if (!str4.equals(str)) {
                                linkedHashMap.put(str4, arrayList2);
                                linkedHashMap2.put(str4, arrayList3);
                            }
                        } else if (!str4.equals(str)) {
                            linkedHashMap.put(str4, Arrays.asList(new Integer[0]));
                            linkedHashMap2.put(str4, new ArrayList());
                        }
                    } catch (HiveException e) {
                        BucketMapJoinOptimizer.LOG.error(StringUtils.stringifyException(e));
                        throw new SemanticException(e.getMessage(), e);
                    }
                } else {
                    if (!checkBucketColumns(table.getBucketCols(), columns, numArr)) {
                        return false;
                    }
                    List<String> onePartitionBucketFileNames2 = getOnePartitionBucketFileNames(table.getDataLocation());
                    Integer num = new Integer(table.getNumBuckets());
                    if (onePartitionBucketFileNames2.size() != num.intValue()) {
                        throw new SemanticException(ErrorMsg.BUCKETED_TABLE_METADATA_INCORRECT.getMsg("The number of buckets for table " + table.getTableName() + " is " + table.getNumBuckets() + ", whereas the number of files is " + onePartitionBucketFileNames2.size()));
                    }
                    if (str4.equals(str)) {
                        linkedHashMap3.put(null, onePartitionBucketFileNames2);
                        linkedHashMap4.put(null, Integer.valueOf(table.getNumBuckets()));
                        z = false;
                    } else {
                        linkedHashMap.put(str4, Arrays.asList(num));
                        linkedHashMap2.put(str4, Arrays.asList(onePartitionBucketFileNames2));
                    }
                }
            }
            Iterator it = linkedHashMap4.values().iterator();
            while (it.hasNext()) {
                if (!checkBucketNumberAgainstBigTable(linkedHashMap, ((Integer) it.next()).intValue())) {
                    return false;
                }
            }
            MapJoinDesc mapJoinDesc2 = (MapJoinDesc) mapJoinOperator.getConf();
            LinkedHashMap linkedHashMap5 = new LinkedHashMap();
            Iterator it2 = linkedHashMap3.values().iterator();
            while (it2.hasNext()) {
                Collections.sort((List) it2.next());
            }
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                String str5 = (String) arrayList.get(i2);
                if (!str5.equals(str)) {
                    Iterator it3 = ((List) linkedHashMap2.get(str5)).iterator();
                    while (it3.hasNext()) {
                        Collections.sort((List) it3.next());
                    }
                    List<Integer> list = (List) linkedHashMap.get(str5);
                    List<List<String>> list2 = (List) linkedHashMap2.get(str5);
                    LinkedHashMap linkedHashMap6 = new LinkedHashMap();
                    linkedHashMap5.put(str5, linkedHashMap6);
                    Iterator it4 = linkedHashMap3.entrySet().iterator();
                    Iterator it5 = linkedHashMap4.entrySet().iterator();
                    while (it4.hasNext()) {
                        if (!$assertionsDisabled && !it5.hasNext()) {
                            throw new AssertionError();
                        }
                        fillMapping(list, list2, linkedHashMap6, ((Integer) ((Map.Entry) it5.next()).getValue()).intValue(), (List) ((Map.Entry) it4.next()).getValue(), mapJoinDesc2.getBigTableBucketNumMapping());
                    }
                }
            }
            mapJoinDesc2.setAliasBucketFileNameMapping(linkedHashMap5);
            mapJoinDesc2.setBigTableAlias(str);
            if (!z) {
                return true;
            }
            mapJoinDesc2.setBigTablePartSpecToFileMapping(convert(linkedHashMap3));
            return true;
        }

        @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
        public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
            boolean convertBucketMapJoin = convertBucketMapJoin(node, stack, nodeProcessorCtx, objArr);
            HiveConf conf = ((BucketMapjoinOptProcCtx) nodeProcessorCtx).getConf();
            if (convertBucketMapJoin || !conf.getBoolVar(HiveConf.ConfVars.HIVEENFORCEBUCKETMAPJOIN)) {
                return null;
            }
            throw new SemanticException(ErrorMsg.BUCKET_MAPJOIN_NOT_POSSIBLE.getMsg());
        }

        private List<String> toColumns(List<ExprNodeDesc> list) {
            ArrayList arrayList = new ArrayList();
            for (ExprNodeDesc exprNodeDesc : list) {
                if (!(exprNodeDesc instanceof ExprNodeColumnDesc)) {
                    return null;
                }
                arrayList.add(((ExprNodeColumnDesc) exprNodeDesc).getColumn());
            }
            return arrayList;
        }

        private Map<String, List<String>> convert(Map<Partition, List<String>> map) {
            HashMap hashMap = new HashMap();
            for (Map.Entry<Partition, List<String>> entry : map.entrySet()) {
                hashMap.put(entry.getKey().getName(), entry.getValue());
            }
            return hashMap;
        }

        private void fillMapping(List<Integer> list, List<List<String>> list2, Map<String, List<String>> map, int i, List<String> list3, Map<String, Integer> map2) {
            for (int i2 = 0; i2 < list3.size(); i2++) {
                ArrayList arrayList = new ArrayList();
                for (int i3 = 0; i3 < list.size(); i3++) {
                    int intValue = list.get(i3).intValue();
                    List<String> list4 = list2.get(i3);
                    if (i >= intValue) {
                        arrayList.add(list4.get(i2 % intValue));
                    } else {
                        int i4 = intValue / i;
                        int i5 = i2;
                        while (true) {
                            int i6 = i5;
                            if (i6 < list4.size()) {
                                arrayList.add(list4.get(i6));
                                i5 = i6 + i4;
                            }
                        }
                    }
                }
                String str = list3.get(i2);
                map.put(str, arrayList);
                map2.put(str, Integer.valueOf(i2));
            }
        }

        private boolean checkBucketNumberAgainstBigTable(Map<String, List<Integer>> map, int i) {
            Iterator<List<Integer>> it = map.values().iterator();
            while (it.hasNext()) {
                Iterator<Integer> it2 = it.next().iterator();
                while (it2.hasNext()) {
                    int intValue = it2.next().intValue();
                    if (!(intValue >= i ? intValue % i == 0 : i % intValue == 0)) {
                        return false;
                    }
                }
            }
            return true;
        }

        private List<String> getOnePartitionBucketFileNames(URI uri) throws SemanticException {
            ArrayList arrayList = new ArrayList();
            try {
                FileStatus[] listStatus = FileSystem.get(uri, this.pGraphContext.getConf()).listStatus(new Path(uri.toString()));
                if (listStatus != null) {
                    for (FileStatus fileStatus : listStatus) {
                        arrayList.add(fileStatus.getPath().toString());
                    }
                }
                return arrayList;
            } catch (IOException e) {
                throw new SemanticException(e);
            }
        }

        private boolean checkBucketColumns(List<String> list, List<String> list2, Integer[] numArr) {
            if (list2 == null || list == null || list.isEmpty()) {
                return false;
            }
            for (int i = 0; i < list2.size(); i++) {
                int indexOf = list.indexOf(list2.get(i));
                if (numArr[i] != null && numArr[i].intValue() != indexOf) {
                    return false;
                }
                numArr[i] = Integer.valueOf(indexOf);
            }
            return list2.containsAll(list);
        }

        static {
            $assertionsDisabled = !BucketMapJoinOptimizer.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/BucketMapJoinOptimizer$BucketMapjoinOptProcCtx.class */
    public class BucketMapjoinOptProcCtx implements NodeProcessorCtx {
        private final HiveConf conf;
        Set<MapJoinOperator> listOfRejectedMapjoins = new HashSet();

        public BucketMapjoinOptProcCtx(HiveConf hiveConf) {
            this.conf = hiveConf;
        }

        public HiveConf getConf() {
            return this.conf;
        }

        public Set<MapJoinOperator> getListOfRejectedMapjoins() {
            return this.listOfRejectedMapjoins;
        }
    }

    @Override // org.apache.hadoop.hive.ql.optimizer.Transform
    public ParseContext transform(ParseContext parseContext) throws SemanticException {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        BucketMapjoinOptProcCtx bucketMapjoinOptProcCtx = new BucketMapjoinOptProcCtx(parseContext.getConf());
        linkedHashMap.put(new RuleRegExp("R1", MapJoinOperator.getOperatorName() + "%"), getBucketMapjoinProc(parseContext));
        linkedHashMap.put(new RuleRegExp("R2", ReduceSinkOperator.getOperatorName() + "%.*" + MapJoinOperator.getOperatorName()), getBucketMapjoinRejectProc(parseContext));
        linkedHashMap.put(new RuleRegExp(new String("R3"), UnionOperator.getOperatorName() + "%.*" + MapJoinOperator.getOperatorName() + "%"), getBucketMapjoinRejectProc(parseContext));
        linkedHashMap.put(new RuleRegExp(new String("R4"), MapJoinOperator.getOperatorName() + "%.*" + MapJoinOperator.getOperatorName() + "%"), getBucketMapjoinRejectProc(parseContext));
        DefaultGraphWalker defaultGraphWalker = new DefaultGraphWalker(new DefaultRuleDispatcher(getDefaultProc(), linkedHashMap, bucketMapjoinOptProcCtx));
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(parseContext.getTopOps().values());
        defaultGraphWalker.startWalking(arrayList, null);
        return parseContext;
    }

    private NodeProcessor getBucketMapjoinRejectProc(ParseContext parseContext) {
        return new NodeProcessor() { // from class: org.apache.hadoop.hive.ql.optimizer.BucketMapJoinOptimizer.1
            @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
            public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
                ((BucketMapjoinOptProcCtx) nodeProcessorCtx).listOfRejectedMapjoins.add((MapJoinOperator) node);
                return null;
            }
        };
    }

    private NodeProcessor getBucketMapjoinProc(ParseContext parseContext) {
        return new BucketMapjoinOptProc(parseContext);
    }

    private NodeProcessor getDefaultProc() {
        return new NodeProcessor() { // from class: org.apache.hadoop.hive.ql.optimizer.BucketMapJoinOptimizer.2
            @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
            public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
                return null;
            }
        };
    }
}
