/*
 * Decompiled with CFR 0.152.
 */
package hive.org.apache.calcite.rel.rules;

import hive.com.google.common.collect.ImmutableList;
import hive.com.google.common.collect.Lists;
import hive.com.google.common.collect.Maps;
import hive.org.apache.calcite.plan.RelOptRule;
import hive.org.apache.calcite.plan.RelOptRuleCall;
import hive.org.apache.calcite.plan.RelOptRuleOperand;
import hive.org.apache.calcite.plan.RelOptUtil;
import hive.org.apache.calcite.rel.RelNode;
import hive.org.apache.calcite.rel.core.Aggregate;
import hive.org.apache.calcite.rel.core.AggregateCall;
import hive.org.apache.calcite.rel.logical.LogicalAggregate;
import hive.org.apache.calcite.rel.type.RelDataType;
import hive.org.apache.calcite.rel.type.RelDataTypeFactory;
import hive.org.apache.calcite.rel.type.RelDataTypeField;
import hive.org.apache.calcite.rex.RexBuilder;
import hive.org.apache.calcite.rex.RexLiteral;
import hive.org.apache.calcite.rex.RexNode;
import hive.org.apache.calcite.sql.SqlOperator;
import hive.org.apache.calcite.sql.fun.SqlAvgAggFunction;
import hive.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import hive.org.apache.calcite.sql.fun.SqlSumAggFunction;
import hive.org.apache.calcite.sql.type.SqlTypeUtil;
import hive.org.apache.calcite.util.CompositeList;
import hive.org.apache.calcite.util.ImmutableIntList;
import hive.org.apache.calcite.util.Util;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class AggregateReduceFunctionsRule
extends RelOptRule {
    public static final AggregateReduceFunctionsRule INSTANCE = new AggregateReduceFunctionsRule(AggregateReduceFunctionsRule.operand(LogicalAggregate.class, AggregateReduceFunctionsRule.any()));

    protected AggregateReduceFunctionsRule(RelOptRuleOperand operand) {
        super(operand);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        if (!super.matches(call)) {
            return false;
        }
        Aggregate oldAggRel = (Aggregate)call.rels[0];
        return this.containsAvgStddevVarCall(oldAggRel.getAggCallList());
    }

    @Override
    public void onMatch(RelOptRuleCall ruleCall) {
        Aggregate oldAggRel = (Aggregate)ruleCall.rels[0];
        this.reduceAggs(ruleCall, oldAggRel);
    }

    private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
        for (AggregateCall call : aggCallList) {
            if (!(call.getAggregation() instanceof SqlAvgAggFunction) && !(call.getAggregation() instanceof SqlSumAggFunction)) continue;
            return true;
        }
        return false;
    }

    private void reduceAggs(RelOptRuleCall ruleCall, Aggregate oldAggRel) {
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
        int groupCount = oldAggRel.getGroupCount();
        int indicatorCount = oldAggRel.getIndicatorCount();
        ArrayList<AggregateCall> newCalls = Lists.newArrayList();
        HashMap<AggregateCall, RexNode> aggCallMapping = Maps.newHashMap();
        ArrayList<RexNode> projList = Lists.newArrayList();
        for (int i = 0; i < groupCount + indicatorCount; ++i) {
            projList.add(rexBuilder.makeInputRef(this.getFieldType(oldAggRel, i), i));
        }
        RelNode input = oldAggRel.getInput();
        ArrayList<RexNode> inputExprs = new ArrayList<RexNode>(rexBuilder.identityProjects(input.getRowType()));
        for (AggregateCall oldCall : oldCalls) {
            projList.add(this.reduceAgg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
        }
        int extraArgCount = inputExprs.size() - input.getRowType().getFieldCount();
        if (extraArgCount > 0) {
            input = RelOptUtil.createProject(input, inputExprs, CompositeList.of(input.getRowType().getFieldNames(), Collections.nCopies(extraArgCount, null)));
        }
        Aggregate newAggRel = this.newAggregateRel(oldAggRel, input, newCalls);
        RelNode projectRel = RelOptUtil.createProject((RelNode)newAggRel, projList, oldAggRel.getRowType().getFieldNames());
        ruleCall.transformTo(projectRel);
    }

    private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        if (oldCall.getAggregation() instanceof SqlSumAggFunction) {
            return this.reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
        }
        if (oldCall.getAggregation() instanceof SqlAvgAggFunction) {
            SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction)oldCall.getAggregation()).getSubtype();
            switch (subtype) {
                case AVG: {
                    return this.reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
                }
                case STDDEV_POP: {
                    return this.reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
                }
                case STDDEV_SAMP: {
                    return this.reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
                }
                case VAR_POP: {
                    return this.reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
                }
                case VAR_SAMP: {
                    return this.reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
                }
            }
            throw Util.unexpected(subtype);
        }
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int nGroups = oldAggRel.getGroupCount();
        List<RelDataType> oldArgTypes = SqlTypeUtil.projectTypes(oldAggRel.getInput().getRowType(), oldCall.getArgList());
        return rexBuilder.addAggCall(oldCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, oldArgTypes);
    }

    private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
        RelDataType avgInputType;
        int nGroups = oldAggRel.getGroupCount();
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int iAvgInput = oldCall.getArgList().get(0);
        RelDataType sumType = typeFactory.createTypeWithNullability(avgInputType, (avgInputType = this.getFieldType(oldAggRel.getInput(), iAvgInput)).isNullable() || nGroups == 0);
        SqlSumAggFunction sumAgg = new SqlSumAggFunction(sumType);
        AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.getArgList(), oldCall.filterArg, sumType, null);
        AggregateCall countCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
        RexNode numeratorRef = rexBuilder.addAggCall(sumCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
        RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
        RexNode divideRef = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
        return rexBuilder.makeCast(oldCall.getType(), divideRef);
    }

    private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
        int nGroups = oldAggRel.getGroupCount();
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int arg = oldCall.getArgList().get(0);
        RelDataType argType = this.getFieldType(oldAggRel.getInput(), arg);
        RelDataType sumType = typeFactory.createTypeWithNullability(argType, argType.isNullable());
        AggregateCall sumZeroCall = AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(), oldCall.getArgList(), oldCall.filterArg, sumType, oldCall.name);
        AggregateCall countCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel, null, null);
        RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
        if (!oldCall.getType().isNullable()) {
            return sumZeroRef;
        }
        RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
        return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), sumZeroRef);
    }

    private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        RexNode div;
        RexNode denominator;
        int nGroups = oldAggRel.getGroupCount();
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        assert (oldCall.getArgList().size() == 1) : oldCall.getArgList();
        int argOrdinal = oldCall.getArgList().get(0);
        RelDataType argType = this.getFieldType(oldAggRel.getInput(), argOrdinal);
        RexNode argRef = inputExprs.get(argOrdinal);
        RexNode argSquared = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, argRef, argRef);
        int argSquaredOrdinal = AggregateReduceFunctionsRule.lookupOrAdd(inputExprs, argSquared);
        RelDataType sumType = typeFactory.createTypeWithNullability(argType, true);
        AggregateCall sumArgSquaredAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argSquaredOrdinal), oldCall.filterArg, sumType, null);
        RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
        AggregateCall sumArgAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argOrdinal), oldCall.filterArg, sumType, null);
        RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
        RexNode sumSquaredArg = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
        AggregateCall countArgAggCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
        RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
        RexNode avgSumSquaredArg = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
        RexNode diff = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
        if (biased) {
            denominator = countArg;
        } else {
            RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
            RexNode nul = rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
            RexNode countMinusOne = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, countArg, one);
            RexNode countEqOne = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, countArg, one);
            denominator = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
        }
        RexNode result = div = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, diff, denominator);
        if (sqrt) {
            RexLiteral half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
            result = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.POWER, div, half);
        }
        return rexBuilder.makeCast(oldCall.getType(), result);
    }

    private static <T> int lookupOrAdd(List<T> list, T element) {
        int ordinal = list.indexOf(element);
        if (ordinal == -1) {
            ordinal = list.size();
            list.add(element);
        }
        return ordinal;
    }

    protected Aggregate newAggregateRel(Aggregate oldAggregate, RelNode input, List<AggregateCall> newCalls) {
        return LogicalAggregate.create(input, oldAggregate.indicator, oldAggregate.getGroupSet(), oldAggregate.getGroupSets(), newCalls);
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        RelDataTypeField inputField = relNode.getRowType().getFieldList().get(i);
        return inputField.getType();
    }
}

