/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

public final class AggregateExpandDistinctAggregatesRule
extends RelOptRule {
    public static final AggregateExpandDistinctAggregatesRule INSTANCE = new AggregateExpandDistinctAggregatesRule(LogicalAggregate.class, true, RelFactories.LOGICAL_BUILDER);
    public static final AggregateExpandDistinctAggregatesRule JOIN = new AggregateExpandDistinctAggregatesRule(LogicalAggregate.class, false, RelFactories.LOGICAL_BUILDER);
    public static final BigDecimal TWO = BigDecimal.valueOf(2L);
    public final boolean useGroupingSets;

    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> clazz, boolean useGroupingSets, RelBuilderFactory relBuilderFactory) {
        super(AggregateExpandDistinctAggregatesRule.operand(clazz, AggregateExpandDistinctAggregatesRule.any()), relBuilderFactory, null);
        this.useGroupingSets = useGroupingSets;
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> clazz, boolean useGroupingSets, RelFactories.JoinFactory joinFactory) {
        this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of((Object)joinFactory)));
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> clazz, RelFactories.JoinFactory joinFactory) {
        this(clazz, false, RelBuilder.proto(Contexts.of((Object)joinFactory)));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        if (!aggregate.containsDistinctCall()) {
            return;
        }
        int nonDistinctCount = 0;
        LinkedHashSet<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<Pair<List<Integer>, Integer>>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            if (!aggCall.isDistinct()) {
                ++nonDistinctCount;
                continue;
            }
            argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
        }
        Preconditions.checkState((argLists.size() > 0 ? 1 : 0) != 0, (Object)"containsDistinctCall lied");
        if (nonDistinctCount == 0 && argLists.size() == 1) {
            Pair pair = (Pair)Iterables.getOnlyElement(argLists);
            RelBuilder relBuilder = call.builder();
            this.convertMonopole(relBuilder, aggregate, (List)pair.left, (Integer)pair.right);
            call.transformTo(relBuilder.build());
            return;
        }
        if (this.useGroupingSets) {
            this.rewriteUsingGroupingSets(call, aggregate, argLists);
            return;
        }
        List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
        ArrayList<RexInputRef> refs = new ArrayList<RexInputRef>();
        List<String> fieldNames = aggregate.getRowType().getFieldNames();
        ImmutableBitSet groupSet = aggregate.getGroupSet();
        int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        for (int i2 : Util.range(groupAndIndicatorCount)) {
            refs.add(RexInputRef.of(i2, aggFields));
        }
        ArrayList<AggregateCall> newAggCallList = new ArrayList<AggregateCall>();
        int i = -1;
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            ++i;
            if (aggCall.isDistinct()) {
                refs.add(null);
                continue;
            }
            refs.add(new RexInputRef(groupAndIndicatorCount + newAggCallList.size(), aggFields.get(groupAndIndicatorCount + i).getType()));
            newAggCallList.add(aggCall);
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push(aggregate.getInput());
        int n = 0;
        if (!newAggCallList.isEmpty()) {
            RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, aggregate.indicator, aggregate.getGroupSets());
            relBuilder.aggregate(groupKey, (List<AggregateCall>)newAggCallList);
            ++n;
        }
        for (Pair pair : argLists) {
            this.doRewrite(relBuilder, aggregate, n++, (List)pair.left, (Integer)pair.right, refs);
        }
        relBuilder.project(refs, fieldNames);
        call.transformTo(relBuilder.build());
    }

    private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
        TreeSet<ImmutableBitSet> groupSetTreeSet = new TreeSet<ImmutableBitSet>((Comparator<ImmutableBitSet>)ImmutableBitSet.ORDERING);
        groupSetTreeSet.add(aggregate.getGroupSet());
        for (Pair<List<Integer>, Integer> argList : argLists) {
            groupSetTreeSet.add(ImmutableBitSet.of((Iterable)argList.left).setIf((Integer)argList.right, (Integer)argList.right >= 0).union(aggregate.getGroupSet()));
        }
        ImmutableList groupSets = ImmutableList.copyOf(groupSetTreeSet);
        final ImmutableBitSet fullGroupSet = ImmutableBitSet.union((Iterable<? extends ImmutableBitSet>)groupSets);
        final ArrayList<AggregateCall> distinctAggCalls = new ArrayList<AggregateCall>();
        for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
            if (((AggregateCall)aggCall.left).isDistinct()) continue;
            distinctAggCalls.add(((AggregateCall)aggCall.left).rename((String)aggCall.right));
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push(aggregate.getInput());
        relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, (ImmutableList<ImmutableBitSet>)groupSets), (List<AggregateCall>)distinctAggCalls);
        RelNode distinct = relBuilder.peek();
        final int groupCount = fullGroupSet.cardinality();
        final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
        RelOptCluster cluster = aggregate.getCluster();
        final RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
        final ArrayList predicates = new ArrayList();
        HashMap<ImmutableBitSet, Integer> filters = new HashMap<ImmutableBitSet, Integer>();
        class Registrar {
            RexNode group = null;

            Registrar() {
            }

            int register(ImmutableBitSet groupSet) {
                if (this.group == null) {
                    this.group = this.makeGroup(groupCount - 1);
                }
                RexNode node = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, this.group, rexBuilder.makeExactLiteral(this.toNumber(AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, groupSet))));
                predicates.add(Pair.of(node, this.toString(groupSet)));
                return groupCount + indicatorCount + distinctAggCalls.size() + predicates.size() - 1;
            }

            private RexNode makeGroup(int i) {
                RexInputRef ref = rexBuilder.makeInputRef(booleanType, groupCount + i);
                RexNode kase = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, ref, rexBuilder.makeExactLiteral(BigDecimal.ZERO), rexBuilder.makeExactLiteral(TWO.pow(i)));
                if (i == 0) {
                    return kase;
                }
                return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.PLUS, this.makeGroup(i - 1), kase);
            }

            private BigDecimal toNumber(ImmutableBitSet bitSet) {
                BigDecimal n = BigDecimal.ZERO;
                for (int key : bitSet) {
                    n = n.add(TWO.pow(key));
                }
                return n;
            }

            private String toString(ImmutableBitSet bitSet) {
                StringBuilder buf = new StringBuilder("$i");
                for (int key : bitSet) {
                    buf.append(key).append('_');
                }
                return buf.substring(0, buf.length() - 1);
            }
        }
        Registrar registrar = new Registrar();
        for (ImmutableBitSet groupSet : groupSets) {
            filters.put(groupSet, registrar.register(groupSet));
        }
        if (!predicates.isEmpty()) {
            ArrayList<Pair<RexInputRef, String>> nodes = new ArrayList<Pair<RexInputRef, String>>();
            for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
                RexInputRef node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
                nodes.add(Pair.of(node, f.getName()));
            }
            nodes.addAll(predicates);
            relBuilder.project(Pair.left(nodes), Pair.right(nodes));
        }
        int x = groupCount + indicatorCount;
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            int newFilterArg;
            ImmutableIntList newArgList;
            SqlAggFunction aggregation;
            if (!aggCall.isDistinct()) {
                aggregation = SqlStdOperatorTable.MIN;
                newArgList = ImmutableIntList.of(x++);
                newFilterArg = (Integer)filters.get(aggregate.getGroupSet());
            } else {
                aggregation = aggCall.getAggregation();
                newArgList = AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggCall.getArgList());
                newFilterArg = (Integer)filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
            }
            AggregateCall newCall = AggregateCall.create(aggregation, false, newArgList, newFilterArg, aggregate.getGroupCount(), distinct, null, aggCall.name);
            newCalls.add(newCall);
        }
        relBuilder.aggregate(relBuilder.groupKey(AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSet()), aggregate.indicator, AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSets())), (List<AggregateCall>)newCalls);
        relBuilder.convert(aggregate.getRowType(), true);
        call.transformTo(relBuilder.build());
    }

    private static ImmutableBitSet remap(ImmutableBitSet groupSet, ImmutableBitSet bitSet) {
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        for (Integer bit : bitSet) {
            builder.set(AggregateExpandDistinctAggregatesRule.remap(groupSet, bit));
        }
        return builder.build();
    }

    private static ImmutableList<ImmutableBitSet> remap(ImmutableBitSet groupSet, Iterable<ImmutableBitSet> bitSets) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (ImmutableBitSet bitSet : bitSets) {
            builder.add((Object)AggregateExpandDistinctAggregatesRule.remap(groupSet, bitSet));
        }
        return builder.build();
    }

    private static List<Integer> remap(ImmutableBitSet groupSet, List<Integer> argList) {
        ImmutableIntList list = ImmutableIntList.of();
        for (int arg : argList) {
            list = list.add(AggregateExpandDistinctAggregatesRule.remap(groupSet, arg));
        }
        return list;
    }

    private static int remap(ImmutableBitSet groupSet, int arg) {
        return arg < 0 ? -1 : groupSet.indexOf(arg);
    }

    private RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate, List<Integer> argList, int filterArg) {
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        this.createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
        ArrayList newAggCalls = Lists.newArrayList(aggregate.getAggCallList());
        AggregateExpandDistinctAggregatesRule.rewriteAggCalls(newAggCalls, argList, sourceOf);
        int cardinality = aggregate.getGroupSet().cardinality();
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.range(cardinality), null, newAggCalls));
        return relBuilder;
    }

    private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, List<Integer> argList, int filterArg, List<RexInputRef> refs) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        List<RelDataTypeField> leftFields = n == 0 ? null : relBuilder.peek().getRowType().getFieldList();
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        this.createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
        ArrayList<AggregateCall> aggCallList = new ArrayList<AggregateCall>();
        List<AggregateCall> aggCalls = aggregate.getAggCallList();
        int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        int i = groupAndIndicatorCount - 1;
        for (AggregateCall aggCall : aggCalls) {
            ++i;
            if (!aggCall.isDistinct() || !aggCall.getArgList().equals(argList)) continue;
            int argCount = aggCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = aggCall.getArgList().get(j);
                newArgs.add((Integer)sourceOf.get(arg));
            }
            int newFilterArg = aggCall.filterArg >= 0 ? (Integer)sourceOf.get(aggCall.filterArg) : -1;
            AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, newFilterArg, aggCall.getType(), aggCall.getName());
            assert (refs.get(i) == null);
            if (n == 0) {
                refs.set(i, new RexInputRef(groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
            } else {
                refs.set(i, new RexInputRef(leftFields.size() + groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
            }
            aggCallList.add(newAggCall);
        }
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        for (Integer key : aggregate.getGroupSet()) {
            map.put(key, map.size());
        }
        ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
        assert (newGroupSet.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality())));
        ImmutableList newGroupingSets = null;
        if (aggregate.indicator) {
            newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.getGroupSets(), map));
        }
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, newGroupSet, (List<ImmutableBitSet>)newGroupingSets, aggCallList));
        if (n == 0) {
            return;
        }
        List<RelDataTypeField> distinctFields = relBuilder.peek().getRowType().getFieldList();
        ArrayList conditions = Lists.newArrayList();
        for (i = 0; i < groupAndIndicatorCount; ++i) {
            conditions.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, RexInputRef.of(i, leftFields), new RexInputRef(leftFields.size() + i, distinctFields.get(i).getType())));
        }
        relBuilder.join(JoinRelType.INNER, conditions);
    }

    private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        for (int i = 0; i < newAggCalls.size(); ++i) {
            AggregateCall aggCall = newAggCalls.get(i);
            if (!aggCall.isDistinct() || !aggCall.getArgList().equals(argList)) continue;
            int argCount = aggCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = aggCall.getArgList().get(j);
                newArgs.add(sourceOf.get(arg));
            }
            AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggCall.getType(), aggCall.getName());
            newAggCalls.set(i, newAggCall);
        }
    }

    private RelBuilder createSelectDistinct(RelBuilder relBuilder, Aggregate aggregate, List<Integer> argList, int filterArg, Map<Integer, Integer> sourceOf) {
        relBuilder.push(aggregate.getInput());
        ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
        List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
        for (int i : aggregate.getGroupSet()) {
            sourceOf.put(i, projects.size());
            projects.add(RexInputRef.of2(i, childFields));
        }
        for (Integer arg : argList) {
            if (filterArg >= 0) {
                RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
                RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
                Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
                RexNode condition = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, filterRef, (RexNode)argRef.left, rexBuilder.ensureType(((RexNode)argRef.left).getType(), rexBuilder.constantNull(), true));
                sourceOf.put(arg, projects.size());
                projects.add(Pair.of(condition, "i$" + (String)argRef.right));
                continue;
            }
            if (sourceOf.get(arg) != null) continue;
            sourceOf.put(arg, projects.size());
            projects.add(RexInputRef.of2(arg, childFields));
        }
        relBuilder.project(Pair.left(projects), Pair.right(projects));
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.range(projects.size()), null, (List<AggregateCall>)ImmutableList.of()));
        return relBuilder;
    }
}

