/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.cost;

import hive.com.google.common.collect.ImmutableCollection;
import hive.com.google.common.collect.ImmutableList;
import hive.com.google.common.collect.Sets;
import hive.org.apache.calcite.plan.RelOptCost;
import hive.org.apache.calcite.rel.RelCollation;
import hive.org.apache.calcite.rel.RelDistribution;
import hive.org.apache.calcite.rel.RelNode;
import hive.org.apache.calcite.rel.metadata.RelMetadataQuery;
import hive.org.apache.calcite.util.ImmutableBitSet;
import hive.org.apache.calcite.util.ImmutableIntList;
import hive.org.apache.calcite.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveAlgorithmsConf;
import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveAlgorithmsUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCost;
import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveCostModel;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;

public class HiveOnTezCostModel
extends HiveCostModel {
    private static HiveOnTezCostModel INSTANCE;
    private static HiveAlgorithmsUtil algoUtils;

    public static synchronized HiveOnTezCostModel getCostModel(HiveConf conf) {
        if (INSTANCE == null) {
            INSTANCE = new HiveOnTezCostModel(conf);
        }
        return INSTANCE;
    }

    private HiveOnTezCostModel(HiveConf conf) {
        super(Sets.newHashSet(TezCommonJoinAlgorithm.INSTANCE, TezMapJoinAlgorithm.INSTANCE, TezBucketJoinAlgorithm.INSTANCE, TezSMBJoinAlgorithm.INSTANCE));
        algoUtils = new HiveAlgorithmsUtil(conf);
    }

    @Override
    public RelOptCost getDefaultCost() {
        return HiveCost.FACTORY.makeZeroCost();
    }

    @Override
    public RelOptCost getScanCost(HiveTableScan ts) {
        return algoUtils.computeScanCost(ts.getRows(), RelMetadataQuery.getAverageRowSize(ts));
    }

    @Override
    public RelOptCost getAggregateCost(HiveAggregate aggregate) {
        if (aggregate.isBucketedInput()) {
            return HiveCost.FACTORY.makeZeroCost();
        }
        Double rCount = RelMetadataQuery.getRowCount(aggregate.getInput());
        if (rCount == null) {
            return null;
        }
        double cpuCost = algoUtils.computeSortCPUCost(rCount);
        Double rAverageSize = RelMetadataQuery.getAverageRowSize(aggregate.getInput());
        if (rAverageSize == null) {
            return null;
        }
        double ioCost = algoUtils.computeSortIOCost(new Pair<Double, Double>(rCount, rAverageSize));
        return HiveCost.FACTORY.makeCost(rCount, cpuCost, ioCost);
    }

    public static class TezSMBJoinAlgorithm
    implements HiveCostModel.JoinAlgorithm {
        public static final HiveCostModel.JoinAlgorithm INSTANCE = new TezSMBJoinAlgorithm();
        private static final String ALGORITHM_NAME = "SMBJoin";

        @Override
        public String toString() {
            return ALGORITHM_NAME;
        }

        @Override
        public boolean isExecutable(HiveJoin join) {
            HiveCalciteUtil.JoinPredicateInfo joinPredInfo = join.getJoinPredicateInfo();
            ArrayList<ImmutableIntList> joinKeysInChildren = new ArrayList<ImmutableIntList>();
            joinKeysInChildren.add(ImmutableIntList.copyOf(joinPredInfo.getProjsFromLeftPartOfJoinKeysInChildSchema()));
            joinKeysInChildren.add(ImmutableIntList.copyOf(joinPredInfo.getProjsFromRightPartOfJoinKeysInChildSchema()));
            for (int i = 0; i < join.getInputs().size(); ++i) {
                RelNode input = join.getInputs().get(i);
                boolean orderFound = join.getSortedInputs().get(i);
                if (!orderFound) {
                    return false;
                }
                RelDistribution distribution = RelMetadataQuery.distribution(input);
                if (distribution.getType() != RelDistribution.Type.HASH_DISTRIBUTED) {
                    return false;
                }
                if (distribution.getKeys().containsAll((Collection)joinKeysInChildren.get(i))) continue;
                return false;
            }
            return true;
        }

        @Override
        public RelOptCost getCost(HiveJoin join) {
            Double leftRCount = RelMetadataQuery.getRowCount(join.getLeft());
            Double rightRCount = RelMetadataQuery.getRowCount(join.getRight());
            if (leftRCount == null || rightRCount == null) {
                return null;
            }
            double rCount = leftRCount + rightRCount;
            ImmutableCollection cardinalities = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(leftRCount)).add(rightRCount)).build();
            ImmutableBitSet.Builder streamingBuilder = new ImmutableBitSet.Builder();
            switch (join.getStreamingSide()) {
                case LEFT_RELATION: {
                    streamingBuilder.set(0);
                    break;
                }
                case RIGHT_RELATION: {
                    streamingBuilder.set(1);
                    break;
                }
                default: {
                    return null;
                }
            }
            ImmutableBitSet streaming = streamingBuilder.build();
            double cpuCost = HiveAlgorithmsUtil.computeSMBMapJoinCPUCost((ImmutableList<Double>)cardinalities);
            Double leftRAverageSize = RelMetadataQuery.getAverageRowSize(join.getLeft());
            Double rightRAverageSize = RelMetadataQuery.getAverageRowSize(join.getRight());
            if (leftRAverageSize == null || rightRAverageSize == null) {
                return null;
            }
            ImmutableCollection relationInfos = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(new Pair<Double, Double>(leftRCount, leftRAverageSize))).add(new Pair<Double, Double>(rightRCount, rightRAverageSize))).build();
            HiveCostModel.JoinAlgorithm oldAlgo = join.getJoinAlgorithm();
            join.setJoinAlgorithm(INSTANCE);
            int parallelism = RelMetadataQuery.splitCount(join) == null ? 1 : RelMetadataQuery.splitCount(join);
            join.setJoinAlgorithm(oldAlgo);
            double ioCost = algoUtils.computeSMBMapJoinIOCost((ImmutableList<Pair<Double, Double>>)relationInfos, streaming, parallelism);
            return HiveCost.FACTORY.makeCost(rCount, cpuCost, ioCost);
        }

        @Override
        public ImmutableList<RelCollation> getCollation(HiveJoin join) {
            return HiveAlgorithmsUtil.getJoinCollation(join.getJoinPredicateInfo(), HiveJoin.MapJoinStreamingRelation.NONE);
        }

        @Override
        public RelDistribution getDistribution(HiveJoin join) {
            return HiveAlgorithmsUtil.getJoinRedistribution(join.getJoinPredicateInfo());
        }

        @Override
        public Double getMemory(HiveJoin join) {
            return 0.0;
        }

        @Override
        public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin join) {
            HiveCostModel.JoinAlgorithm oldAlgo = join.getJoinAlgorithm();
            join.setJoinAlgorithm(INSTANCE);
            Double memoryWithinPhase = RelMetadataQuery.cumulativeMemoryWithinPhase(join);
            Integer splitCount = RelMetadataQuery.splitCount(join);
            join.setJoinAlgorithm(oldAlgo);
            if (memoryWithinPhase == null || splitCount == null) {
                return null;
            }
            return memoryWithinPhase / (double)splitCount.intValue();
        }

        @Override
        public Boolean isPhaseTransition(HiveJoin join) {
            return false;
        }

        @Override
        public Integer getSplitCount(HiveJoin join) {
            return HiveAlgorithmsUtil.getSplitCountWithoutRepartition(join);
        }
    }

    public static class TezBucketJoinAlgorithm
    implements HiveCostModel.JoinAlgorithm {
        public static final HiveCostModel.JoinAlgorithm INSTANCE = new TezBucketJoinAlgorithm();
        private static final String ALGORITHM_NAME = "BucketJoin";

        @Override
        public String toString() {
            return ALGORITHM_NAME;
        }

        @Override
        public boolean isExecutable(HiveJoin join) {
            Double maxMemory = join.getCluster().getPlanner().getContext().unwrap(HiveAlgorithmsConf.class).getMaxMemory();
            RelNode smallInput = join.getStreamingInput();
            if (smallInput == null) {
                return false;
            }
            HiveCalciteUtil.JoinPredicateInfo joinPredInfo = join.getJoinPredicateInfo();
            ArrayList<ImmutableIntList> joinKeysInChildren = new ArrayList<ImmutableIntList>();
            joinKeysInChildren.add(ImmutableIntList.copyOf(joinPredInfo.getProjsFromLeftPartOfJoinKeysInChildSchema()));
            joinKeysInChildren.add(ImmutableIntList.copyOf(joinPredInfo.getProjsFromRightPartOfJoinKeysInChildSchema()));
            HiveCostModel.JoinAlgorithm oldAlgo = join.getJoinAlgorithm();
            join.setJoinAlgorithm(INSTANCE);
            Integer buckets = RelMetadataQuery.splitCount(smallInput);
            join.setJoinAlgorithm(oldAlgo);
            if (buckets == null) {
                return false;
            }
            if (!HiveAlgorithmsUtil.isFittingIntoMemory(maxMemory, smallInput, buckets)) {
                return false;
            }
            for (int i = 0; i < join.getInputs().size(); ++i) {
                RelNode input = join.getInputs().get(i);
                RelDistribution distribution = RelMetadataQuery.distribution(input);
                if (distribution.getType() != RelDistribution.Type.HASH_DISTRIBUTED) {
                    return false;
                }
                if (distribution.getKeys().containsAll((Collection)joinKeysInChildren.get(i))) continue;
                return false;
            }
            return true;
        }

        @Override
        public RelOptCost getCost(HiveJoin join) {
            Double leftRCount = RelMetadataQuery.getRowCount(join.getLeft());
            Double rightRCount = RelMetadataQuery.getRowCount(join.getRight());
            if (leftRCount == null || rightRCount == null) {
                return null;
            }
            double rCount = leftRCount + rightRCount;
            ImmutableCollection cardinalities = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(leftRCount)).add(rightRCount)).build();
            ImmutableBitSet.Builder streamingBuilder = new ImmutableBitSet.Builder();
            switch (join.getStreamingSide()) {
                case LEFT_RELATION: {
                    streamingBuilder.set(0);
                    break;
                }
                case RIGHT_RELATION: {
                    streamingBuilder.set(1);
                    break;
                }
                default: {
                    return null;
                }
            }
            ImmutableBitSet streaming = streamingBuilder.build();
            double cpuCost = algoUtils.computeBucketMapJoinCPUCost((ImmutableList<Double>)cardinalities, streaming);
            Double leftRAverageSize = RelMetadataQuery.getAverageRowSize(join.getLeft());
            Double rightRAverageSize = RelMetadataQuery.getAverageRowSize(join.getRight());
            if (leftRAverageSize == null || rightRAverageSize == null) {
                return null;
            }
            ImmutableCollection relationInfos = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(new Pair<Double, Double>(leftRCount, leftRAverageSize))).add(new Pair<Double, Double>(rightRCount, rightRAverageSize))).build();
            HiveCostModel.JoinAlgorithm oldAlgo = join.getJoinAlgorithm();
            join.setJoinAlgorithm(INSTANCE);
            int parallelism = RelMetadataQuery.splitCount(join) == null ? 1 : RelMetadataQuery.splitCount(join);
            join.setJoinAlgorithm(oldAlgo);
            double ioCost = algoUtils.computeBucketMapJoinIOCost((ImmutableList<Pair<Double, Double>>)relationInfos, streaming, parallelism);
            return HiveCost.FACTORY.makeCost(rCount, cpuCost, ioCost);
        }

        @Override
        public ImmutableList<RelCollation> getCollation(HiveJoin join) {
            if (join.getStreamingSide() != HiveJoin.MapJoinStreamingRelation.LEFT_RELATION || join.getStreamingSide() != HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                return null;
            }
            return HiveAlgorithmsUtil.getJoinCollation(join.getJoinPredicateInfo(), join.getStreamingSide());
        }

        @Override
        public RelDistribution getDistribution(HiveJoin join) {
            return HiveAlgorithmsUtil.getJoinRedistribution(join.getJoinPredicateInfo());
        }

        @Override
        public Double getMemory(HiveJoin join) {
            return HiveAlgorithmsUtil.getJoinMemory(join);
        }

        @Override
        public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin join) {
            RelNode inMemoryInput;
            if (join.getStreamingSide() == HiveJoin.MapJoinStreamingRelation.LEFT_RELATION) {
                inMemoryInput = join.getRight();
            } else if (join.getStreamingSide() == HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                inMemoryInput = join.getLeft();
            } else {
                return null;
            }
            Double memoryInput = RelMetadataQuery.cumulativeMemoryWithinPhase(inMemoryInput);
            Integer splitCount = RelMetadataQuery.splitCount(inMemoryInput);
            if (memoryInput == null || splitCount == null) {
                return null;
            }
            return memoryInput / (double)splitCount.intValue();
        }

        @Override
        public Boolean isPhaseTransition(HiveJoin join) {
            return false;
        }

        @Override
        public Integer getSplitCount(HiveJoin join) {
            return HiveAlgorithmsUtil.getSplitCountWithoutRepartition(join);
        }
    }

    public static class TezMapJoinAlgorithm
    implements HiveCostModel.JoinAlgorithm {
        public static final HiveCostModel.JoinAlgorithm INSTANCE = new TezMapJoinAlgorithm();
        private static final String ALGORITHM_NAME = "MapJoin";

        @Override
        public String toString() {
            return ALGORITHM_NAME;
        }

        @Override
        public boolean isExecutable(HiveJoin join) {
            Double maxMemory = join.getCluster().getPlanner().getContext().unwrap(HiveAlgorithmsConf.class).getMaxMemory();
            RelNode smallInput = join.getStreamingInput();
            if (smallInput == null) {
                return false;
            }
            return HiveAlgorithmsUtil.isFittingIntoMemory(maxMemory, smallInput, 1);
        }

        @Override
        public RelOptCost getCost(HiveJoin join) {
            Double leftRCount = RelMetadataQuery.getRowCount(join.getLeft());
            Double rightRCount = RelMetadataQuery.getRowCount(join.getRight());
            if (leftRCount == null || rightRCount == null) {
                return null;
            }
            double rCount = leftRCount + rightRCount;
            ImmutableCollection cardinalities = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(leftRCount)).add(rightRCount)).build();
            ImmutableBitSet.Builder streamingBuilder = new ImmutableBitSet.Builder();
            switch (join.getStreamingSide()) {
                case LEFT_RELATION: {
                    streamingBuilder.set(0);
                    break;
                }
                case RIGHT_RELATION: {
                    streamingBuilder.set(1);
                    break;
                }
                default: {
                    return null;
                }
            }
            ImmutableBitSet streaming = streamingBuilder.build();
            double cpuCost = HiveAlgorithmsUtil.computeMapJoinCPUCost((ImmutableList<Double>)cardinalities, streaming);
            Double leftRAverageSize = RelMetadataQuery.getAverageRowSize(join.getLeft());
            Double rightRAverageSize = RelMetadataQuery.getAverageRowSize(join.getRight());
            if (leftRAverageSize == null || rightRAverageSize == null) {
                return null;
            }
            ImmutableCollection relationInfos = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(new Pair<Double, Double>(leftRCount, leftRAverageSize))).add(new Pair<Double, Double>(rightRCount, rightRAverageSize))).build();
            HiveCostModel.JoinAlgorithm oldAlgo = join.getJoinAlgorithm();
            join.setJoinAlgorithm(INSTANCE);
            int parallelism = RelMetadataQuery.splitCount(join) == null ? 1 : RelMetadataQuery.splitCount(join);
            join.setJoinAlgorithm(oldAlgo);
            double ioCost = algoUtils.computeMapJoinIOCost((ImmutableList<Pair<Double, Double>>)relationInfos, streaming, parallelism);
            return HiveCost.FACTORY.makeCost(rCount, cpuCost, ioCost);
        }

        @Override
        public ImmutableList<RelCollation> getCollation(HiveJoin join) {
            if (join.getStreamingSide() != HiveJoin.MapJoinStreamingRelation.LEFT_RELATION || join.getStreamingSide() != HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                return null;
            }
            return HiveAlgorithmsUtil.getJoinCollation(join.getJoinPredicateInfo(), join.getStreamingSide());
        }

        @Override
        public RelDistribution getDistribution(HiveJoin join) {
            if (join.getStreamingSide() != HiveJoin.MapJoinStreamingRelation.LEFT_RELATION || join.getStreamingSide() != HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                return null;
            }
            return HiveAlgorithmsUtil.getJoinDistribution(join.getJoinPredicateInfo(), join.getStreamingSide());
        }

        @Override
        public Double getMemory(HiveJoin join) {
            return HiveAlgorithmsUtil.getJoinMemory(join);
        }

        @Override
        public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin join) {
            RelNode inMemoryInput;
            if (join.getStreamingSide() == HiveJoin.MapJoinStreamingRelation.LEFT_RELATION) {
                inMemoryInput = join.getRight();
            } else if (join.getStreamingSide() == HiveJoin.MapJoinStreamingRelation.RIGHT_RELATION) {
                inMemoryInput = join.getLeft();
            } else {
                return null;
            }
            return RelMetadataQuery.cumulativeMemoryWithinPhase(inMemoryInput);
        }

        @Override
        public Boolean isPhaseTransition(HiveJoin join) {
            return false;
        }

        @Override
        public Integer getSplitCount(HiveJoin join) {
            return HiveAlgorithmsUtil.getSplitCountWithoutRepartition(join);
        }
    }

    public static class TezCommonJoinAlgorithm
    implements HiveCostModel.JoinAlgorithm {
        public static final HiveCostModel.JoinAlgorithm INSTANCE = new TezCommonJoinAlgorithm();
        private static final String ALGORITHM_NAME = "CommonJoin";

        @Override
        public String toString() {
            return ALGORITHM_NAME;
        }

        @Override
        public boolean isExecutable(HiveJoin join) {
            return true;
        }

        @Override
        public RelOptCost getCost(HiveJoin join) {
            Double leftRCount = RelMetadataQuery.getRowCount(join.getLeft());
            Double rightRCount = RelMetadataQuery.getRowCount(join.getRight());
            if (leftRCount == null || rightRCount == null) {
                return null;
            }
            double rCount = leftRCount + rightRCount;
            ImmutableCollection cardinalities = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(leftRCount)).add(rightRCount)).build();
            double cpuCost = algoUtils.computeSortMergeCPUCost((ImmutableList<Double>)cardinalities, join.getSortedInputs());
            Double leftRAverageSize = RelMetadataQuery.getAverageRowSize(join.getLeft());
            Double rightRAverageSize = RelMetadataQuery.getAverageRowSize(join.getRight());
            if (leftRAverageSize == null || rightRAverageSize == null) {
                return null;
            }
            ImmutableCollection relationInfos = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(new Pair<Double, Double>(leftRCount, leftRAverageSize))).add(new Pair<Double, Double>(rightRCount, rightRAverageSize))).build();
            double ioCost = algoUtils.computeSortMergeIOCost((ImmutableList<Pair<Double, Double>>)relationInfos);
            return HiveCost.FACTORY.makeCost(rCount, cpuCost, ioCost);
        }

        @Override
        public ImmutableList<RelCollation> getCollation(HiveJoin join) {
            return HiveAlgorithmsUtil.getJoinCollation(join.getJoinPredicateInfo(), HiveJoin.MapJoinStreamingRelation.NONE);
        }

        @Override
        public RelDistribution getDistribution(HiveJoin join) {
            return HiveAlgorithmsUtil.getJoinRedistribution(join.getJoinPredicateInfo());
        }

        @Override
        public Double getMemory(HiveJoin join) {
            return HiveAlgorithmsUtil.getJoinMemory(join, HiveJoin.MapJoinStreamingRelation.NONE);
        }

        @Override
        public Double getCumulativeMemoryWithinPhaseSplit(HiveJoin join) {
            HiveCostModel.JoinAlgorithm oldAlgo = join.getJoinAlgorithm();
            join.setJoinAlgorithm(INSTANCE);
            Double memoryWithinPhase = RelMetadataQuery.cumulativeMemoryWithinPhase(join);
            Integer splitCount = RelMetadataQuery.splitCount(join);
            join.setJoinAlgorithm(oldAlgo);
            if (memoryWithinPhase == null || splitCount == null) {
                return null;
            }
            return memoryWithinPhase / (double)splitCount.intValue();
        }

        @Override
        public Boolean isPhaseTransition(HiveJoin join) {
            return true;
        }

        @Override
        public Integer getSplitCount(HiveJoin join) {
            return HiveAlgorithmsUtil.getSplitCountWithRepartition(join);
        }
    }
}

