/*
 * Decompiled with CFR 0.152.
 */
package hive.org.apache.calcite.rel.metadata;

import hive.com.google.common.collect.ImmutableList;
import hive.org.apache.calcite.plan.RelOptUtil;
import hive.org.apache.calcite.rel.RelCollation;
import hive.org.apache.calcite.rel.RelNode;
import hive.org.apache.calcite.rel.core.Aggregate;
import hive.org.apache.calcite.rel.core.AggregateCall;
import hive.org.apache.calcite.rel.core.Join;
import hive.org.apache.calcite.rel.core.JoinRelType;
import hive.org.apache.calcite.rel.core.Minus;
import hive.org.apache.calcite.rel.core.Project;
import hive.org.apache.calcite.rel.core.SemiJoin;
import hive.org.apache.calcite.rel.core.Union;
import hive.org.apache.calcite.rel.metadata.RelMetadataQuery;
import hive.org.apache.calcite.rex.RexBuilder;
import hive.org.apache.calcite.rex.RexCall;
import hive.org.apache.calcite.rex.RexInputRef;
import hive.org.apache.calcite.rex.RexLiteral;
import hive.org.apache.calcite.rex.RexLocalRef;
import hive.org.apache.calcite.rex.RexNode;
import hive.org.apache.calcite.rex.RexProgram;
import hive.org.apache.calcite.rex.RexUtil;
import hive.org.apache.calcite.rex.RexVisitorImpl;
import hive.org.apache.calcite.sql.SqlFunction;
import hive.org.apache.calcite.sql.SqlFunctionCategory;
import hive.org.apache.calcite.sql.SqlKind;
import hive.org.apache.calcite.sql.SqlOperator;
import hive.org.apache.calcite.sql.type.OperandTypes;
import hive.org.apache.calcite.sql.type.ReturnTypes;
import hive.org.apache.calcite.util.ImmutableBitSet;
import hive.org.apache.calcite.util.NumberUtil;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

public class RelMdUtil {
    public static final SqlFunction ARTIFICIAL_SELECTIVITY_FUNC = new SqlFunction("ARTIFICIAL_SELECTIVITY", SqlKind.OTHER_FUNCTION, ReturnTypes.BOOLEAN, null, OperandTypes.NUMERIC, SqlFunctionCategory.SYSTEM);

    private RelMdUtil() {
    }

    public static RexNode makeSemiJoinSelectivityRexNode(RelMetadataQuery mq, SemiJoin rel) {
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        double selectivity = RelMdUtil.computeSemiJoinSelectivity(mq, rel.getLeft(), rel.getRight(), rel);
        RexLiteral selec = rexBuilder.makeApproxLiteral(new BigDecimal(selectivity));
        return rexBuilder.makeCall((SqlOperator)ARTIFICIAL_SELECTIVITY_FUNC, selec);
    }

    public static double getSelectivityValue(RexNode artificialSelecFuncNode) {
        assert (artificialSelecFuncNode instanceof RexCall);
        RexCall call = (RexCall)artificialSelecFuncNode;
        assert (call.getOperator() == ARTIFICIAL_SELECTIVITY_FUNC);
        RexNode operand = call.getOperands().get(0);
        BigDecimal bd = (BigDecimal)((RexLiteral)operand).getValue();
        return bd.doubleValue();
    }

    public static double computeSemiJoinSelectivity(RelMetadataQuery mq, SemiJoin rel) {
        return RelMdUtil.computeSemiJoinSelectivity(mq, rel.getLeft(), rel.getRight(), rel.getLeftKeys(), rel.getRightKeys());
    }

    public static double computeSemiJoinSelectivity(RelMetadataQuery mq, RelNode factRel, RelNode dimRel, SemiJoin rel) {
        return RelMdUtil.computeSemiJoinSelectivity(mq, factRel, dimRel, rel.getLeftKeys(), rel.getRightKeys());
    }

    public static double computeSemiJoinSelectivity(RelMetadataQuery mq, RelNode factRel, RelNode dimRel, List<Integer> factKeyList, List<Integer> dimKeyList) {
        Double selectivity;
        Double dimCard;
        ImmutableBitSet.Builder factKeys = ImmutableBitSet.builder();
        for (int factCol : factKeyList) {
            factKeys.set(factCol);
        }
        ImmutableBitSet.Builder dimKeyBuilder = ImmutableBitSet.builder();
        for (int dimCol : dimKeyList) {
            dimKeyBuilder.set(dimCol);
        }
        ImmutableBitSet dimKeys = dimKeyBuilder.build();
        Double factPop = mq.getPopulationSize(factRel, factKeys.build());
        if (factPop == null) {
            factPop = mq.getPopulationSize(dimRel, dimKeys);
        }
        if ((dimCard = mq.getDistinctRowCount(dimRel, dimKeys, null)) != null && factPop != null) {
            if (factPop < 1.0) {
                factPop = 1.0;
            }
            selectivity = dimCard / factPop;
        } else {
            selectivity = mq.getPercentageOriginalRows(dimRel);
        }
        if (selectivity == null) {
            selectivity = Math.pow(0.1, dimKeys.cardinality());
        } else if (selectivity > 1.0) {
            selectivity = 1.0;
        }
        return selectivity;
    }

    public static boolean areColumnsDefinitelyUnique(RelMetadataQuery mq, RelNode rel, ImmutableBitSet colMask) {
        Boolean b = mq.areColumnsUnique(rel, colMask, false);
        return b != null && b != false;
    }

    public static Boolean areColumnsUnique(RelMetadataQuery mq, RelNode rel, List<RexInputRef> columnRefs) {
        ImmutableBitSet.Builder colMask = ImmutableBitSet.builder();
        for (RexInputRef columnRef : columnRefs) {
            colMask.set(columnRef.getIndex());
        }
        return mq.areColumnsUnique(rel, colMask.build());
    }

    public static boolean areColumnsDefinitelyUnique(RelMetadataQuery mq, RelNode rel, List<RexInputRef> columnRefs) {
        Boolean b = RelMdUtil.areColumnsUnique(mq, rel, columnRefs);
        return b != null && b != false;
    }

    public static boolean areColumnsDefinitelyUniqueWhenNullsFiltered(RelMetadataQuery mq, RelNode rel, ImmutableBitSet colMask) {
        Boolean b = mq.areColumnsUnique(rel, colMask, true);
        if (b == null) {
            return false;
        }
        return b;
    }

    public static Boolean areColumnsUniqueWhenNullsFiltered(RelMetadataQuery mq, RelNode rel, List<RexInputRef> columnRefs) {
        ImmutableBitSet.Builder colMask = ImmutableBitSet.builder();
        for (RexInputRef columnRef : columnRefs) {
            colMask.set(columnRef.getIndex());
        }
        return mq.areColumnsUnique(rel, colMask.build(), true);
    }

    public static boolean areColumnsDefinitelyUniqueWhenNullsFiltered(RelMetadataQuery mq, RelNode rel, List<RexInputRef> columnRefs) {
        Boolean b = RelMdUtil.areColumnsUniqueWhenNullsFiltered(mq, rel, columnRefs);
        if (b == null) {
            return false;
        }
        return b;
    }

    public static void setLeftRightBitmaps(ImmutableBitSet groupKey, ImmutableBitSet.Builder leftMask, ImmutableBitSet.Builder rightMask, int nFieldsOnLeft) {
        for (int bit : groupKey) {
            if (bit < nFieldsOnLeft) {
                leftMask.set(bit);
                continue;
            }
            rightMask.set(bit - nFieldsOnLeft);
        }
    }

    public static Double numDistinctVals(Double domainSize, Double numSelected) {
        double res;
        if (domainSize == null || numSelected == null) {
            return null;
        }
        double dSize = RelMdUtil.capInfinity(domainSize);
        double numSel = RelMdUtil.capInfinity(numSelected);
        double d = res = dSize > 0.0 ? (1.0 - Math.exp(-1.0 * numSel / dSize)) * dSize : 0.0;
        if (res > dSize) {
            res = dSize;
        }
        if (res > numSel) {
            res = numSel;
        }
        if (res < 0.0) {
            res = 0.0;
        }
        return res;
    }

    public static double capInfinity(Double d) {
        return d.isInfinite() ? Double.MAX_VALUE : d;
    }

    public static double guessSelectivity(RexNode predicate) {
        return RelMdUtil.guessSelectivity(predicate, false);
    }

    public static double guessSelectivity(RexNode predicate, boolean artificialOnly) {
        double sel = 1.0;
        if (predicate == null || predicate.isAlwaysTrue()) {
            return sel;
        }
        double artificialSel = 1.0;
        for (RexNode pred : RelOptUtil.conjunctions(predicate)) {
            if (pred.getKind() == SqlKind.IS_NOT_NULL) {
                sel *= 0.9;
                continue;
            }
            if (pred instanceof RexCall && ((RexCall)pred).getOperator() == ARTIFICIAL_SELECTIVITY_FUNC) {
                artificialSel *= RelMdUtil.getSelectivityValue(pred);
                continue;
            }
            if (pred.isA(SqlKind.EQUALS)) {
                sel *= 0.15;
                continue;
            }
            if (pred.isA(SqlKind.COMPARISON)) {
                sel *= 0.5;
                continue;
            }
            sel *= 0.25;
        }
        if (artificialOnly) {
            return artificialSel;
        }
        return sel * artificialSel;
    }

    public static RexNode unionPreds(RexBuilder rexBuilder, RexNode pred1, RexNode pred2) {
        ArrayList<RexNode> unionList = new ArrayList<RexNode>();
        HashSet<String> strings = new HashSet<String>();
        for (RexNode rex : RelOptUtil.conjunctions(pred1)) {
            if (!strings.add(rex.toString())) continue;
            unionList.add(rex);
        }
        for (RexNode rex2 : RelOptUtil.conjunctions(pred2)) {
            if (!strings.add(rex2.toString())) continue;
            unionList.add(rex2);
        }
        return RexUtil.composeConjunction(rexBuilder, unionList, true);
    }

    public static RexNode minusPreds(RexBuilder rexBuilder, RexNode pred1, RexNode pred2) {
        List<RexNode> list1 = RelOptUtil.conjunctions(pred1);
        List<RexNode> list2 = RelOptUtil.conjunctions(pred2);
        ArrayList<RexNode> minusList = new ArrayList<RexNode>();
        for (RexNode rex1 : list1) {
            boolean add = true;
            for (RexNode rex2 : list2) {
                if (rex2.toString().compareTo(rex1.toString()) != 0) continue;
                add = false;
                break;
            }
            if (!add) continue;
            minusList.add(rex1);
        }
        return RexUtil.composeConjunction(rexBuilder, minusList, true);
    }

    public static void setAggChildKeys(ImmutableBitSet groupKey, Aggregate aggRel, ImmutableBitSet.Builder childKey) {
        List<AggregateCall> aggCalls = aggRel.getAggCallList();
        for (int bit : groupKey) {
            if (bit < aggRel.getGroupCount()) {
                childKey.set(bit);
                continue;
            }
            AggregateCall agg = aggCalls.get(bit - (aggRel.getGroupCount() + aggRel.getIndicatorCount()));
            for (Integer arg : agg.getArgList()) {
                childKey.set(arg);
            }
        }
    }

    public static void splitCols(List<RexNode> projExprs, ImmutableBitSet groupKey, ImmutableBitSet.Builder baseCols, ImmutableBitSet.Builder projCols) {
        for (int bit : groupKey) {
            RexNode e = projExprs.get(bit);
            if (e instanceof RexInputRef) {
                baseCols.set(((RexInputRef)e).getIndex());
                continue;
            }
            projCols.set(bit);
        }
    }

    public static Double cardOfProjExpr(RelMetadataQuery mq, Project rel, RexNode expr) {
        return expr.accept(new CardOfProjExpr(mq, rel));
    }

    public static Double getJoinPopulationSize(RelMetadataQuery mq, RelNode joinRel, ImmutableBitSet groupKey) {
        ImmutableBitSet.Builder leftMask = ImmutableBitSet.builder();
        ImmutableBitSet.Builder rightMask = ImmutableBitSet.builder();
        RelNode left = joinRel.getInputs().get(0);
        RelNode right = joinRel.getInputs().get(1);
        RelMdUtil.setLeftRightBitmaps(groupKey, leftMask, rightMask, left.getRowType().getFieldCount());
        Double population = NumberUtil.multiply(mq.getPopulationSize(left, leftMask.build()), mq.getPopulationSize(right, rightMask.build()));
        return RelMdUtil.numDistinctVals(population, mq.getRowCount(joinRel));
    }

    public static Double getJoinDistinctRowCount(RelMetadataQuery mq, RelNode joinRel, JoinRelType joinType, ImmutableBitSet groupKey, RexNode predicate, boolean useMaxNdv) {
        ImmutableBitSet.Builder leftMask = ImmutableBitSet.builder();
        ImmutableBitSet.Builder rightMask = ImmutableBitSet.builder();
        RelNode left = joinRel.getInputs().get(0);
        RelNode right = joinRel.getInputs().get(1);
        RelMdUtil.setLeftRightBitmaps(groupKey, leftMask, rightMask, left.getRowType().getFieldCount());
        RexNode leftPred = null;
        RexNode rightPred = null;
        if (predicate != null) {
            ArrayList<RexNode> leftFilters = new ArrayList<RexNode>();
            ArrayList<RexNode> rightFilters = new ArrayList<RexNode>();
            ArrayList<RexNode> joinFilters = new ArrayList<RexNode>();
            List<RexNode> predList = RelOptUtil.conjunctions(predicate);
            RelOptUtil.classifyFilters(joinRel, predList, joinType, joinType == JoinRelType.INNER, !joinType.generatesNullsOnLeft(), !joinType.generatesNullsOnRight(), joinFilters, leftFilters, rightFilters);
            RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
            leftPred = RexUtil.composeConjunction(rexBuilder, leftFilters, true);
            rightPred = RexUtil.composeConjunction(rexBuilder, rightFilters, true);
        }
        Double distRowCount = useMaxNdv ? Double.valueOf(Math.max(mq.getDistinctRowCount(left, leftMask.build(), leftPred), mq.getDistinctRowCount(right, rightMask.build(), rightPred))) : NumberUtil.multiply(mq.getDistinctRowCount(left, leftMask.build(), leftPred), mq.getDistinctRowCount(right, rightMask.build(), rightPred));
        return RelMdUtil.numDistinctVals(distRowCount, mq.getRowCount(joinRel));
    }

    public static double getUnionAllRowCount(RelMetadataQuery mq, Union rel) {
        double rowCount = 0.0;
        for (RelNode input : rel.getInputs()) {
            rowCount += mq.getRowCount(input).doubleValue();
        }
        return rowCount;
    }

    public static double getMinusRowCount(RelMetadataQuery mq, Minus minus) {
        List<RelNode> inputs = minus.getInputs();
        double dRows = mq.getRowCount(inputs.get(0));
        for (int i = 1; i < inputs.size(); ++i) {
            dRows -= 0.5 * mq.getRowCount(inputs.get(i));
        }
        if (dRows < 0.0) {
            dRows = 0.0;
        }
        return dRows;
    }

    public static Double getJoinRowCount(RelMetadataQuery mq, Join join, RexNode condition) {
        Double max;
        Double left = mq.getRowCount(join.getLeft());
        Double right = mq.getRowCount(join.getRight());
        if (left == null || right == null) {
            return null;
        }
        if ((left <= 1.0 || right <= 1.0) && (max = mq.getMaxRowCount(join)) != null && max <= 1.0) {
            return max;
        }
        double product = left * right;
        return product * mq.getSelectivity(join, condition);
    }

    public static Double getSemiJoinRowCount(RelMetadataQuery mq, RelNode left, RelNode right, JoinRelType joinType, RexNode condition) {
        Double leftCount = mq.getRowCount(left);
        if (leftCount == null) {
            return null;
        }
        return leftCount * RexUtil.getSelectivity(condition);
    }

    public static double estimateFilteredRows(RelNode child, RexProgram program, RelMetadataQuery mq) {
        RexLocalRef programCondition = program.getCondition();
        RexNode condition = programCondition == null ? null : program.expandLocalRef(programCondition);
        return RelMdUtil.estimateFilteredRows(child, condition, mq);
    }

    public static double estimateFilteredRows(RelNode child, RexNode condition, RelMetadataQuery mq) {
        return mq.getRowCount(child) * mq.getSelectivity(child, condition);
    }

    public static boolean checkInputForCollationAndLimit(RelMetadataQuery mq, RelNode input, RelCollation collation, RexNode offset, RexNode fetch) {
        int limit;
        int offsetVal;
        boolean alreadySorted = false;
        for (RelCollation inputCollation : mq.collations(input)) {
            if (!inputCollation.satisfies(collation)) continue;
            alreadySorted = true;
            break;
        }
        boolean alreadySmaller = true;
        Double rowCount = mq.getMaxRowCount(input);
        if (rowCount != null && fetch != null && (double)(offsetVal = offset == null ? 0 : RexLiteral.intValue(offset)) + (double)(limit = RexLiteral.intValue(fetch)) < rowCount) {
            alreadySmaller = false;
        }
        return alreadySorted && alreadySmaller;
    }

    private static class CardOfProjExpr
    extends RexVisitorImpl<Double> {
        private final RelMetadataQuery mq;
        private Project rel;

        public CardOfProjExpr(RelMetadataQuery mq, Project rel) {
            super(true);
            this.mq = mq;
            this.rel = rel;
        }

        @Override
        public Double visitInputRef(RexInputRef var) {
            int index = var.getIndex();
            ImmutableBitSet col = ImmutableBitSet.of(index);
            Double distinctRowCount = this.mq.getDistinctRowCount(this.rel.getInput(), col, null);
            if (distinctRowCount == null) {
                return null;
            }
            return RelMdUtil.numDistinctVals(distinctRowCount, this.mq.getRowCount(this.rel));
        }

        @Override
        public Double visitLiteral(RexLiteral literal) {
            return RelMdUtil.numDistinctVals(1.0, this.mq.getRowCount(this.rel));
        }

        @Override
        public Double visitCall(RexCall call) {
            Double distinctRowCount;
            Double rowCount = this.mq.getRowCount(this.rel);
            if (call.isA(SqlKind.MINUS_PREFIX)) {
                distinctRowCount = RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(0));
            } else if (call.isA(ImmutableList.of(SqlKind.PLUS, SqlKind.MINUS))) {
                Double card0 = RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(0));
                if (card0 == null) {
                    return null;
                }
                Double card1 = RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(1));
                if (card1 == null) {
                    return null;
                }
                distinctRowCount = Math.max(card0, card1);
            } else {
                distinctRowCount = call.isA(ImmutableList.of(SqlKind.TIMES, SqlKind.DIVIDE)) ? NumberUtil.multiply(RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(0)), RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(1))) : (call.getOperands().size() == 1 ? RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(0)) : Double.valueOf(rowCount / 10.0));
            }
            return RelMdUtil.numDistinctVals(distinctRowCount, rowCount);
        }
    }
}

