/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Permutation;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;

public class AggregateProjectPullUpConstantsRule
extends RelOptRule {
    public static final AggregateProjectPullUpConstantsRule INSTANCE = new AggregateProjectPullUpConstantsRule();

    private AggregateProjectPullUpConstantsRule() {
        super(AggregateProjectPullUpConstantsRule.operand(LogicalAggregate.class, null, Aggregate.IS_SIMPLE, AggregateProjectPullUpConstantsRule.operand(LogicalProject.class, AggregateProjectPullUpConstantsRule.any()), new RelOptRuleOperand[0]));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
        LogicalProject input = (LogicalProject)call.rel(1);
        int groupCount = aggregate.getGroupCount();
        if (groupCount == 1) {
            return;
        }
        RexProgram program = RexProgram.create(input.getInput().getRowType(), input.getProjects(), null, input.getRowType(), input.getCluster().getRexBuilder());
        RelDataType childRowType = input.getRowType();
        ArrayList<Integer> constantList = new ArrayList<Integer>();
        HashMap<Integer, RexNode> constants = new HashMap<Integer, RexNode>();
        for (int i : aggregate.getGroupSet()) {
            RexLocalRef ref = program.getProjectList().get(i);
            if (!program.isConstant(ref)) continue;
            constantList.add(i);
            constants.put(i, program.gatherExpr(ref));
        }
        if (constantList.size() == 0) {
            return;
        }
        if (groupCount == constantList.size()) {
            constantList.remove(0);
        }
        int newGroupCount = groupCount - constantList.size();
        RelBuilder relBuilder = call.builder();
        relBuilder.push(input);
        if ((Integer)constantList.get(0) == newGroupCount) {
            ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
            for (AggregateCall aggCall : aggregate.getAggCallList()) {
                newAggCalls.add(aggCall.adaptTo(input, aggCall.getArgList(), aggCall.filterArg, groupCount, newGroupCount));
            }
            relBuilder.aggregate(relBuilder.groupKey(ImmutableBitSet.range(newGroupCount), false, (ImmutableList<ImmutableBitSet>)null), (List<AggregateCall>)newAggCalls);
        } else {
            Permutation mapping = new Permutation(childRowType.getFieldCount());
            mapping.identity();
            int groupOrdinal = 0;
            int constOrdinal = newGroupCount;
            for (int i = 0; i < groupCount; ++i) {
                if (i >= groupCount) {
                    mapping.set(i, i);
                    continue;
                }
                if (constants.containsKey(i)) {
                    mapping.set(i, constOrdinal++);
                    continue;
                }
                mapping.set(i, groupOrdinal++);
            }
            AggregateProjectPullUpConstantsRule.createProjection(relBuilder, mapping);
            ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
            for (AggregateCall aggCall : aggregate.getAggCallList()) {
                int argCount = aggCall.getArgList().size();
                ArrayList<Integer> args = new ArrayList<Integer>(argCount);
                for (int j = 0; j < argCount; ++j) {
                    Integer arg = aggCall.getArgList().get(j);
                    args.add(mapping.getTarget(arg));
                }
                int filterArg = aggCall.filterArg < 0 ? aggCall.filterArg : mapping.getTarget(aggCall.filterArg);
                newAggCalls.add(aggCall.adaptTo(relBuilder.peek(), args, filterArg, groupCount, newGroupCount));
            }
            relBuilder.aggregate(relBuilder.groupKey(ImmutableBitSet.range(newGroupCount), false, (ImmutableList<ImmutableBitSet>)null), (List<AggregateCall>)newAggCalls);
        }
        ArrayList<Pair<RexInputRef, String>> projects = new ArrayList<Pair<RexInputRef, String>>();
        int source = 0;
        for (RelDataTypeField field : aggregate.getRowType().getFieldList()) {
            RexNode expr;
            int i = field.getIndex();
            if (i >= groupCount) {
                expr = relBuilder.field(i - constantList.size());
            } else if (constantList.contains(i)) {
                expr = (RexNode)constants.get(i);
            } else {
                expr = relBuilder.field(source);
                ++source;
            }
            projects.add(Pair.of(expr, field.getName()));
        }
        relBuilder.project(Pair.left(projects), Pair.right(projects));
        call.transformTo(relBuilder.build());
    }

    private static RelBuilder createProjection(RelBuilder relBuilder, Mapping mapping) {
        assert (mapping.getMappingType().isA(MappingType.INVERSE_SURJECTION));
        RelDataType childRowType = relBuilder.peek().getRowType();
        assert (mapping.getSourceCount() == childRowType.getFieldCount());
        ArrayList<Pair<RexInputRef, String>> projects = new ArrayList<Pair<RexInputRef, String>>();
        for (int target = 0; target < mapping.getTargetCount(); ++target) {
            int source = mapping.getSource(target);
            projects.add(Pair.of(relBuilder.field(source), childRowType.getFieldList().get(source).getName()));
        }
        return relBuilder.project(Pair.left(projects), Pair.right(projects));
    }
}

