package org.apache.hadoop.hive.ql.optimizer.calcite.cost;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2201-core.jar:org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveOnTezCostModel.class */
public class HiveOnTezCostModel extends HiveCostModel {
    private static HiveOnTezCostModel INSTANCE;
    private static HiveAlgorithmsUtil algoUtils;
    private static final transient Logger LOG = LoggerFactory.getLogger(HiveOnTezCostModel.class);

    /* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2201-core.jar:org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveOnTezCostModel$TezBucketJoinAlgorithm.class */
    public static class TezBucketJoinAlgorithm implements HiveCostModel.JoinAlgorithm {
        public static final HiveCostModel.JoinAlgorithm INSTANCE = new TezBucketJoinAlgorithm();
        private static final String ALGORITHM_NAME = "BucketJoin";

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public String toString() {
            return ALGORITHM_NAME;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public boolean isExecutable(HiveJoin hiveJoin) {
            Double maxMemory = ((HiveAlgorithmsConf) hiveJoin.getCluster().getPlanner().getContext().unwrap(HiveAlgorithmsConf.class)).getMaxMemory();
            RelNode streamingInput = hiveJoin.getStreamingInput();
            if (streamingInput == null) {
                return false;
            }
            HiveCalciteUtil.JoinPredicateInfo joinPredicateInfo = hiveJoin.getJoinPredicateInfo();
            ArrayList arrayList = new ArrayList();
            arrayList.add(ImmutableIntList.copyOf((Iterable<? extends Number>) joinPredicateInfo.getProjsFromLeftPartOfJoinKeysInChildSchema()));
            arrayList.add(ImmutableIntList.copyOf((Iterable<? extends Number>) joinPredicateInfo.getProjsFromRightPartOfJoinKeysInChildSchema()));
            HiveCostModel.JoinAlgorithm joinAlgorithm = hiveJoin.getJoinAlgorithm();
            hiveJoin.setJoinAlgorithm(INSTANCE);
            Integer splitCount = RelMetadataQuery.instance().splitCount(streamingInput);
            hiveJoin.setJoinAlgorithm(joinAlgorithm);
            if (splitCount == null || !HiveAlgorithmsUtil.isFittingIntoMemory(maxMemory, streamingInput, splitCount.intValue())) {
                return false;
            }
            for (int i = 0; i < hiveJoin.getInputs().size(); i++) {
                RelDistribution distribution = RelMetadataQuery.instance().distribution(hiveJoin.getInputs().get(i));
                if (distribution.getType() != RelDistribution.Type.HASH_DISTRIBUTED || !distribution.getKeys().containsAll((Collection) arrayList.get(i))) {
                    return false;
                }
            }
            return true;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public RelOptCost getCost(HiveJoin hiveJoin) {
            RelMetadataQuery instance = RelMetadataQuery.instance();
            Double rowCount = instance.getRowCount(hiveJoin.getLeft());
            Double rowCount2 = instance.getRowCount(hiveJoin.getRight());
            if (rowCount == null || rowCount2 == null) {
                return null;
            }
            double doubleValue = rowCount.doubleValue() + rowCount2.doubleValue();
            ImmutableList<Double> build = new ImmutableList.Builder().add((ImmutableList.Builder) rowCount).add((ImmutableList.Builder) rowCount2).build();
            ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
            switch (hiveJoin.getStreamingSide()) {
                case LEFT_RELATION:
                    builder.set(0);
                    break;
                case RIGHT_RELATION:
                    builder.set(1);
                    break;
                default:
                    return null;
            }
            ImmutableBitSet build2 = builder.build();
            double computeBucketMapJoinCPUCost = HiveOnTezCostModel.algoUtils.computeBucketMapJoinCPUCost(build, build2);
            Double averageRowSize = instance.getAverageRowSize(hiveJoin.getLeft());
            Double averageRowSize2 = instance.getAverageRowSize(hiveJoin.getRight());
            if (averageRowSize == null || averageRowSize2 == null) {
                return null;
            }
            ImmutableList<Pair<Double, Double>> build3 = new ImmutableList.Builder().add((ImmutableList.Builder) new Pair(rowCount, averageRowSize)).add((ImmutableList.Builder) new Pair(rowCount2, averageRowSize2)).build();
            HiveCostModel.JoinAlgorithm joinAlgorithm = hiveJoin.getJoinAlgorithm();
            hiveJoin.setJoinAlgorithm(INSTANCE);
            int intValue = instance.splitCount(hiveJoin) == null ? 1 : instance.splitCount(hiveJoin).intValue();
            hiveJoin.setJoinAlgorithm(joinAlgorithm);
            return HiveCost.FACTORY.makeCost(doubleValue, computeBucketMapJoinCPUCost, HiveOnTezCostModel.algoUtils.computeBucketMapJoinIOCost(build3, build2, intValue));
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public ImmutableList<RelCollation> getCollation(HiveJoin hiveJoin) {
            HiveJoin.MapJoinStreamingRelation streamingSide = hiveJoin.getStreamingSide();
            if (streamingSide == HiveJoin.MapJoinStreamingRelation.LEFT_RELATION || streamingSide == HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                return HiveAlgorithmsUtil.getJoinCollation(hiveJoin.getJoinPredicateInfo(), hiveJoin.getStreamingSide());
            }
            HiveOnTezCostModel.LOG.warn("Streaming side for map join not chosen");
            return ImmutableList.of();
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public RelDistribution getDistribution(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getJoinRedistribution(hiveJoin.getJoinPredicateInfo());
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Double getMemory(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getJoinMemory(hiveJoin);
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin hiveJoin) {
            RelNode left;
            if (hiveJoin.getStreamingSide() == HiveJoin.MapJoinStreamingRelation.LEFT_RELATION) {
                left = hiveJoin.getRight();
            } else {
                if (hiveJoin.getStreamingSide() != HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                    return null;
                }
                left = hiveJoin.getLeft();
            }
            Double cumulativeMemoryWithinPhase = RelMetadataQuery.instance().cumulativeMemoryWithinPhase(left);
            Integer splitCount = RelMetadataQuery.instance().splitCount(left);
            if (cumulativeMemoryWithinPhase == null || splitCount == null) {
                return null;
            }
            return Double.valueOf(cumulativeMemoryWithinPhase.doubleValue() / splitCount.intValue());
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Boolean isPhaseTransition(HiveJoin hiveJoin) {
            return false;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Integer getSplitCount(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getSplitCountWithoutRepartition(hiveJoin);
        }
    }

    /* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2201-core.jar:org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveOnTezCostModel$TezCommonJoinAlgorithm.class */
    public static class TezCommonJoinAlgorithm implements HiveCostModel.JoinAlgorithm {
        public static final HiveCostModel.JoinAlgorithm INSTANCE = new TezCommonJoinAlgorithm();
        private static final String ALGORITHM_NAME = "CommonJoin";

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public String toString() {
            return ALGORITHM_NAME;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public boolean isExecutable(HiveJoin hiveJoin) {
            return true;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public RelOptCost getCost(HiveJoin hiveJoin) {
            RelMetadataQuery instance = RelMetadataQuery.instance();
            Double rowCount = instance.getRowCount(hiveJoin.getLeft());
            Double rowCount2 = instance.getRowCount(hiveJoin.getRight());
            if (rowCount == null || rowCount2 == null) {
                return null;
            }
            double doubleValue = rowCount.doubleValue() + rowCount2.doubleValue();
            try {
                double computeSortMergeCPUCost = HiveOnTezCostModel.algoUtils.computeSortMergeCPUCost(new ImmutableList.Builder().add((ImmutableList.Builder) rowCount).add((ImmutableList.Builder) rowCount2).build(), hiveJoin.getSortedInputs());
                Double averageRowSize = instance.getAverageRowSize(hiveJoin.getLeft());
                Double averageRowSize2 = instance.getAverageRowSize(hiveJoin.getRight());
                if (averageRowSize == null || averageRowSize2 == null) {
                    return null;
                }
                return HiveCost.FACTORY.makeCost(doubleValue, computeSortMergeCPUCost, HiveOnTezCostModel.algoUtils.computeSortMergeIOCost(new ImmutableList.Builder().add((ImmutableList.Builder) new Pair(rowCount, averageRowSize)).add((ImmutableList.Builder) new Pair(rowCount2, averageRowSize2)).build()));
            } catch (CalciteSemanticException e) {
                HiveOnTezCostModel.LOG.trace("Failed to compute sort merge cpu cost ", (Throwable) e);
                return null;
            }
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public ImmutableList<RelCollation> getCollation(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getJoinCollation(hiveJoin.getJoinPredicateInfo(), HiveJoin.MapJoinStreamingRelation.NONE);
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public RelDistribution getDistribution(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getJoinRedistribution(hiveJoin.getJoinPredicateInfo());
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Double getMemory(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getJoinMemory(hiveJoin, HiveJoin.MapJoinStreamingRelation.NONE);
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin hiveJoin) {
            HiveCostModel.JoinAlgorithm joinAlgorithm = hiveJoin.getJoinAlgorithm();
            hiveJoin.setJoinAlgorithm(INSTANCE);
            Double cumulativeMemoryWithinPhase = RelMetadataQuery.instance().cumulativeMemoryWithinPhase(hiveJoin);
            Integer splitCount = RelMetadataQuery.instance().splitCount(hiveJoin);
            hiveJoin.setJoinAlgorithm(joinAlgorithm);
            if (cumulativeMemoryWithinPhase == null || splitCount == null) {
                return null;
            }
            return Double.valueOf(cumulativeMemoryWithinPhase.doubleValue() / splitCount.intValue());
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Boolean isPhaseTransition(HiveJoin hiveJoin) {
            return true;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Integer getSplitCount(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getSplitCountWithRepartition(hiveJoin);
        }
    }

    /* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2201-core.jar:org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveOnTezCostModel$TezMapJoinAlgorithm.class */
    public static class TezMapJoinAlgorithm implements HiveCostModel.JoinAlgorithm {
        public static final HiveCostModel.JoinAlgorithm INSTANCE = new TezMapJoinAlgorithm();
        private static final String ALGORITHM_NAME = "MapJoin";

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public String toString() {
            return ALGORITHM_NAME;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public boolean isExecutable(HiveJoin hiveJoin) {
            Double maxMemory = ((HiveAlgorithmsConf) hiveJoin.getCluster().getPlanner().getContext().unwrap(HiveAlgorithmsConf.class)).getMaxMemory();
            RelNode streamingInput = hiveJoin.getStreamingInput();
            if (streamingInput == null) {
                return false;
            }
            return HiveAlgorithmsUtil.isFittingIntoMemory(maxMemory, streamingInput, 1);
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public RelOptCost getCost(HiveJoin hiveJoin) {
            RelMetadataQuery instance = RelMetadataQuery.instance();
            Double rowCount = instance.getRowCount(hiveJoin.getLeft());
            Double rowCount2 = instance.getRowCount(hiveJoin.getRight());
            if (rowCount == null || rowCount2 == null) {
                return null;
            }
            double doubleValue = rowCount.doubleValue() + rowCount2.doubleValue();
            ImmutableList build = new ImmutableList.Builder().add((ImmutableList.Builder) rowCount).add((ImmutableList.Builder) rowCount2).build();
            ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
            switch (hiveJoin.getStreamingSide()) {
                case LEFT_RELATION:
                    builder.set(0);
                    break;
                case RIGHT_RELATION:
                    builder.set(1);
                    break;
                default:
                    return null;
            }
            ImmutableBitSet build2 = builder.build();
            double computeMapJoinCPUCost = HiveAlgorithmsUtil.computeMapJoinCPUCost(build, build2);
            Double averageRowSize = instance.getAverageRowSize(hiveJoin.getLeft());
            Double averageRowSize2 = instance.getAverageRowSize(hiveJoin.getRight());
            if (averageRowSize == null || averageRowSize2 == null) {
                return null;
            }
            ImmutableList<Pair<Double, Double>> build3 = new ImmutableList.Builder().add((ImmutableList.Builder) new Pair(rowCount, averageRowSize)).add((ImmutableList.Builder) new Pair(rowCount2, averageRowSize2)).build();
            HiveCostModel.JoinAlgorithm joinAlgorithm = hiveJoin.getJoinAlgorithm();
            hiveJoin.setJoinAlgorithm(INSTANCE);
            int intValue = instance.splitCount(hiveJoin) == null ? 1 : instance.splitCount(hiveJoin).intValue();
            hiveJoin.setJoinAlgorithm(joinAlgorithm);
            return HiveCost.FACTORY.makeCost(doubleValue, computeMapJoinCPUCost, HiveOnTezCostModel.algoUtils.computeMapJoinIOCost(build3, build2, intValue));
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public ImmutableList<RelCollation> getCollation(HiveJoin hiveJoin) {
            HiveJoin.MapJoinStreamingRelation streamingSide = hiveJoin.getStreamingSide();
            if (streamingSide == HiveJoin.MapJoinStreamingRelation.LEFT_RELATION || streamingSide == HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                return HiveAlgorithmsUtil.getJoinCollation(hiveJoin.getJoinPredicateInfo(), hiveJoin.getStreamingSide());
            }
            HiveOnTezCostModel.LOG.warn("Streaming side for map join not chosen");
            return ImmutableList.of();
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public RelDistribution getDistribution(HiveJoin hiveJoin) {
            HiveJoin.MapJoinStreamingRelation streamingSide = hiveJoin.getStreamingSide();
            if (streamingSide == HiveJoin.MapJoinStreamingRelation.LEFT_RELATION || streamingSide == HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                return HiveAlgorithmsUtil.getJoinDistribution(hiveJoin.getJoinPredicateInfo(), hiveJoin.getStreamingSide());
            }
            HiveOnTezCostModel.LOG.warn("Streaming side for map join not chosen");
            return RelDistributions.SINGLETON;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Double getMemory(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getJoinMemory(hiveJoin);
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin hiveJoin) {
            RelNode left;
            if (hiveJoin.getStreamingSide() == HiveJoin.MapJoinStreamingRelation.LEFT_RELATION) {
                left = hiveJoin.getRight();
            } else {
                if (hiveJoin.getStreamingSide() != HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                    return null;
                }
                left = hiveJoin.getLeft();
            }
            return RelMetadataQuery.instance().cumulativeMemoryWithinPhase(left);
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Boolean isPhaseTransition(HiveJoin hiveJoin) {
            return false;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Integer getSplitCount(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getSplitCountWithoutRepartition(hiveJoin);
        }
    }

    /* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2201-core.jar:org/apache/hadoop/hive/ql/optimizer/calcite/cost/HiveOnTezCostModel$TezSMBJoinAlgorithm.class */
    public static class TezSMBJoinAlgorithm implements HiveCostModel.JoinAlgorithm {
        public static final HiveCostModel.JoinAlgorithm INSTANCE = new TezSMBJoinAlgorithm();
        private static final String ALGORITHM_NAME = "SMBJoin";

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public String toString() {
            return ALGORITHM_NAME;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public boolean isExecutable(HiveJoin hiveJoin) {
            HiveCalciteUtil.JoinPredicateInfo joinPredicateInfo = hiveJoin.getJoinPredicateInfo();
            ArrayList arrayList = new ArrayList();
            arrayList.add(ImmutableIntList.copyOf((Iterable<? extends Number>) joinPredicateInfo.getProjsFromLeftPartOfJoinKeysInChildSchema()));
            arrayList.add(ImmutableIntList.copyOf((Iterable<? extends Number>) joinPredicateInfo.getProjsFromRightPartOfJoinKeysInChildSchema()));
            for (int i = 0; i < hiveJoin.getInputs().size(); i++) {
                RelNode relNode = hiveJoin.getInputs().get(i);
                try {
                    if (!hiveJoin.getSortedInputs().get(i)) {
                        return false;
                    }
                    RelDistribution distribution = RelMetadataQuery.instance().distribution(relNode);
                    if (distribution.getType() != RelDistribution.Type.HASH_DISTRIBUTED || !distribution.getKeys().containsAll((Collection) arrayList.get(i))) {
                        return false;
                    }
                } catch (CalciteSemanticException e) {
                    HiveOnTezCostModel.LOG.trace("Not possible to do SMB Join ", (Throwable) e);
                    return false;
                }
            }
            return true;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public RelOptCost getCost(HiveJoin hiveJoin) {
            RelMetadataQuery instance = RelMetadataQuery.instance();
            Double rowCount = instance.getRowCount(hiveJoin.getLeft());
            Double rowCount2 = instance.getRowCount(hiveJoin.getRight());
            if (rowCount == null || rowCount2 == null) {
                return null;
            }
            double doubleValue = rowCount.doubleValue() + rowCount2.doubleValue();
            ImmutableList build = new ImmutableList.Builder().add((ImmutableList.Builder) rowCount).add((ImmutableList.Builder) rowCount2).build();
            ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
            switch (hiveJoin.getStreamingSide()) {
                case LEFT_RELATION:
                    builder.set(0);
                    break;
                case RIGHT_RELATION:
                    builder.set(1);
                    break;
                default:
                    return null;
            }
            ImmutableBitSet build2 = builder.build();
            double computeSMBMapJoinCPUCost = HiveAlgorithmsUtil.computeSMBMapJoinCPUCost(build);
            Double averageRowSize = instance.getAverageRowSize(hiveJoin.getLeft());
            Double averageRowSize2 = instance.getAverageRowSize(hiveJoin.getRight());
            if (averageRowSize == null || averageRowSize2 == null) {
                return null;
            }
            ImmutableList<Pair<Double, Double>> build3 = new ImmutableList.Builder().add((ImmutableList.Builder) new Pair(rowCount, averageRowSize)).add((ImmutableList.Builder) new Pair(rowCount2, averageRowSize2)).build();
            HiveCostModel.JoinAlgorithm joinAlgorithm = hiveJoin.getJoinAlgorithm();
            hiveJoin.setJoinAlgorithm(INSTANCE);
            int intValue = instance.splitCount(hiveJoin) == null ? 1 : instance.splitCount(hiveJoin).intValue();
            hiveJoin.setJoinAlgorithm(joinAlgorithm);
            return HiveCost.FACTORY.makeCost(doubleValue, computeSMBMapJoinCPUCost, HiveOnTezCostModel.algoUtils.computeSMBMapJoinIOCost(build3, build2, intValue));
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public ImmutableList<RelCollation> getCollation(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getJoinCollation(hiveJoin.getJoinPredicateInfo(), HiveJoin.MapJoinStreamingRelation.NONE);
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public RelDistribution getDistribution(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getJoinRedistribution(hiveJoin.getJoinPredicateInfo());
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Double getMemory(HiveJoin hiveJoin) {
            return Double.valueOf(0.0d);
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin hiveJoin) {
            RelMetadataQuery instance = RelMetadataQuery.instance();
            HiveCostModel.JoinAlgorithm joinAlgorithm = hiveJoin.getJoinAlgorithm();
            hiveJoin.setJoinAlgorithm(INSTANCE);
            Double cumulativeMemoryWithinPhase = instance.cumulativeMemoryWithinPhase(hiveJoin);
            Integer splitCount = instance.splitCount(hiveJoin);
            hiveJoin.setJoinAlgorithm(joinAlgorithm);
            if (cumulativeMemoryWithinPhase == null || splitCount == null) {
                return null;
            }
            return Double.valueOf(cumulativeMemoryWithinPhase.doubleValue() / splitCount.intValue());
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Boolean isPhaseTransition(HiveJoin hiveJoin) {
            return false;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel.JoinAlgorithm
        public Integer getSplitCount(HiveJoin hiveJoin) {
            return HiveAlgorithmsUtil.getSplitCountWithoutRepartition(hiveJoin);
        }
    }

    public static synchronized HiveOnTezCostModel getCostModel(HiveConf hiveConf) {
        if (INSTANCE == null) {
            INSTANCE = new HiveOnTezCostModel(hiveConf);
        }
        return INSTANCE;
    }

    private HiveOnTezCostModel(HiveConf hiveConf) {
        super(Sets.newHashSet(TezCommonJoinAlgorithm.INSTANCE, TezMapJoinAlgorithm.INSTANCE, TezBucketJoinAlgorithm.INSTANCE, TezSMBJoinAlgorithm.INSTANCE));
        algoUtils = new HiveAlgorithmsUtil(hiveConf);
    }

    @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel
    public RelOptCost getDefaultCost() {
        return HiveCost.FACTORY.makeZeroCost();
    }

    @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel
    public RelOptCost getScanCost(HiveTableScan hiveTableScan) {
        return algoUtils.computeScanCost(hiveTableScan.getRows(), RelMetadataQuery.instance().getAverageRowSize(hiveTableScan).doubleValue());
    }

    @Override // org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel
    public RelOptCost getAggregateCost(HiveAggregate hiveAggregate) {
        if (hiveAggregate.isBucketedInput()) {
            return HiveCost.FACTORY.makeZeroCost();
        }
        RelMetadataQuery instance = RelMetadataQuery.instance();
        Double rowCount = instance.getRowCount(hiveAggregate.getInput());
        if (rowCount == null) {
            return null;
        }
        double computeSortCPUCost = algoUtils.computeSortCPUCost(rowCount);
        Double averageRowSize = instance.getAverageRowSize(hiveAggregate.getInput());
        if (averageRowSize == null) {
            return null;
        }
        return HiveCost.FACTORY.makeCost(rowCount.doubleValue(), computeSortCPUCost, algoUtils.computeSortIOCost(new Pair<>(rowCount, averageRowSize)));
    }
}
