package org.eigenbase.rel.rules;

import com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.eigenbase.rel.AggregateCall;
import org.eigenbase.rel.AggregateRel;
import org.eigenbase.rel.AggregateRelBase;
import org.eigenbase.rel.CalcRel;
import org.eigenbase.rel.RelNode;
import org.eigenbase.relopt.RelOptRule;
import org.eigenbase.relopt.RelOptRuleCall;
import org.eigenbase.relopt.RelOptRuleOperand;
import org.eigenbase.reltype.RelDataType;
import org.eigenbase.reltype.RelDataTypeFactory;
import org.eigenbase.reltype.RelDataTypeField;
import org.eigenbase.rex.RexBuilder;
import org.eigenbase.rex.RexLiteral;
import org.eigenbase.rex.RexNode;
import org.eigenbase.sql.SqlAggFunction;
import org.eigenbase.sql.fun.SqlAvgAggFunction;
import org.eigenbase.sql.fun.SqlStdOperatorTable;
import org.eigenbase.sql.fun.SqlSumAggFunction;
import org.eigenbase.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.eigenbase.sql.type.SqlTypeUtil;
import org.eigenbase.util.CompositeList;
import org.eigenbase.util.ImmutableIntList;
import org.eigenbase.util.Util;

/* loaded from: input_file:org/eigenbase/rel/rules/ReduceAggregatesRule.class */
public class ReduceAggregatesRule extends RelOptRule {
    public static final ReduceAggregatesRule INSTANCE;
    private static /* synthetic */ int[] $SWITCH_TABLE$org$eigenbase$sql$fun$SqlAvgAggFunction$Subtype;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !ReduceAggregatesRule.class.desiredAssertionStatus();
        INSTANCE = new ReduceAggregatesRule(operand(AggregateRel.class, any()));
    }

    protected ReduceAggregatesRule(RelOptRuleOperand relOptRuleOperand) {
        super(relOptRuleOperand);
    }

    @Override // org.eigenbase.relopt.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        if (super.matches(relOptRuleCall)) {
            return containsAvgStddevVarCall(((AggregateRelBase) relOptRuleCall.rels[0]).getAggCallList());
        }
        return false;
    }

    @Override // org.eigenbase.relopt.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        reduceAggs(relOptRuleCall, (AggregateRelBase) relOptRuleCall.rels[0]);
    }

    private boolean containsAvgStddevVarCall(List<AggregateCall> list) {
        for (AggregateCall aggregateCall : list) {
            if ((aggregateCall.getAggregation() instanceof SqlAvgAggFunction) || (aggregateCall.getAggregation() instanceof SqlSumAggFunction)) {
                return true;
            }
        }
        return false;
    }

    private void reduceAggs(RelOptRuleCall relOptRuleCall, AggregateRelBase aggregateRelBase) {
        RexBuilder rexBuilder = aggregateRelBase.getCluster().getRexBuilder();
        List<AggregateCall> aggCallList = aggregateRelBase.getAggCallList();
        int groupCount = aggregateRelBase.getGroupCount();
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < groupCount; i++) {
            arrayList2.add(rexBuilder.makeInputRef(getFieldType(aggregateRelBase, i), i));
        }
        RelNode child = aggregateRelBase.getChild();
        ArrayList arrayList3 = new ArrayList();
        Iterator<RelDataTypeField> it = child.getRowType().getFieldList().iterator();
        while (it.hasNext()) {
            arrayList3.add(rexBuilder.makeInputRef(it.next().getType(), arrayList3.size()));
        }
        Iterator<AggregateCall> it2 = aggCallList.iterator();
        while (it2.hasNext()) {
            arrayList2.add(reduceAgg(aggregateRelBase, it2.next(), arrayList, hashMap, arrayList3));
        }
        int size = arrayList3.size() - child.getRowType().getFieldCount();
        if (size > 0) {
            child = CalcRel.createProject(child, arrayList3, CompositeList.of((List) child.getRowType().getFieldNames(), Collections.nCopies(size, null)));
        }
        relOptRuleCall.transformTo(CalcRel.createProject(newAggregateRel(aggregateRelBase, child, arrayList), arrayList2, aggregateRelBase.getRowType().getFieldNames()));
    }

    private RexNode reduceAgg(AggregateRelBase aggregateRelBase, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        if (aggregateCall.getAggregation() instanceof SqlSumAggFunction) {
            return reduceSum(aggregateRelBase, aggregateCall, list, map);
        }
        if (!(aggregateCall.getAggregation() instanceof SqlAvgAggFunction)) {
            return aggregateRelBase.getCluster().getRexBuilder().addAggCall(aggregateCall, aggregateRelBase.getGroupCount(), list, map, SqlTypeUtil.projectTypes(aggregateRelBase.getRowType(), aggregateCall.getArgList()));
        }
        SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) aggregateCall.getAggregation()).getSubtype();
        switch ($SWITCH_TABLE$org$eigenbase$sql$fun$SqlAvgAggFunction$Subtype()[subtype.ordinal()]) {
            case 1:
                return reduceAvg(aggregateRelBase, aggregateCall, list, map);
            case 2:
                return reduceStddev(aggregateRelBase, aggregateCall, true, true, list, map, list2);
            case 3:
                return reduceStddev(aggregateRelBase, aggregateCall, false, true, list, map, list2);
            case 4:
                return reduceStddev(aggregateRelBase, aggregateCall, true, false, list, map, list2);
            case 5:
                return reduceStddev(aggregateRelBase, aggregateCall, false, false, list, map, list2);
            default:
                throw Util.unexpected(subtype);
        }
    }

    private RexNode reduceAvg(AggregateRelBase aggregateRelBase, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map) {
        int groupCount = aggregateRelBase.getGroupCount();
        RelDataTypeFactory typeFactory = aggregateRelBase.getCluster().getTypeFactory();
        RexBuilder rexBuilder = aggregateRelBase.getCluster().getRexBuilder();
        RelDataType fieldType = getFieldType(aggregateRelBase.getChild(), aggregateCall.getArgList().get(0).intValue());
        RelDataType createTypeWithNullability = typeFactory.createTypeWithNullability(fieldType, fieldType.isNullable() || groupCount == 0);
        AggregateCall aggregateCall2 = new AggregateCall(new SqlSumAggFunction(createTypeWithNullability), aggregateCall.isDistinct(), aggregateCall.getArgList(), createTypeWithNullability, null);
        SqlAggFunction sqlAggFunction = SqlStdOperatorTable.COUNT;
        return rexBuilder.makeCast(aggregateCall.getType(), rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, rexBuilder.addAggCall(aggregateCall2, groupCount, list, map, ImmutableList.of(fieldType)), rexBuilder.addAggCall(new AggregateCall(sqlAggFunction, aggregateCall.isDistinct(), aggregateCall.getArgList(), sqlAggFunction.getReturnType(typeFactory), null), groupCount, list, map, ImmutableList.of(fieldType))));
    }

    private RexNode reduceSum(AggregateRelBase aggregateRelBase, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map) {
        int groupCount = aggregateRelBase.getGroupCount();
        RelDataTypeFactory typeFactory = aggregateRelBase.getCluster().getTypeFactory();
        RexBuilder rexBuilder = aggregateRelBase.getCluster().getRexBuilder();
        RelDataType fieldType = getFieldType(aggregateRelBase.getChild(), aggregateCall.getArgList().get(0).intValue());
        RelDataType createTypeWithNullability = typeFactory.createTypeWithNullability(fieldType, fieldType.isNullable());
        AggregateCall aggregateCall2 = new AggregateCall(new SqlSumEmptyIsZeroAggFunction(createTypeWithNullability), aggregateCall.isDistinct(), aggregateCall.getArgList(), createTypeWithNullability, null);
        SqlAggFunction sqlAggFunction = SqlStdOperatorTable.COUNT;
        AggregateCall aggregateCall3 = new AggregateCall(sqlAggFunction, aggregateCall.isDistinct(), aggregateCall.getArgList(), sqlAggFunction.getReturnType(typeFactory), null);
        RexNode addAggCall = rexBuilder.addAggCall(aggregateCall2, groupCount, list, map, ImmutableList.of(fieldType));
        if (!aggregateCall.getType().isNullable()) {
            return addAggCall;
        }
        return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexBuilder.addAggCall(aggregateCall3, groupCount, list, map, ImmutableList.of(fieldType)), rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), addAggCall);
    }

    private RexNode reduceStddev(AggregateRelBase aggregateRelBase, AggregateCall aggregateCall, boolean z, boolean z2, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        RexNode makeCall;
        int groupCount = aggregateRelBase.getGroupCount();
        RelDataTypeFactory typeFactory = aggregateRelBase.getCluster().getTypeFactory();
        RexBuilder rexBuilder = aggregateRelBase.getCluster().getRexBuilder();
        if (!$assertionsDisabled && aggregateCall.getArgList().size() != 1) {
            throw new AssertionError(aggregateCall.getArgList());
        }
        int intValue = aggregateCall.getArgList().get(0).intValue();
        RelDataType fieldType = getFieldType(aggregateRelBase.getChild(), intValue);
        RexNode rexNode = list2.get(intValue);
        int lookupOrAdd = lookupOrAdd(list2, rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, rexNode, rexNode));
        RelDataType createTypeWithNullability = typeFactory.createTypeWithNullability(fieldType, true);
        RexNode addAggCall = rexBuilder.addAggCall(new AggregateCall(new SqlSumAggFunction(createTypeWithNullability), aggregateCall.isDistinct(), ImmutableIntList.of(lookupOrAdd), createTypeWithNullability, null), groupCount, list, map, ImmutableList.of(fieldType));
        RexNode addAggCall2 = rexBuilder.addAggCall(new AggregateCall(new SqlSumAggFunction(createTypeWithNullability), aggregateCall.isDistinct(), ImmutableIntList.of(intValue), createTypeWithNullability, null), groupCount, list, map, ImmutableList.of(fieldType));
        RexNode makeCall2 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, addAggCall2, addAggCall2);
        SqlAggFunction sqlAggFunction = SqlStdOperatorTable.COUNT;
        RexNode addAggCall3 = rexBuilder.addAggCall(new AggregateCall(sqlAggFunction, aggregateCall.isDistinct(), aggregateCall.getArgList(), sqlAggFunction.getReturnType(typeFactory), null), groupCount, list, map, ImmutableList.of(fieldType));
        RexNode makeCall3 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, addAggCall, rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, makeCall2, addAggCall3));
        if (z) {
            makeCall = addAggCall3;
        } else {
            RexLiteral makeExactLiteral = rexBuilder.makeExactLiteral(BigDecimal.ONE);
            RexNode makeNullLiteral = rexBuilder.makeNullLiteral(addAggCall3.getType().getSqlTypeName());
            RexNode makeCall4 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, addAggCall3, makeExactLiteral);
            makeCall = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, addAggCall3, makeExactLiteral), makeNullLiteral, makeCall4);
        }
        RexNode makeCall5 = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, makeCall3, makeCall);
        RexNode rexNode2 = makeCall5;
        if (z2) {
            rexNode2 = rexBuilder.makeCall(SqlStdOperatorTable.POWER, makeCall5, rexBuilder.makeExactLiteral(new BigDecimal("0.5")));
        }
        return rexBuilder.makeCast(aggregateCall.getType(), rexNode2);
    }

    private static <T> int lookupOrAdd(List<T> list, T t) {
        int indexOf = list.indexOf(t);
        if (indexOf == -1) {
            indexOf = list.size();
            list.add(t);
        }
        return indexOf;
    }

    protected AggregateRelBase newAggregateRel(AggregateRelBase aggregateRelBase, RelNode relNode, List<AggregateCall> list) {
        return new AggregateRel(aggregateRelBase.getCluster(), relNode, aggregateRelBase.getGroupSet(), list);
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        return relNode.getRowType().getFieldList().get(i).getType();
    }

    static /* synthetic */ int[] $SWITCH_TABLE$org$eigenbase$sql$fun$SqlAvgAggFunction$Subtype() {
        int[] iArr = $SWITCH_TABLE$org$eigenbase$sql$fun$SqlAvgAggFunction$Subtype;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[SqlAvgAggFunction.Subtype.valuesCustom().length];
        try {
            iArr2[SqlAvgAggFunction.Subtype.AVG.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[SqlAvgAggFunction.Subtype.STDDEV_POP.ordinal()] = 2;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[SqlAvgAggFunction.Subtype.STDDEV_SAMP.ordinal()] = 3;
        } catch (NoSuchFieldError unused3) {
        }
        try {
            iArr2[SqlAvgAggFunction.Subtype.VAR_POP.ordinal()] = 4;
        } catch (NoSuchFieldError unused4) {
        }
        try {
            iArr2[SqlAvgAggFunction.Subtype.VAR_SAMP.ordinal()] = 5;
        } catch (NoSuchFieldError unused5) {
        }
        $SWITCH_TABLE$org$eigenbase$sql$fun$SqlAvgAggFunction$Subtype = iArr2;
        return iArr2;
    }
}
